├── .gitignore
├── LICENSE
├── README.md
├── adet
├── __init__.py
├── checkpoint
│ ├── __init__.py
│ └── adet_checkpoint.py
├── config
│ ├── __init__.py
│ ├── config.py
│ └── defaults.py
├── data
│ ├── __init__.py
│ ├── augmentation.py
│ ├── builtin.py
│ ├── dataset_mapper.py
│ ├── datasets
│ │ └── text.py
│ └── detection_utils.py
├── evaluation
│ ├── __init__.py
│ ├── lexicon_procesor.py
│ ├── rrc_evaluation_funcs.py
│ ├── text_eval_script.py
│ └── text_evaluation.py
├── layers
│ ├── __init__.py
│ ├── csrc
│ │ ├── DeformAttn
│ │ │ ├── ms_deform_attn.h
│ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ ├── ms_deform_attn_cpu.h
│ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ ├── ms_deform_attn_cuda.h
│ │ │ └── ms_deform_im2col_cuda.cuh
│ │ ├── cuda_version.cu
│ │ └── vision.cpp
│ ├── deformable_transformer.py
│ ├── ms_deform_attn.py
│ └── pos_encoding.py
├── modeling
│ ├── __init__.py
│ ├── testr
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ ├── matcher.py
│ │ └── models.py
│ └── transformer_detector.py
└── utils
│ ├── __init__.py
│ ├── comm.py
│ ├── misc.py
│ └── visualizer.py
├── configs
└── TESTR
│ ├── Base-TESTR.yaml
│ ├── CTW1500
│ ├── Base-CTW1500-Polygon.yaml
│ ├── Base-CTW1500.yaml
│ ├── TESTR_R_50.yaml
│ └── TESTR_R_50_Polygon.yaml
│ ├── ICDAR15
│ ├── Base-ICDAR15-Polygon.yaml
│ └── TESTR_R_50_Polygon.yaml
│ ├── Pretrain
│ ├── Base-Pretrain-Polygon.yaml
│ ├── Base-Pretrain.yaml
│ ├── TESTR_R_50.yaml
│ └── TESTR_R_50_Polygon.yaml
│ └── TotalText
│ ├── Base-TotalText-Polygon.yaml
│ ├── Base-TotalText.yaml
│ ├── TESTR_R_50.yaml
│ └── TESTR_R_50_Polygon.yaml
├── demo
├── demo.py
└── predictor.py
├── figures
└── arch.svg
├── setup.py
└── tools
└── train_net.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # output dir
2 | output
3 | instant_test_output
4 | inference_test_output
5 |
6 |
7 | *.jpg
8 | *.png
9 | *.txt
10 |
11 | # compilation and distribution
12 | __pycache__
13 | _ext
14 | *.pyc
15 | *.so
16 | AdelaiDet.egg-info/
17 | build/
18 | dist/
19 |
20 | # pytorch/python/numpy formats
21 | *.pth
22 | *.pkl
23 | *.npy
24 |
25 | # ipython/jupyter notebooks
26 | *.ipynb
27 | **/.ipynb_checkpoints/
28 |
29 | # Editor temporaries
30 | *.swn
31 | *.swo
32 | *.swp
33 | *~
34 |
35 | # Pycharm editor settings
36 | .idea
37 | .vscode
38 | .python-version
39 |
40 | # project dirs
41 | /datasets/coco
42 | /datasets/lvis
43 | /datasets/pic
44 | /datasets/ytvos
45 | /models
46 | /demo_outputs
47 | /example_inputs
48 | /debug
49 | /weights
50 | /export
51 | eval.sh
52 |
53 | demo/performance.py
54 | demo/demo2.py
55 | train.sh
56 | benchmark.sh
57 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TESTR: Text Spotting Transformers
2 |
3 | This repository is the official implementations for the following paper:
4 |
5 | [Text Spotting Transformers](https://openaccess.thecvf.com/content/CVPR2022/html/Zhang_Text_Spotting_Transformers_CVPR_2022_paper.html)
6 |
7 | [Xiang Zhang](https://xzhang.dev), Yongwen Su, [Subarna Tripathi](https://subarnatripathi.github.io), and [Zhuowen Tu](https://pages.ucsd.edu/~ztu/), CVPR 2022
8 |
9 |
10 |
11 | ## Getting Started
12 | We use the following environment in our experiments. It's recommended to install the dependencies via Anaconda
13 |
14 | + CUDA 11.3
15 | + Python 3.8
16 | + PyTorch 1.10.1
17 | + Official Pre-Built Detectron2
18 |
19 | #### Installation
20 |
21 | Please refer to the **Installation** section of AdelaiDet: [README.md](https://github.com/aim-uofa/AdelaiDet/blob/master/README.md).
22 |
23 | If you have not installed Detectron2, following the official guide: [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md).
24 |
25 | After that, build this repository with
26 |
27 | ```bash
28 | python setup.py build develop
29 | ```
30 |
31 | #### Preparing Datasets
32 |
33 | Please download TotalText, CTW1500, MLT, and CurvedSynText150k according to the guide provided by AdelaiDet: [README.md](https://github.com/aim-uofa/AdelaiDet/blob/master/datasets/README.md).
34 |
35 | ICDAR2015 dataset can be download via [link](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/EWgEM5BSRjBEua4B_qLrGR0BaombUL8K3d23ldXOb7wUNA?e=7VzH34).
36 |
37 | Extract all the datasets and make sure you organize them as follows
38 |
39 | ```
40 | - datasets
41 | | - CTW1500
42 | | | - annotations
43 | | | - ctwtest_text_image
44 | | | - ctwtrain_text_image
45 | | - totaltext (or icdar2015)
46 | | | - test_images
47 | | | - train_images
48 | | | - test.json
49 | | | - train.json
50 | | - mlt2017 (or syntext1, syntext2)
51 | | - annotations
52 | | - images
53 | ```
54 |
55 | After that, download [polygonal annotations](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/ES4aqkvamlJAgiPNFJuYkX4BLo-5cDx9TD_6pnMJnVhXpw?e=tu9D8t), along with [evaluation files](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/Ea5oF7VFoe5NngUoPmLTerQBMdiVUhHcx2pPu3Q5p3hZvg?e=2NJNWh) and extract them under `datasets` folder.
56 |
57 | #### Visualization Demo
58 |
59 | You can try to visualize the predictions of the network using the following command:
60 |
61 | ```bash
62 | python demo/demo.py --config-file --input --output --opts MODEL.WEIGHTS MODEL.TRANSFORMER.INFERENCE_TH_TEST 0.3
63 | ```
64 |
65 | You may want to adjust `INFERENCE_TH_TEST` to filter out predictions with lower scores.
66 |
67 | #### Training
68 |
69 | You can train from scratch or finetune the model by putting pretrained weights in `weights` folder.
70 |
71 | Example commands:
72 |
73 | ```bash
74 | python tools/train_net.py --config-file --num-gpus 8
75 | ```
76 |
77 | All configuration files can be found in `configs/TESTR`, excluding those files named `Base-xxxx.yaml`.
78 |
79 | `TESTR_R_50.yaml` is the config for TESTR-Bezier, while `TESTR_R_50_Polygon.yaml` is for TESTR-Polygon.
80 |
81 | #### Evaluation
82 |
83 | ```bash
84 | python tools/train_net.py --config-file --eval-only MODEL.WEIGHTS
85 | ```
86 |
87 | ## Pretrained Models
88 |
89 |
90 |
91 |
92 | Dataset |
93 | Annotation Type |
94 | Lexicon |
95 | Det-P |
96 | Det-R |
97 | Det-F |
98 | E2E-P |
99 | E2E-R |
100 | E2E-F |
101 | Link |
102 |
103 |
104 |
105 |
106 | Pretrain |
107 | Bezier |
108 | None |
109 | 88.87 |
110 | 76.47 |
111 | 82.20 |
112 | 63.58 |
113 | 56.92 |
114 | 60.06 |
115 | OneDrive |
116 |
117 |
118 | Polygonal |
119 | None |
120 | 88.18 |
121 | 77.51 |
122 | 82.50 |
123 | 66.19 |
124 | 61.14 |
125 | 63.57 |
126 | OneDrive |
127 |
128 |
129 | TotalText |
130 | Bezier |
131 | None |
132 | 92.83 |
133 | 83.65 |
134 | 88.00 |
135 | 74.26 |
136 | 69.05 |
137 | 71.56 |
138 | OneDrive |
139 |
140 |
141 | Full |
142 | - |
143 | - |
144 | - |
145 | 86.42 |
146 | 80.35 |
147 | 83.28 |
148 |
149 |
150 | Polygonal |
151 | None |
152 | 93.36 |
153 | 81.35 |
154 | 86.94 |
155 | 76.85 |
156 | 69.98 |
157 | 73.25 |
158 | OneDrive |
159 |
160 |
161 | Full |
162 | - |
163 | - |
164 | - |
165 | 88.00 |
166 | 80.13 |
167 | 83.88 |
168 |
169 |
170 | CTW1500 |
171 | Bezier |
172 | None |
173 | 89.71 |
174 | 83.07 |
175 | 86.27 |
176 | 55.44 |
177 | 51.34 |
178 | 53.31 |
179 | OneDrive |
180 |
181 |
182 | Full |
183 | - |
184 | - |
185 | - |
186 | 83.05 |
187 | 76.90 |
188 | 79.85 |
189 |
190 |
191 | Polygonal |
192 | None |
193 | 92.04 |
194 | 82.63 |
195 | 87.08 |
196 | 59.14 |
197 | 53.09 |
198 | 55.95 |
199 | OneDrive |
200 |
201 |
202 | Full |
203 | - |
204 | - |
205 | - |
206 | 86.16 |
207 | 77.34 |
208 | 81.51 |
209 |
210 |
211 | ICDAR15 |
212 | Polygonal |
213 | None |
214 | 90.31 |
215 | 89.70 |
216 | 90.00 |
217 | 65.49 |
218 | 65.05 |
219 | 65.27 |
220 | OneDrive |
221 |
222 |
223 | Strong |
224 | - |
225 | - |
226 | - |
227 | 87.11 |
228 | 83.29 |
229 | 85.16 |
230 |
231 |
232 | Weak |
233 | - |
234 | - |
235 | - |
236 | 80.36 |
237 | 78.38 |
238 | 79.36 |
239 |
240 |
241 | Generic |
242 | - |
243 | - |
244 | - |
245 | 73.82 |
246 | 73.33 |
247 | 73.57 |
248 |
249 |
250 |
251 |
252 | The `Lite` models only use the image feature from the last stage of ResNet.
253 |
254 | | Method | Annotation Type | Lexicon | Det-P | Det-R | Det-F | E2E-P | E2E-R | E2E-F | Link |
255 | | ---------------- | --------------- | ------- | ------ | ------ | ------ | ------ | ------ | ------ | ------------------------------------------------------------ |
256 | | Pretrain (Lite) | Polygonal | None | 90.28 | 72.58 | 80.47 | 59.49 | 50.22 | 54.46 | [OneDrive](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/EcG-WKHN7dlNnzUJ5g301goBNknB-_IyADfVoW9q8efIIA?e=8dHW3K) |
257 | | TotalText (Lite) | Polygonal | None | 92.16 | 79.09 | 85.12 | 66.42 | 59.06 | 62.52 | [OneDrive](https://ucsdcloud-my.sharepoint.com/:u:/g/personal/xiz102_ucsd_edu/ETL5VCes0eJBuktGqHqGu_wBltwbngIhqmqePIWfaWgGxw?e=hDkana) |
258 |
259 | ## Citation
260 | ```
261 | @InProceedings{Zhang_2022_CVPR,
262 | author = {Zhang, Xiang and Su, Yongwen and Tripathi, Subarna and Tu, Zhuowen},
263 | title = {Text Spotting Transformers},
264 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
265 | month = {June},
266 | year = {2022},
267 | pages = {9519-9528}
268 | }
269 | ```
270 |
271 | ## License
272 | This repository is released under the Apache License 2.0. License can be found in [LICENSE](LICENSE) file.
273 |
274 |
275 | ## Acknowledgement
276 |
277 | Thanks to [AdelaiDet](https://github.com/aim-uofa/AdelaiDet) for a standardized training and inference framework, and [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR) for the implementation of multi-scale deformable cross-attention.
278 |
--------------------------------------------------------------------------------
/adet/__init__.py:
--------------------------------------------------------------------------------
1 | from adet import modeling
2 |
3 | __version__ = "0.1.1"
4 |
--------------------------------------------------------------------------------
/adet/checkpoint/__init__.py:
--------------------------------------------------------------------------------
1 | from .adet_checkpoint import AdetCheckpointer
2 |
3 | __all__ = ["AdetCheckpointer"]
4 |
--------------------------------------------------------------------------------
/adet/checkpoint/adet_checkpoint.py:
--------------------------------------------------------------------------------
1 | import pickle, os
2 | from fvcore.common.file_io import PathManager
3 | from detectron2.checkpoint import DetectionCheckpointer
4 |
5 |
6 | class AdetCheckpointer(DetectionCheckpointer):
7 | """
8 | Same as :class:`DetectronCheckpointer`, but is able to convert models
9 | in AdelaiDet, such as LPF backbone.
10 | """
11 | def _load_file(self, filename):
12 | if filename.endswith(".pkl"):
13 | with PathManager.open(filename, "rb") as f:
14 | data = pickle.load(f, encoding="latin1")
15 | if "model" in data and "__author__" in data:
16 | # file is in Detectron2 model zoo format
17 | self.logger.info("Reading a file from '{}'".format(data["__author__"]))
18 | return data
19 | else:
20 | # assume file is from Caffe2 / Detectron1 model zoo
21 | if "blobs" in data:
22 | # Detection models have "blobs", but ImageNet models don't
23 | data = data["blobs"]
24 | data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
25 | if "weight_order" in data:
26 | del data["weight_order"]
27 | return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
28 |
29 | loaded = super()._load_file(filename) # load native pth checkpoint
30 | if "model" not in loaded:
31 | loaded = {"model": loaded}
32 |
33 | basename = os.path.basename(filename).lower()
34 | if "lpf" in basename or "dla" in basename:
35 | loaded["matching_heuristics"] = True
36 | return loaded
37 |
--------------------------------------------------------------------------------
/adet/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import get_cfg
2 |
3 | __all__ = [
4 | "get_cfg",
5 | ]
6 |
--------------------------------------------------------------------------------
/adet/config/config.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode
2 |
3 |
4 | def get_cfg() -> CfgNode:
5 | """
6 | Get a copy of the default config.
7 |
8 | Returns:
9 | a detectron2 CfgNode instance.
10 | """
11 | from .defaults import _C
12 |
13 | return _C.clone()
14 |
--------------------------------------------------------------------------------
/adet/config/defaults.py:
--------------------------------------------------------------------------------
1 | from detectron2.config.defaults import _C
2 | from detectron2.config import CfgNode as CN
3 |
4 |
5 | # ---------------------------------------------------------------------------- #
6 | # Additional Configs
7 | # ---------------------------------------------------------------------------- #
8 | _C.MODEL.MOBILENET = False
9 | _C.MODEL.BACKBONE.ANTI_ALIAS = False
10 | _C.MODEL.RESNETS.DEFORM_INTERVAL = 1
11 | _C.INPUT.HFLIP_TRAIN = True
12 | _C.INPUT.CROP.CROP_INSTANCE = True
13 |
14 | # ---------------------------------------------------------------------------- #
15 | # FCOS Head
16 | # ---------------------------------------------------------------------------- #
17 | _C.MODEL.FCOS = CN()
18 |
19 | # This is the number of foreground classes.
20 | _C.MODEL.FCOS.NUM_CLASSES = 80
21 | _C.MODEL.FCOS.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
22 | _C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128]
23 | _C.MODEL.FCOS.PRIOR_PROB = 0.01
24 | _C.MODEL.FCOS.INFERENCE_TH_TRAIN = 0.05
25 | _C.MODEL.FCOS.INFERENCE_TH_TEST = 0.05
26 | _C.MODEL.FCOS.NMS_TH = 0.6
27 | _C.MODEL.FCOS.PRE_NMS_TOPK_TRAIN = 1000
28 | _C.MODEL.FCOS.PRE_NMS_TOPK_TEST = 1000
29 | _C.MODEL.FCOS.POST_NMS_TOPK_TRAIN = 100
30 | _C.MODEL.FCOS.POST_NMS_TOPK_TEST = 100
31 | _C.MODEL.FCOS.TOP_LEVELS = 2
32 | _C.MODEL.FCOS.NORM = "GN" # Support GN or none
33 | _C.MODEL.FCOS.USE_SCALE = True
34 |
35 | # The options for the quality of box prediction
36 | # It can be "ctrness" (as described in FCOS paper) or "iou"
37 | # Using "iou" here generally has ~0.4 better AP on COCO
38 | # Note that for compatibility, we still use the term "ctrness" in the code
39 | _C.MODEL.FCOS.BOX_QUALITY = "ctrness"
40 |
41 | # Multiply centerness before threshold
42 | # This will affect the final performance by about 0.05 AP but save some time
43 | _C.MODEL.FCOS.THRESH_WITH_CTR = False
44 |
45 | # Focal loss parameters
46 | _C.MODEL.FCOS.LOSS_ALPHA = 0.25
47 | _C.MODEL.FCOS.LOSS_GAMMA = 2.0
48 |
49 | # The normalizer of the classification loss
50 | # The normalizer can be "fg" (normalized by the number of the foreground samples),
51 | # "moving_fg" (normalized by the MOVING number of the foreground samples),
52 | # or "all" (normalized by the number of all samples)
53 | _C.MODEL.FCOS.LOSS_NORMALIZER_CLS = "fg"
54 | _C.MODEL.FCOS.LOSS_WEIGHT_CLS = 1.0
55 |
56 | _C.MODEL.FCOS.SIZES_OF_INTEREST = [64, 128, 256, 512]
57 | _C.MODEL.FCOS.USE_RELU = True
58 | _C.MODEL.FCOS.USE_DEFORMABLE = False
59 |
60 | # the number of convolutions used in the cls and bbox tower
61 | _C.MODEL.FCOS.NUM_CLS_CONVS = 4
62 | _C.MODEL.FCOS.NUM_BOX_CONVS = 4
63 | _C.MODEL.FCOS.NUM_SHARE_CONVS = 0
64 | _C.MODEL.FCOS.CENTER_SAMPLE = True
65 | _C.MODEL.FCOS.POS_RADIUS = 1.5
66 | _C.MODEL.FCOS.LOC_LOSS_TYPE = 'giou'
67 | _C.MODEL.FCOS.YIELD_PROPOSAL = False
68 | _C.MODEL.FCOS.YIELD_BOX_FEATURES = False
69 |
70 | # ---------------------------------------------------------------------------- #
71 | # VoVNet backbone
72 | # ---------------------------------------------------------------------------- #
73 | _C.MODEL.VOVNET = CN()
74 | _C.MODEL.VOVNET.CONV_BODY = "V-39-eSE"
75 | _C.MODEL.VOVNET.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
76 |
77 | # Options: FrozenBN, GN, "SyncBN", "BN"
78 | _C.MODEL.VOVNET.NORM = "FrozenBN"
79 | _C.MODEL.VOVNET.OUT_CHANNELS = 256
80 | _C.MODEL.VOVNET.BACKBONE_OUT_CHANNELS = 256
81 |
82 | # ---------------------------------------------------------------------------- #
83 | # DLA backbone
84 | # ---------------------------------------------------------------------------- #
85 |
86 | _C.MODEL.DLA = CN()
87 | _C.MODEL.DLA.CONV_BODY = "DLA34"
88 | _C.MODEL.DLA.OUT_FEATURES = ["stage2", "stage3", "stage4", "stage5"]
89 |
90 | # Options: FrozenBN, GN, "SyncBN", "BN"
91 | _C.MODEL.DLA.NORM = "FrozenBN"
92 |
93 | # ---------------------------------------------------------------------------- #
94 | # BAText Options
95 | # ---------------------------------------------------------------------------- #
96 | _C.MODEL.BATEXT = CN()
97 | _C.MODEL.BATEXT.VOC_SIZE = 96
98 | _C.MODEL.BATEXT.NUM_CHARS = 25
99 | _C.MODEL.BATEXT.POOLER_RESOLUTION = (8, 32)
100 | _C.MODEL.BATEXT.IN_FEATURES = ["p2", "p3", "p4"]
101 | _C.MODEL.BATEXT.POOLER_SCALES = (0.25, 0.125, 0.0625)
102 | _C.MODEL.BATEXT.SAMPLING_RATIO = 1
103 | _C.MODEL.BATEXT.CONV_DIM = 256
104 | _C.MODEL.BATEXT.NUM_CONV = 2
105 | _C.MODEL.BATEXT.RECOGNITION_LOSS = "ctc"
106 | _C.MODEL.BATEXT.RECOGNIZER = "attn"
107 | _C.MODEL.BATEXT.CANONICAL_SIZE = 96 # largest min_size for level 3 (stride=8)
108 | _C.MODEL.BATEXT.USE_COORDCONV = False
109 | _C.MODEL.BATEXT.USE_AET = False
110 | _C.MODEL.BATEXT.CUSTOM_DICT = "" # Path to the class file.
111 |
112 | # ---------------------------------------------------------------------------- #
113 | # BlendMask Options
114 | # ---------------------------------------------------------------------------- #
115 | _C.MODEL.BLENDMASK = CN()
116 | _C.MODEL.BLENDMASK.ATTN_SIZE = 14
117 | _C.MODEL.BLENDMASK.TOP_INTERP = "bilinear"
118 | _C.MODEL.BLENDMASK.BOTTOM_RESOLUTION = 56
119 | _C.MODEL.BLENDMASK.POOLER_TYPE = "ROIAlignV2"
120 | _C.MODEL.BLENDMASK.POOLER_SAMPLING_RATIO = 1
121 | _C.MODEL.BLENDMASK.POOLER_SCALES = (0.25,)
122 | _C.MODEL.BLENDMASK.INSTANCE_LOSS_WEIGHT = 1.0
123 | _C.MODEL.BLENDMASK.VISUALIZE = False
124 |
125 | # ---------------------------------------------------------------------------- #
126 | # Basis Module Options
127 | # ---------------------------------------------------------------------------- #
128 | _C.MODEL.BASIS_MODULE = CN()
129 | _C.MODEL.BASIS_MODULE.NAME = "ProtoNet"
130 | _C.MODEL.BASIS_MODULE.NUM_BASES = 4
131 | _C.MODEL.BASIS_MODULE.LOSS_ON = False
132 | _C.MODEL.BASIS_MODULE.ANN_SET = "coco"
133 | _C.MODEL.BASIS_MODULE.CONVS_DIM = 128
134 | _C.MODEL.BASIS_MODULE.IN_FEATURES = ["p3", "p4", "p5"]
135 | _C.MODEL.BASIS_MODULE.NORM = "SyncBN"
136 | _C.MODEL.BASIS_MODULE.NUM_CONVS = 3
137 | _C.MODEL.BASIS_MODULE.COMMON_STRIDE = 8
138 | _C.MODEL.BASIS_MODULE.NUM_CLASSES = 80
139 | _C.MODEL.BASIS_MODULE.LOSS_WEIGHT = 0.3
140 |
141 | # ---------------------------------------------------------------------------- #
142 | # MEInst Head
143 | # ---------------------------------------------------------------------------- #
144 | _C.MODEL.MEInst = CN()
145 |
146 | # This is the number of foreground classes.
147 | _C.MODEL.MEInst.NUM_CLASSES = 80
148 | _C.MODEL.MEInst.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
149 | _C.MODEL.MEInst.FPN_STRIDES = [8, 16, 32, 64, 128]
150 | _C.MODEL.MEInst.PRIOR_PROB = 0.01
151 | _C.MODEL.MEInst.INFERENCE_TH_TRAIN = 0.05
152 | _C.MODEL.MEInst.INFERENCE_TH_TEST = 0.05
153 | _C.MODEL.MEInst.NMS_TH = 0.6
154 | _C.MODEL.MEInst.PRE_NMS_TOPK_TRAIN = 1000
155 | _C.MODEL.MEInst.PRE_NMS_TOPK_TEST = 1000
156 | _C.MODEL.MEInst.POST_NMS_TOPK_TRAIN = 100
157 | _C.MODEL.MEInst.POST_NMS_TOPK_TEST = 100
158 | _C.MODEL.MEInst.TOP_LEVELS = 2
159 | _C.MODEL.MEInst.NORM = "GN" # Support GN or none
160 | _C.MODEL.MEInst.USE_SCALE = True
161 |
162 | # Multiply centerness before threshold
163 | # This will affect the final performance by about 0.05 AP but save some time
164 | _C.MODEL.MEInst.THRESH_WITH_CTR = False
165 |
166 | # Focal loss parameters
167 | _C.MODEL.MEInst.LOSS_ALPHA = 0.25
168 | _C.MODEL.MEInst.LOSS_GAMMA = 2.0
169 | _C.MODEL.MEInst.SIZES_OF_INTEREST = [64, 128, 256, 512]
170 | _C.MODEL.MEInst.USE_RELU = True
171 | _C.MODEL.MEInst.USE_DEFORMABLE = False
172 | _C.MODEL.MEInst.LAST_DEFORMABLE = False
173 | _C.MODEL.MEInst.TYPE_DEFORMABLE = "DCNv1" # or DCNv2.
174 |
175 | # the number of convolutions used in the cls and bbox tower
176 | _C.MODEL.MEInst.NUM_CLS_CONVS = 4
177 | _C.MODEL.MEInst.NUM_BOX_CONVS = 4
178 | _C.MODEL.MEInst.NUM_SHARE_CONVS = 0
179 | _C.MODEL.MEInst.CENTER_SAMPLE = True
180 | _C.MODEL.MEInst.POS_RADIUS = 1.5
181 | _C.MODEL.MEInst.LOC_LOSS_TYPE = 'giou'
182 |
183 | # ---------------------------------------------------------------------------- #
184 | # Mask Encoding
185 | # ---------------------------------------------------------------------------- #
186 | # Whether to use mask branch.
187 | _C.MODEL.MEInst.MASK_ON = True
188 | # IOU overlap ratios [IOU_THRESHOLD]
189 | # Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
190 | # Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
191 | _C.MODEL.MEInst.IOU_THRESHOLDS = [0.5]
192 | _C.MODEL.MEInst.IOU_LABELS = [0, 1]
193 | # Whether to use class_agnostic or class_specific.
194 | _C.MODEL.MEInst.AGNOSTIC = True
195 | # Some operations in mask encoding.
196 | _C.MODEL.MEInst.WHITEN = True
197 | _C.MODEL.MEInst.SIGMOID = True
198 |
199 | # The number of convolutions used in the mask tower.
200 | _C.MODEL.MEInst.NUM_MASK_CONVS = 4
201 |
202 | # The dim of mask before/after mask encoding.
203 | _C.MODEL.MEInst.DIM_MASK = 60
204 | _C.MODEL.MEInst.MASK_SIZE = 28
205 | # The default path for parameters of mask encoding.
206 | _C.MODEL.MEInst.PATH_COMPONENTS = "datasets/coco/components/" \
207 | "coco_2017_train_class_agnosticTrue_whitenTrue_sigmoidTrue_60.npz"
208 | # An indicator for encoding parameters loading during training.
209 | _C.MODEL.MEInst.FLAG_PARAMETERS = False
210 | # The loss for mask branch, can be mse now.
211 | _C.MODEL.MEInst.MASK_LOSS_TYPE = "mse"
212 |
213 | # Whether to use gcn in mask prediction.
214 | # Large Kernel Matters -- https://arxiv.org/abs/1703.02719
215 | _C.MODEL.MEInst.USE_GCN_IN_MASK = False
216 | _C.MODEL.MEInst.GCN_KERNEL_SIZE = 9
217 | # Whether to compute loss on original mask (binary mask).
218 | _C.MODEL.MEInst.LOSS_ON_MASK = False
219 |
220 | # ---------------------------------------------------------------------------- #
221 | # CondInst Options
222 | # ---------------------------------------------------------------------------- #
223 | _C.MODEL.CONDINST = CN()
224 |
225 | # the downsampling ratio of the final instance masks to the input image
226 | _C.MODEL.CONDINST.MASK_OUT_STRIDE = 4
227 | _C.MODEL.CONDINST.BOTTOM_PIXELS_REMOVED = -1
228 |
229 | # if not -1, we only compute the mask loss for MAX_PROPOSALS random proposals PER GPU
230 | _C.MODEL.CONDINST.MAX_PROPOSALS = -1
231 | # if not -1, we only compute the mask loss for top `TOPK_PROPOSALS_PER_IM` proposals
232 | # PER IMAGE in terms of their detection scores
233 | _C.MODEL.CONDINST.TOPK_PROPOSALS_PER_IM = -1
234 |
235 | _C.MODEL.CONDINST.MASK_HEAD = CN()
236 | _C.MODEL.CONDINST.MASK_HEAD.CHANNELS = 8
237 | _C.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS = 3
238 | _C.MODEL.CONDINST.MASK_HEAD.USE_FP16 = False
239 | _C.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS = False
240 |
241 | _C.MODEL.CONDINST.MASK_BRANCH = CN()
242 | _C.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS = 8
243 | _C.MODEL.CONDINST.MASK_BRANCH.IN_FEATURES = ["p3", "p4", "p5"]
244 | _C.MODEL.CONDINST.MASK_BRANCH.CHANNELS = 128
245 | _C.MODEL.CONDINST.MASK_BRANCH.NORM = "BN"
246 | _C.MODEL.CONDINST.MASK_BRANCH.NUM_CONVS = 4
247 | _C.MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON = False
248 |
249 | # The options for BoxInst, which can train the instance segmentation model with box annotations only
250 | # Please refer to the paper https://arxiv.org/abs/2012.02310
251 | _C.MODEL.BOXINST = CN()
252 | # Whether to enable BoxInst
253 | _C.MODEL.BOXINST.ENABLED = False
254 | _C.MODEL.BOXINST.BOTTOM_PIXELS_REMOVED = 10
255 |
256 | _C.MODEL.BOXINST.PAIRWISE = CN()
257 | _C.MODEL.BOXINST.PAIRWISE.SIZE = 3
258 | _C.MODEL.BOXINST.PAIRWISE.DILATION = 2
259 | _C.MODEL.BOXINST.PAIRWISE.WARMUP_ITERS = 10000
260 | _C.MODEL.BOXINST.PAIRWISE.COLOR_THRESH = 0.3
261 |
262 | # ---------------------------------------------------------------------------- #
263 | # TOP Module Options
264 | # ---------------------------------------------------------------------------- #
265 | _C.MODEL.TOP_MODULE = CN()
266 | _C.MODEL.TOP_MODULE.NAME = "conv"
267 | _C.MODEL.TOP_MODULE.DIM = 16
268 |
269 | # ---------------------------------------------------------------------------- #
270 | # BiFPN options
271 | # ---------------------------------------------------------------------------- #
272 |
273 | _C.MODEL.BiFPN = CN()
274 | # Names of the input feature maps to be used by BiFPN
275 | # They must have contiguous power of 2 strides
276 | # e.g., ["res2", "res3", "res4", "res5"]
277 | _C.MODEL.BiFPN.IN_FEATURES = ["res2", "res3", "res4", "res5"]
278 | _C.MODEL.BiFPN.OUT_CHANNELS = 160
279 | _C.MODEL.BiFPN.NUM_REPEATS = 6
280 |
281 | # Options: "" (no norm), "GN"
282 | _C.MODEL.BiFPN.NORM = ""
283 |
284 | # ---------------------------------------------------------------------------- #
285 | # SOLOv2 Options
286 | # ---------------------------------------------------------------------------- #
287 | _C.MODEL.SOLOV2 = CN()
288 |
289 | # Instance hyper-parameters
290 | _C.MODEL.SOLOV2.INSTANCE_IN_FEATURES = ["p2", "p3", "p4", "p5", "p6"]
291 | _C.MODEL.SOLOV2.FPN_INSTANCE_STRIDES = [8, 8, 16, 32, 32]
292 | _C.MODEL.SOLOV2.FPN_SCALE_RANGES = ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
293 | _C.MODEL.SOLOV2.SIGMA = 0.2
294 | # Channel size for the instance head.
295 | _C.MODEL.SOLOV2.INSTANCE_IN_CHANNELS = 256
296 | _C.MODEL.SOLOV2.INSTANCE_CHANNELS = 512
297 | # Convolutions to use in the instance head.
298 | _C.MODEL.SOLOV2.NUM_INSTANCE_CONVS = 4
299 | _C.MODEL.SOLOV2.USE_DCN_IN_INSTANCE = False
300 | _C.MODEL.SOLOV2.TYPE_DCN = 'DCN'
301 | _C.MODEL.SOLOV2.NUM_GRIDS = [40, 36, 24, 16, 12]
302 | # Number of foreground classes.
303 | _C.MODEL.SOLOV2.NUM_CLASSES = 80
304 | _C.MODEL.SOLOV2.NUM_KERNELS = 256
305 | _C.MODEL.SOLOV2.NORM = "GN"
306 | _C.MODEL.SOLOV2.USE_COORD_CONV = True
307 | _C.MODEL.SOLOV2.PRIOR_PROB = 0.01
308 |
309 | # Mask hyper-parameters.
310 | # Channel size for the mask tower.
311 | _C.MODEL.SOLOV2.MASK_IN_FEATURES = ["p2", "p3", "p4", "p5"]
312 | _C.MODEL.SOLOV2.MASK_IN_CHANNELS = 256
313 | _C.MODEL.SOLOV2.MASK_CHANNELS = 128
314 | _C.MODEL.SOLOV2.NUM_MASKS = 256
315 |
316 | # Test cfg.
317 | _C.MODEL.SOLOV2.NMS_PRE = 500
318 | _C.MODEL.SOLOV2.SCORE_THR = 0.1
319 | _C.MODEL.SOLOV2.UPDATE_THR = 0.05
320 | _C.MODEL.SOLOV2.MASK_THR = 0.5
321 | _C.MODEL.SOLOV2.MAX_PER_IMG = 100
322 | # NMS type: matrix OR mask.
323 | _C.MODEL.SOLOV2.NMS_TYPE = "matrix"
324 | # Matrix NMS kernel type: gaussian OR linear.
325 | _C.MODEL.SOLOV2.NMS_KERNEL = "gaussian"
326 | _C.MODEL.SOLOV2.NMS_SIGMA = 2
327 |
328 | # Loss cfg.
329 | _C.MODEL.SOLOV2.LOSS = CN()
330 | _C.MODEL.SOLOV2.LOSS.FOCAL_USE_SIGMOID = True
331 | _C.MODEL.SOLOV2.LOSS.FOCAL_ALPHA = 0.25
332 | _C.MODEL.SOLOV2.LOSS.FOCAL_GAMMA = 2.0
333 | _C.MODEL.SOLOV2.LOSS.FOCAL_WEIGHT = 1.0
334 | _C.MODEL.SOLOV2.LOSS.DICE_WEIGHT = 3.0
335 |
336 |
337 | # ---------------------------------------------------------------------------- #
338 | # (Deformable) Transformer Options
339 | # ---------------------------------------------------------------------------- #
340 | _C.MODEL.TRANSFORMER = CN()
341 | _C.MODEL.TRANSFORMER.USE_POLYGON = False
342 | _C.MODEL.TRANSFORMER.ENABLED = False
343 | _C.MODEL.TRANSFORMER.INFERENCE_TH_TEST = 0.45
344 | _C.MODEL.TRANSFORMER.VOC_SIZE = 96
345 | _C.MODEL.TRANSFORMER.NUM_CHARS = 25
346 | _C.MODEL.TRANSFORMER.AUX_LOSS = True
347 | _C.MODEL.TRANSFORMER.ENC_LAYERS = 6
348 | _C.MODEL.TRANSFORMER.DEC_LAYERS = 6
349 | _C.MODEL.TRANSFORMER.DIM_FEEDFORWARD = 1024
350 | _C.MODEL.TRANSFORMER.HIDDEN_DIM = 256
351 | _C.MODEL.TRANSFORMER.DROPOUT = 0.1
352 | _C.MODEL.TRANSFORMER.NHEADS = 8
353 | _C.MODEL.TRANSFORMER.NUM_QUERIES = 100
354 | _C.MODEL.TRANSFORMER.ENC_N_POINTS = 4
355 | _C.MODEL.TRANSFORMER.DEC_N_POINTS = 4
356 | _C.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE = 6.283185307179586 # 2 PI
357 | _C.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS = 4
358 | _C.MODEL.TRANSFORMER.NUM_CTRL_POINTS = 8
359 |
360 | _C.MODEL.TRANSFORMER.LOSS = CN()
361 | _C.MODEL.TRANSFORMER.LOSS.AUX_LOSS = True
362 | _C.MODEL.TRANSFORMER.LOSS.POINT_CLASS_WEIGHT = 2.0
363 | _C.MODEL.TRANSFORMER.LOSS.POINT_COORD_WEIGHT = 5.0
364 | _C.MODEL.TRANSFORMER.LOSS.POINT_TEXT_WEIGHT = 2.0
365 | _C.MODEL.TRANSFORMER.LOSS.BOX_CLASS_WEIGHT = 2.0
366 | _C.MODEL.TRANSFORMER.LOSS.BOX_COORD_WEIGHT = 5.0
367 | _C.MODEL.TRANSFORMER.LOSS.BOX_GIOU_WEIGHT = 2.0
368 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_ALPHA = 0.25
369 | _C.MODEL.TRANSFORMER.LOSS.FOCAL_GAMMA = 2.0
370 |
371 |
372 | _C.SOLVER.OPTIMIZER = "ADAMW"
373 | _C.SOLVER.LR_BACKBONE = 1e-5
374 | _C.SOLVER.LR_BACKBONE_NAMES = []
375 | _C.SOLVER.LR_LINEAR_PROJ_NAMES = []
376 | _C.SOLVER.LR_LINEAR_PROJ_MULT = 0.1
377 |
378 | _C.TEST.USE_LEXICON = False
379 | # 1 - Generic, 2 - Weak, 3 - Strong (for icdar2015)
380 | # 1 - Full lexicon (for totaltext/ctw1500)
381 | _C.TEST.LEXICON_TYPE = 1
382 | _C.TEST.WEIGHTED_EDIT_DIST = False
383 |
--------------------------------------------------------------------------------
/adet/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import builtin # ensure the builtin datasets are registered
2 | from .dataset_mapper import DatasetMapperWithBasis
3 |
4 |
5 | __all__ = ["DatasetMapperWithBasis"]
6 |
--------------------------------------------------------------------------------
/adet/data/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from fvcore.transforms import transform as T
5 |
6 | from detectron2.data.transforms import RandomCrop, StandardAugInput
7 | from detectron2.structures import BoxMode
8 |
9 |
10 | def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True):
11 | """
12 | Generate a CropTransform so that the cropping region contains
13 | the center of the given instance.
14 |
15 | Args:
16 | crop_size (tuple): h, w in pixels
17 | image_size (tuple): h, w
18 | instance (dict): an annotation dict of one instance, in Detectron2's
19 | dataset format.
20 | """
21 | bbox = random.choice(instances)
22 | crop_size = np.asarray(crop_size, dtype=np.int32)
23 | center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
24 | assert (
25 | image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
26 | ), "The annotation bounding box is outside of the image!"
27 | assert (
28 | image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
29 | ), "Crop size is larger than image size!"
30 |
31 | min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
32 | max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
33 | max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
34 |
35 | y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
36 | x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
37 |
38 | # if some instance is cropped extend the box
39 | if not crop_box:
40 | num_modifications = 0
41 | modified = True
42 |
43 | # convert crop_size to float
44 | crop_size = crop_size.astype(np.float32)
45 | while modified:
46 | modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances)
47 | num_modifications += 1
48 | if num_modifications > 100:
49 | raise ValueError(
50 | "Cannot finished cropping adjustment within 100 tries (#instances {}).".format(
51 | len(instances)
52 | )
53 | )
54 | return T.CropTransform(0, 0, image_size[1], image_size[0])
55 |
56 | return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0])))
57 |
58 |
59 | def adjust_crop(x0, y0, crop_size, instances, eps=1e-3):
60 | modified = False
61 |
62 | x1 = x0 + crop_size[1]
63 | y1 = y0 + crop_size[0]
64 |
65 | for bbox in instances:
66 |
67 | if bbox[0] < x0 - eps and bbox[2] > x0 + eps:
68 | crop_size[1] += x0 - bbox[0]
69 | x0 = bbox[0]
70 | modified = True
71 |
72 | if bbox[0] < x1 - eps and bbox[2] > x1 + eps:
73 | crop_size[1] += bbox[2] - x1
74 | x1 = bbox[2]
75 | modified = True
76 |
77 | if bbox[1] < y0 - eps and bbox[3] > y0 + eps:
78 | crop_size[0] += y0 - bbox[1]
79 | y0 = bbox[1]
80 | modified = True
81 |
82 | if bbox[1] < y1 - eps and bbox[3] > y1 + eps:
83 | crop_size[0] += bbox[3] - y1
84 | y1 = bbox[3]
85 | modified = True
86 |
87 | return modified, x0, y0, crop_size
88 |
89 |
90 | class RandomCropWithInstance(RandomCrop):
91 | """ Instance-aware cropping.
92 | """
93 |
94 | def __init__(self, crop_type, crop_size, crop_instance=True):
95 | """
96 | Args:
97 | crop_instance (bool): if False, extend cropping boxes to avoid cropping instances
98 | """
99 | super().__init__(crop_type, crop_size)
100 | self.crop_instance = crop_instance
101 | self.input_args = ("image", "boxes")
102 |
103 | def get_transform(self, img, boxes):
104 | image_size = img.shape[:2]
105 | crop_size = self.get_crop_size(image_size)
106 | return gen_crop_transform_with_instance(
107 | crop_size, image_size, boxes, crop_box=self.crop_instance
108 | )
109 |
--------------------------------------------------------------------------------
/adet/data/builtin.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from detectron2.data.datasets.register_coco import register_coco_instances
4 | from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
5 |
6 | from .datasets.text import register_text_instances
7 |
8 | # register plane reconstruction
9 |
10 | _PREDEFINED_SPLITS_PIC = {
11 | "pic_person_train": ("pic/image/train", "pic/annotations/train_person.json"),
12 | "pic_person_val": ("pic/image/val", "pic/annotations/val_person.json"),
13 | }
14 |
15 | metadata_pic = {
16 | "thing_classes": ["person"]
17 | }
18 |
19 | _PREDEFINED_SPLITS_TEXT = {
20 | # datasets with bezier annotations
21 | "totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
22 | "totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
23 | "ctw1500_word_train": ("CTW1500/ctwtrain_text_image", "CTW1500/annotations/train_ctw1500_maxlen100_v2.json"),
24 | "ctw1500_word_test": ("CTW1500/ctwtest_text_image","CTW1500/annotations/test_ctw1500_maxlen100.json"),
25 | "syntext1_train": ("syntext1/images", "syntext1/annotations/train.json"),
26 | "syntext2_train": ("syntext2/images", "syntext2/annotations/train.json"),
27 | "mltbezier_word_train": ("mlt2017/images","mlt2017/annotations/train.json"),
28 | "rects_train": ("ReCTS/ReCTS_train_images", "ReCTS/annotations/rects_train.json"),
29 | "rects_val": ("ReCTS/ReCTS_val_images", "ReCTS/annotations/rects_val.json"),
30 | "rects_test": ("ReCTS/ReCTS_test_images", "ReCTS/annotations/rects_test.json"),
31 | "art_train": ("ArT/rename_artimg_train", "ArT/annotations/abcnet_art_train.json"),
32 | "lsvt_train": ("LSVT/rename_lsvtimg_train", "LSVT/annotations/abcnet_lsvt_train.json"),
33 | "chnsyn_train": ("ChnSyn/syn_130k_images", "ChnSyn/annotations/chn_syntext.json"),
34 | # datasets with polygon annotations
35 | "totaltext_poly_train": ("totaltext/train_images", "totaltext/train_poly.json"),
36 | "totaltext_poly_val": ("totaltext/test_images", "totaltext/test_poly.json"),
37 | "ctw1500_word_poly_train": ("CTW1500/ctwtrain_text_image", "CTW1500/annotations/train_poly.json"),
38 | "ctw1500_word_poly_test": ("CTW1500/ctwtest_text_image","CTW1500/annotations/test_poly.json"),
39 | "syntext1_poly_train": ("syntext1/images", "syntext1/annotations/train_poly.json"),
40 | "syntext2_poly_train": ("syntext2/images", "syntext2/annotations/train_poly.json"),
41 | "mltbezier_word_poly_train": ("mlt2017/images","mlt2017/annotations/train_poly.json"),
42 | "icdar2015_train": ("icdar2015/train_images", "icdar2015/train_poly.json"),
43 | "icdar2015_test": ("icdar2015/test_images", "icdar2015/test_poly.json"),
44 | "icdar2019_train": ("icdar2019/train_images", "icdar2019/train_poly.json"),
45 | "textocr_train": ("textocr/train_images", "textocr/annotations/train_poly.json"),
46 | "textocr_val": ("textocr/train_images", "textocr/annotations/val_poly.json"),
47 | }
48 |
49 | metadata_text = {
50 | "thing_classes": ["text"]
51 | }
52 |
53 |
54 | def register_all_coco(root="datasets"):
55 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_PIC.items():
56 | # Assume pre-defined datasets live in `./datasets`.
57 | register_coco_instances(
58 | key,
59 | metadata_pic,
60 | os.path.join(root, json_file) if "://" not in json_file else json_file,
61 | os.path.join(root, image_root),
62 | )
63 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_TEXT.items():
64 | # Assume pre-defined datasets live in `./datasets`.
65 | register_text_instances(
66 | key,
67 | metadata_text,
68 | os.path.join(root, json_file) if "://" not in json_file else json_file,
69 | os.path.join(root, image_root),
70 | )
71 |
72 |
73 | register_all_coco()
74 |
--------------------------------------------------------------------------------
/adet/data/dataset_mapper.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import os.path as osp
4 |
5 | import numpy as np
6 | import torch
7 | from fvcore.common.file_io import PathManager
8 | from PIL import Image
9 | from pycocotools import mask as maskUtils
10 |
11 | from detectron2.data import detection_utils as utils
12 | from detectron2.data import transforms as T
13 | from detectron2.data.dataset_mapper import DatasetMapper
14 | from detectron2.data.detection_utils import SizeMismatchError
15 | from detectron2.structures import BoxMode
16 |
17 | from .augmentation import RandomCropWithInstance
18 | from .detection_utils import (annotations_to_instances, build_augmentation,
19 | transform_instance_annotations)
20 |
21 | """
22 | This file contains the default mapping that's applied to "dataset dicts".
23 | """
24 |
25 | __all__ = ["DatasetMapperWithBasis"]
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | def segmToRLE(segm, img_size):
31 | h, w = img_size
32 | if type(segm) == list:
33 | # polygon -- a single object might consist of multiple parts
34 | # we merge all parts into one mask rle code
35 | rles = maskUtils.frPyObjects(segm, h, w)
36 | rle = maskUtils.merge(rles)
37 | elif type(segm["counts"]) == list:
38 | # uncompressed RLE
39 | rle = maskUtils.frPyObjects(segm, h, w)
40 | else:
41 | # rle
42 | rle = segm
43 | return rle
44 |
45 |
46 | def segmToMask(segm, img_size):
47 | rle = segmToRLE(segm, img_size)
48 | m = maskUtils.decode(rle)
49 | return m
50 |
51 |
52 | class DatasetMapperWithBasis(DatasetMapper):
53 | """
54 | This caller enables the default Detectron2 mapper to read an additional basis semantic label
55 | """
56 |
57 | def __init__(self, cfg, is_train=True):
58 | super().__init__(cfg, is_train)
59 |
60 | # Rebuild augmentations
61 | logger.info(
62 | "Rebuilding the augmentations. The previous augmentations will be overridden."
63 | )
64 | self.augmentation = build_augmentation(cfg, is_train)
65 |
66 | if cfg.INPUT.CROP.ENABLED and is_train:
67 | self.augmentation.insert(
68 | 0,
69 | RandomCropWithInstance(
70 | cfg.INPUT.CROP.TYPE,
71 | cfg.INPUT.CROP.SIZE,
72 | cfg.INPUT.CROP.CROP_INSTANCE,
73 | ),
74 | )
75 | logging.getLogger(__name__).info(
76 | "Cropping used in training: " + str(self.augmentation[0])
77 | )
78 |
79 | self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON
80 | self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET
81 | self.boxinst_enabled = cfg.MODEL.BOXINST.ENABLED
82 |
83 | if self.boxinst_enabled:
84 | self.use_instance_mask = False
85 | self.recompute_boxes = False
86 |
87 | def __call__(self, dataset_dict):
88 | """
89 | Args:
90 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
91 |
92 | Returns:
93 | dict: a format that builtin models in detectron2 accept
94 | """
95 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
96 | # USER: Write your own image loading if it's not from a file
97 | try:
98 | image = utils.read_image(
99 | dataset_dict["file_name"], format=self.image_format
100 | )
101 | except Exception as e:
102 | print(dataset_dict["file_name"])
103 | print(e)
104 | raise e
105 | try:
106 | utils.check_image_size(dataset_dict, image)
107 | except SizeMismatchError as e:
108 | expected_wh = (dataset_dict["width"], dataset_dict["height"])
109 | image_wh = (image.shape[1], image.shape[0])
110 | if (image_wh[1], image_wh[0]) == expected_wh:
111 | print("transposing image {}".format(dataset_dict["file_name"]))
112 | image = image.transpose(1, 0, 2)
113 | else:
114 | raise e
115 |
116 | # USER: Remove if you don't do semantic/panoptic segmentation.
117 | if "sem_seg_file_name" in dataset_dict:
118 | sem_seg_gt = utils.read_image(
119 | dataset_dict.pop("sem_seg_file_name"), "L"
120 | ).squeeze(2)
121 | else:
122 | sem_seg_gt = None
123 |
124 | boxes = np.asarray(
125 | [
126 | BoxMode.convert(
127 | instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS
128 | )
129 | for instance in dataset_dict["annotations"]
130 | ]
131 | )
132 | aug_input = T.StandardAugInput(image, boxes=boxes, sem_seg=sem_seg_gt)
133 | transforms = aug_input.apply_augmentations(self.augmentation)
134 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg
135 |
136 | image_shape = image.shape[:2] # h, w
137 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
138 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
139 | # Therefore it's important to use torch.Tensor.
140 | dataset_dict["image"] = torch.as_tensor(
141 | np.ascontiguousarray(image.transpose(2, 0, 1))
142 | )
143 | if sem_seg_gt is not None:
144 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
145 |
146 | # USER: Remove if you don't use pre-computed proposals.
147 | # Most users would not need this feature.
148 | if self.proposal_topk:
149 | utils.transform_proposals(
150 | dataset_dict,
151 | image_shape,
152 | transforms,
153 | proposal_topk=self.proposal_topk,
154 | min_box_size=self.proposal_min_box_size,
155 | )
156 |
157 | if not self.is_train:
158 | dataset_dict.pop("annotations", None)
159 | dataset_dict.pop("sem_seg_file_name", None)
160 | dataset_dict.pop("pano_seg_file_name", None)
161 | return dataset_dict
162 |
163 | if "annotations" in dataset_dict:
164 | # USER: Modify this if you want to keep them for some reason.
165 | for anno in dataset_dict["annotations"]:
166 | if not self.use_instance_mask:
167 | anno.pop("segmentation", None)
168 | if not self.use_keypoint:
169 | anno.pop("keypoints", None)
170 |
171 | # USER: Implement additional transformations if you have other types of data
172 | annos = [
173 | transform_instance_annotations(
174 | obj,
175 | transforms,
176 | image_shape,
177 | keypoint_hflip_indices=self.keypoint_hflip_indices,
178 | )
179 | for obj in dataset_dict.pop("annotations")
180 | if obj.get("iscrowd", 0) == 0
181 | ]
182 | instances = annotations_to_instances(
183 | annos, image_shape, mask_format=self.instance_mask_format
184 | )
185 |
186 | # After transforms such as cropping are applied, the bounding box may no longer
187 | # tightly bound the object. As an example, imagine a triangle object
188 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
189 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
190 | if self.recompute_boxes:
191 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
192 | dataset_dict["instances"] = utils.filter_empty_instances(instances)
193 |
194 | if self.basis_loss_on and self.is_train:
195 | # load basis supervisions
196 | if self.ann_set == "coco":
197 | basis_sem_path = (
198 | dataset_dict["file_name"]
199 | .replace("train2017", "thing_train2017")
200 | .replace("image/train", "thing_train")
201 | )
202 | else:
203 | basis_sem_path = (
204 | dataset_dict["file_name"]
205 | .replace("coco", "lvis")
206 | .replace("train2017", "thing_train")
207 | )
208 | # change extension to npz
209 | basis_sem_path = osp.splitext(basis_sem_path)[0] + ".npz"
210 | basis_sem_gt = np.load(basis_sem_path)["mask"]
211 | basis_sem_gt = transforms.apply_segmentation(basis_sem_gt)
212 | basis_sem_gt = torch.as_tensor(basis_sem_gt.astype("long"))
213 | dataset_dict["basis_sem"] = basis_sem_gt
214 | return dataset_dict
215 |
--------------------------------------------------------------------------------
/adet/data/datasets/text.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import contextlib
3 | import io
4 | import logging
5 | import os
6 | from fvcore.common.timer import Timer
7 | from fvcore.common.file_io import PathManager
8 |
9 | from detectron2.structures import BoxMode
10 |
11 | from detectron2.data import DatasetCatalog, MetadataCatalog
12 |
13 | """
14 | This file contains functions to parse COCO-format text annotations into dicts in "Detectron2 format".
15 | """
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | __all__ = ["load_text_json", "register_text_instances"]
21 |
22 |
23 | def register_text_instances(name, metadata, json_file, image_root):
24 | """
25 | Register a dataset in json annotation format for text detection and recognition.
26 |
27 | Args:
28 | name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
29 | metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
30 | json_file (str): path to the json instance annotation file.
31 | image_root (str or path-like): directory which contains all the images.
32 | """
33 | DatasetCatalog.register(name, lambda: load_text_json(json_file, image_root, name))
34 | MetadataCatalog.get(name).set(
35 | json_file=json_file, image_root=image_root, evaluator_type="text", **metadata
36 | )
37 |
38 |
39 | def load_text_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
40 | """
41 | Load a json file with totaltext annotation format.
42 | Currently supports text detection and recognition.
43 |
44 | Args:
45 | json_file (str): full path to the json file in totaltext annotation format.
46 | image_root (str or path-like): the directory where the images in this json file exists.
47 | dataset_name (str): the name of the dataset (e.g., coco_2017_train).
48 | If provided, this function will also put "thing_classes" into
49 | the metadata associated with this dataset.
50 | extra_annotation_keys (list[str]): list of per-annotation keys that should also be
51 | loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
52 | "category_id", "segmentation"). The values for these keys will be returned as-is.
53 | For example, the densepose annotations are loaded in this way.
54 |
55 | Returns:
56 | list[dict]: a list of dicts in Detectron2 standard dataset dicts format. (See
57 | `Using Custom Datasets `_ )
58 |
59 | Notes:
60 | 1. This function does not read the image files.
61 | The results do not have the "image" field.
62 | """
63 | from pycocotools.coco import COCO
64 |
65 | timer = Timer()
66 | json_file = PathManager.get_local_path(json_file)
67 | with contextlib.redirect_stdout(io.StringIO()):
68 | coco_api = COCO(json_file)
69 | if timer.seconds() > 1:
70 | logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
71 |
72 | id_map = None
73 | if dataset_name is not None:
74 | meta = MetadataCatalog.get(dataset_name)
75 | cat_ids = sorted(coco_api.getCatIds())
76 | cats = coco_api.loadCats(cat_ids)
77 | # The categories in a custom json file may not be sorted.
78 | thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
79 | meta.thing_classes = thing_classes
80 |
81 | # In COCO, certain category ids are artificially removed,
82 | # and by convention they are always ignored.
83 | # We deal with COCO's id issue and translate
84 | # the category ids to contiguous ids in [0, 80).
85 |
86 | # It works by looking at the "categories" field in the json, therefore
87 | # if users' own json also have incontiguous ids, we'll
88 | # apply this mapping as well but print a warning.
89 | if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
90 | if "coco" not in dataset_name:
91 | logger.warning(
92 | """
93 | Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
94 | """
95 | )
96 | id_map = {v: i for i, v in enumerate(cat_ids)}
97 | meta.thing_dataset_id_to_contiguous_id = id_map
98 |
99 | # sort indices for reproducible results
100 | img_ids = sorted(coco_api.imgs.keys())
101 | # imgs is a list of dicts, each looks something like:
102 | # {'license': 4,
103 | # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
104 | # 'file_name': 'COCO_val2014_000000001268.jpg',
105 | # 'height': 427,
106 | # 'width': 640,
107 | # 'date_captured': '2013-11-17 05:57:24',
108 | # 'id': 1268}
109 | imgs = coco_api.loadImgs(img_ids)
110 | # anns is a list[list[dict]], where each dict is an annotation
111 | # record for an object. The inner list enumerates the objects in an image
112 | # and the outer list enumerates over images. Example of anns[0]:
113 | # [{'segmentation': [[192.81,
114 | # 247.09,
115 | # ...
116 | # 219.03,
117 | # 249.06]],
118 | # 'area': 1035.749,
119 | # 'rec': [84, 72, ... 96],
120 | # 'bezier_pts': [169.0, 425.0, ..., ]
121 | # 'iscrowd': 0,
122 | # 'image_id': 1268,
123 | # 'bbox': [192.81, 224.8, 74.73, 33.43],
124 | # 'category_id': 16,
125 | # 'id': 42986},
126 | # ...]
127 | anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
128 |
129 | if "minival" not in json_file:
130 | # The popular valminusminival & minival annotations for COCO2014 contain this bug.
131 | # However the ratio of buggy annotations there is tiny and does not affect accuracy.
132 | # Therefore we explicitly white-list them.
133 | ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
134 | assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
135 | json_file
136 | )
137 |
138 | imgs_anns = list(zip(imgs, anns))
139 |
140 | logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
141 |
142 | dataset_dicts = []
143 |
144 | ann_keys = ["iscrowd", "bbox", "rec", "category_id"] + (extra_annotation_keys or [])
145 |
146 | num_instances_without_valid_segmentation = 0
147 |
148 | for (img_dict, anno_dict_list) in imgs_anns:
149 | record = {}
150 | record["file_name"] = os.path.join(image_root, img_dict["file_name"])
151 | record["height"] = img_dict["height"]
152 | record["width"] = img_dict["width"]
153 | image_id = record["image_id"] = img_dict["id"]
154 |
155 | objs = []
156 | for anno in anno_dict_list:
157 | # Check that the image_id in this annotation is the same as
158 | # the image_id we're looking at.
159 | # This fails only when the data parsing logic or the annotation file is buggy.
160 |
161 | # The original COCO valminusminival2014 & minival2014 annotation files
162 | # actually contains bugs that, together with certain ways of using COCO API,
163 | # can trigger this assertion.
164 | assert anno["image_id"] == image_id
165 |
166 | assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
167 |
168 | obj = {key: anno[key] for key in ann_keys if key in anno}
169 |
170 | segm = anno.get("segmentation", None)
171 | if segm: # either list[list[float]] or dict(RLE)
172 | if not isinstance(segm, dict):
173 | # filter out invalid polygons (< 3 points)
174 | segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
175 | if len(segm) == 0:
176 | num_instances_without_valid_segmentation += 1
177 | continue # ignore this instance
178 | obj["segmentation"] = segm
179 |
180 | bezierpts = anno.get("bezier_pts", None)
181 | # Bezier Points are the control points for BezierAlign Text recognition (BAText)
182 | if bezierpts: # list[float]
183 | obj["beziers"] = bezierpts
184 |
185 | polypts = anno.get("polys", None)
186 | if polypts:
187 | obj["polygons"] = polypts
188 |
189 | text = anno.get("rec", None)
190 | if text:
191 | obj["text"] = text
192 |
193 | obj["bbox_mode"] = BoxMode.XYWH_ABS
194 | if id_map:
195 | obj["category_id"] = id_map[obj["category_id"]]
196 | objs.append(obj)
197 | record["annotations"] = objs
198 | dataset_dicts.append(record)
199 |
200 | if num_instances_without_valid_segmentation > 0:
201 | logger.warning(
202 | "Filtered out {} instances without valid segmentation. "
203 | "There might be issues in your dataset generation process.".format(
204 | num_instances_without_valid_segmentation
205 | )
206 | )
207 | return dataset_dicts
--------------------------------------------------------------------------------
/adet/data/detection_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from detectron2.data import transforms as T
7 | from detectron2.data.detection_utils import \
8 | annotations_to_instances as d2_anno_to_inst
9 | from detectron2.data.detection_utils import \
10 | transform_instance_annotations as d2_transform_inst_anno
11 |
12 |
13 | def transform_instance_annotations(
14 | annotation, transforms, image_size, *, keypoint_hflip_indices=None
15 | ):
16 |
17 | annotation = d2_transform_inst_anno(
18 | annotation,
19 | transforms,
20 | image_size,
21 | keypoint_hflip_indices=keypoint_hflip_indices,
22 | )
23 |
24 | if "beziers" in annotation:
25 | beziers = transform_ctrl_pnts_annotations(annotation["beziers"], transforms)
26 | annotation["beziers"] = beziers
27 |
28 | if "polygons" in annotation:
29 | polys = transform_ctrl_pnts_annotations(annotation["polygons"], transforms)
30 | annotation["polygons"] = polys
31 |
32 | return annotation
33 |
34 |
35 | def transform_ctrl_pnts_annotations(pnts, transforms):
36 | """
37 | Transform keypoint annotations of an image.
38 |
39 | Args:
40 | beziers (list[float]): Nx16 float in Detectron2 Dataset format.
41 | transforms (TransformList):
42 | """
43 | # (N*2,) -> (N, 2)
44 | pnts = np.asarray(pnts, dtype="float64").reshape(-1, 2)
45 | pnts = transforms.apply_coords(pnts).reshape(-1)
46 |
47 | # This assumes that HorizFlipTransform is the only one that does flip
48 | do_hflip = (
49 | sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
50 | )
51 | if do_hflip:
52 | raise ValueError("Flipping text data is not supported (also disencouraged).")
53 |
54 | return pnts
55 |
56 |
57 | def annotations_to_instances(annos, image_size, mask_format="polygon"):
58 | instance = d2_anno_to_inst(annos, image_size, mask_format)
59 |
60 | if not annos:
61 | return instance
62 |
63 | # add attributes
64 | if "beziers" in annos[0]:
65 | beziers = [obj.get("beziers", []) for obj in annos]
66 | instance.beziers = torch.as_tensor(beziers, dtype=torch.float32)
67 |
68 | if "rec" in annos[0]:
69 | text = [obj.get("rec", []) for obj in annos]
70 | instance.text = torch.as_tensor(text, dtype=torch.int32)
71 |
72 | if "polygons" in annos[0]:
73 | polys = [obj.get("polygons", []) for obj in annos]
74 | instance.polygons = torch.as_tensor(polys, dtype=torch.float32)
75 |
76 | return instance
77 |
78 |
79 | def build_augmentation(cfg, is_train):
80 | """
81 | With option to don't use hflip
82 |
83 | Returns:
84 | list[Augmentation]
85 | """
86 | if is_train:
87 | min_size = cfg.INPUT.MIN_SIZE_TRAIN
88 | max_size = cfg.INPUT.MAX_SIZE_TRAIN
89 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
90 | else:
91 | min_size = cfg.INPUT.MIN_SIZE_TEST
92 | max_size = cfg.INPUT.MAX_SIZE_TEST
93 | sample_style = "choice"
94 | if sample_style == "range":
95 | assert (
96 | len(min_size) == 2
97 | ), "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
98 |
99 | logger = logging.getLogger(__name__)
100 |
101 | augmentation = []
102 | augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
103 | if is_train:
104 | if cfg.INPUT.HFLIP_TRAIN:
105 | augmentation.append(T.RandomFlip())
106 | logger.info("Augmentations used in training: " + str(augmentation))
107 | return augmentation
108 |
109 |
110 | build_transform_gen = build_augmentation
111 | """
112 | Alias for backward-compatibility.
113 | """
114 |
--------------------------------------------------------------------------------
/adet/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .text_evaluation import TextEvaluator
2 | from .text_eval_script import text_eval_main
3 | from . import rrc_evaluation_funcs
--------------------------------------------------------------------------------
/adet/evaluation/lexicon_procesor.py:
--------------------------------------------------------------------------------
1 | import editdistance
2 | import numpy as np
3 | from numba import njit
4 | from numba.core import types
5 | from numba.typed import Dict
6 |
7 | @njit
8 | def weighted_edit_distance(word1: str, word2: str, scores: np.ndarray, ct_labels_inv):
9 | m: int = len(word1)
10 | n: int = len(word2)
11 | dp = np.zeros((n+1, m+1), dtype=np.float32)
12 | dp[0, :] = np.arange(m+1)
13 | dp[:, 0] = np.arange(n+1)
14 | for i in range(1, n + 1): ## word2
15 | for j in range(1, m + 1): ## word1
16 | delect_cost = _ed_delete_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## delect a[i]
17 | insert_cost = _ed_insert_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## insert b[j]
18 | if word1[j - 1] != word2[i - 1]:
19 | replace_cost = _ed_replace_cost(j-1, i-1, word1, word2, scores, ct_labels_inv) ## replace a[i] with b[j]
20 | else:
21 | replace_cost = 0
22 | dp[i][j] = min(dp[i-1][j] + insert_cost, dp[i][j-1] + delect_cost, dp[i-1][j-1] + replace_cost)
23 |
24 | return dp[n][m]
25 |
26 | @njit
27 | def _ed_delete_cost(j, i, word1, word2, scores, ct_labels_inv):
28 | ## delete a[i]
29 | return _get_score(scores[j], word1[j], ct_labels_inv)
30 |
31 | @njit
32 | def _ed_insert_cost(i, j, word1, word2, scores, ct_labels_inv):
33 | ## insert b[j]
34 | if i < len(word1) - 1:
35 | return (_get_score(scores[i], word1[i], ct_labels_inv) + _get_score(scores[i+1], word1[i+1], ct_labels_inv))/2
36 | else:
37 | return _get_score(scores[i], word1[i], ct_labels_inv)
38 |
39 | @njit
40 | def _ed_replace_cost(i, j, word1, word2, scores, ct_labels_inv):
41 | ## replace a[i] with b[j]
42 | # if word1 == "eeatpisaababarait".upper():
43 | # print(scores[c2][i]/scores[c1][i])
44 | return max(1 - _get_score(scores[i], word2[j], ct_labels_inv)/_get_score(scores[i], word1[i], ct_labels_inv)*5, 0)
45 |
46 | @njit
47 | def _get_score(scores, char, ct_labels_inv):
48 | upper = ct_labels_inv[char.upper()]
49 | lower = ct_labels_inv[char.lower()]
50 | return max(scores[upper], scores[lower])
51 |
52 | class LexiconMatcher:
53 | def __init__(self, dataset, lexicon_type, use_lexicon, ct_labels, weighted_ed=False):
54 | self.use_lexicon = use_lexicon
55 | self.lexicon_type = lexicon_type
56 | self.dataset = dataset
57 | self.ct_labels_inv = Dict.empty(
58 | key_type=types.string,
59 | value_type=types.int64,
60 | )
61 | for i, c in enumerate(ct_labels):
62 | self.ct_labels_inv[c] = i
63 | # maps char to index
64 | self.is_full_lex_dataset = "totaltext" in dataset or "ctw1500" in dataset
65 | self._load_lexicon(dataset, lexicon_type)
66 | self.weighted_ed = weighted_ed
67 |
68 | def find_match_word(self, rec_str, img_id=None, scores=None):
69 | if not self.use_lexicon:
70 | return rec_str
71 | rec_str = rec_str.upper()
72 | dist_min = 100
73 | match_word = ''
74 | match_dist = 100
75 |
76 | lexicons = self.lexicons if self.lexicon_type != 3 else self.lexicons[img_id]
77 | pairs = self.pairs if self.lexicon_type != 3 else self.pairs[img_id]
78 |
79 | # scores of shape (seq_len, n_symbols) must be provided for weighted editdistance
80 | assert not self.weighted_ed or scores is not None
81 |
82 | for word in lexicons:
83 | word = word.upper()
84 | if self.weighted_ed:
85 | ed = weighted_edit_distance(rec_str, word, scores, self.ct_labels_inv)
86 | else:
87 | ed = editdistance.eval(rec_str, word)
88 | if ed < dist_min:
89 | dist_min = ed
90 | match_word = pairs[word]
91 | match_dist = ed
92 |
93 | if self.is_full_lex_dataset:
94 | # always return matched results for the full lexicon (for totaltext/ctw1500)
95 | return match_word
96 | else:
97 | # filter unmatched words for icdar
98 | return match_word if match_dist < 2.5 or self.lexicon_type == 1 else None
99 |
100 | @staticmethod
101 | def _get_lexicon_path(dataset):
102 | if "icdar2015" in dataset:
103 | g_lexicon_path = "datasets/evaluation/lexicons/ic15/GenericVocabulary_new.txt"
104 | g_pairlist_path = "datasets/evaluation/lexicons/ic15/GenericVocabulary_pair_list.txt"
105 | w_lexicon_path = "datasets/evaluation/lexicons/ic15/ch4_test_vocabulary_new.txt"
106 | w_pairlist_path = "datasets/evaluation/lexicons/ic15/ch4_test_vocabulary_pair_list.txt"
107 | s_lexicon_paths = [
108 | (str(fid+1), f"datasets/evaluation/lexicons/ic15/new_strong_lexicon/new_voc_img_{fid+1}.txt") for fid in range(500)]
109 | s_pairlist_paths = [
110 | (str(fid+1), f"datasets/evaluation/lexicons/ic15/new_strong_lexicon/pair_voc_img_{fid+1}.txt") for fid in range(500)]
111 | elif "totaltext" in dataset:
112 | s_lexicon_paths = s_pairlist_paths = None
113 | g_lexicon_path = "datasets/evaluation/lexicons/totaltext/tt_lexicon.txt"
114 | g_pairlist_path = "datasets/evaluation/lexicons/totaltext/tt_pair_list.txt"
115 | w_lexicon_path = "datasets/evaluation/lexicons/totaltext/weak_voc_new.txt"
116 | w_pairlist_path = "datasets/evaluation/lexicons/totaltext/weak_voc_pair_list.txt"
117 | elif "ctw1500" in dataset:
118 | s_lexicon_paths = s_pairlist_paths = w_lexicon_path = w_pairlist_path = None
119 | g_lexicon_path = "datasets/evaluation/lexicons/ctw1500/ctw1500_lexicon.txt"
120 | g_pairlist_path = "datasets/evaluation/lexicons/ctw1500/ctw1500_pair_list.txt"
121 | return g_lexicon_path, g_pairlist_path, w_lexicon_path, w_pairlist_path, s_lexicon_paths, s_pairlist_paths
122 |
123 | def _load_lexicon(self, dataset, lexicon_type):
124 | if not self.use_lexicon:
125 | return
126 | g_lexicon_path, g_pairlist_path, w_lexicon_path, w_pairlist_path, s_lexicon_path, s_pairlist_path = self._get_lexicon_path(
127 | dataset)
128 | if lexicon_type in (1, 2):
129 | # generic/weak lexicon
130 | lexicon_path = g_lexicon_path if lexicon_type == 1 else w_lexicon_path
131 | pairlist_path = g_pairlist_path if lexicon_type == 1 else w_pairlist_path
132 | if lexicon_path is None or pairlist_path is None:
133 | self.use_lexicon = False
134 | return
135 | with open(pairlist_path) as fp:
136 | pairs = dict()
137 | for line in fp.readlines():
138 | line = line.strip()
139 | if self.is_full_lex_dataset:
140 | # might contain space in key word
141 | split = line.split(' ')
142 | half = len(split) // 2
143 | word = ' '.join(split[:half]).upper()
144 | else:
145 | word = line.split(' ')[0].upper()
146 | word_gt = line[len(word)+1:]
147 | pairs[word] = word_gt
148 | with open(lexicon_path) as fp:
149 | lexicons = []
150 | for line in fp.readlines():
151 | lexicons.append(line.strip())
152 | self.lexicons = lexicons
153 | self.pairs = pairs
154 | elif lexicon_type == 3:
155 | # strong lexicon
156 | if s_lexicon_path is None or s_pairlist_path is None:
157 | self.use_lexicon = False
158 | return
159 | lexicons, pairlists = dict(), dict()
160 | for (fid, lexicon_path), (_, pairlist_path) in zip(s_lexicon_path, s_pairlist_path):
161 | with open(lexicon_path) as fp:
162 | lexicon = []
163 | for line in fp.readlines():
164 | lexicon.append(line.strip())
165 | with open(pairlist_path) as fp:
166 | pairs = dict()
167 | for line in fp.readlines():
168 | line = line.strip()
169 | word = line.split(' ')[0].upper()
170 | word_gt = line[len(word)+1:]
171 | pairs[word] = word_gt
172 | lexicons[fid] = lexicon
173 | pairlists[fid] = pairs
174 | self.lexicons = lexicons
175 | self.pairs = pairlists
176 |
--------------------------------------------------------------------------------
/adet/evaluation/text_evaluation.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import copy
3 | import io
4 | import itertools
5 | import json
6 | import logging
7 | import numpy as np
8 | import os
9 | import re
10 | import torch
11 | from collections import OrderedDict
12 | from fvcore.common.file_io import PathManager
13 | from pycocotools.coco import COCO
14 |
15 | from detectron2.utils import comm
16 | from detectron2.data import MetadataCatalog
17 | from detectron2.evaluation.evaluator import DatasetEvaluator
18 |
19 | import glob
20 | import shutil
21 | from shapely.geometry import Polygon, LinearRing
22 | from adet.evaluation import text_eval_script
23 | import zipfile
24 | import pickle
25 |
26 | from adet.evaluation.lexicon_procesor import LexiconMatcher
27 |
28 | NULL_CHAR = u'口'
29 |
30 | class TextEvaluator(DatasetEvaluator):
31 | """
32 | Evaluate text proposals and recognition.
33 | """
34 |
35 | def __init__(self, dataset_name, cfg, distributed, output_dir=None):
36 | self._tasks = ("polygon", "recognition")
37 | self._distributed = distributed
38 | self._output_dir = output_dir
39 |
40 | self._cpu_device = torch.device("cpu")
41 | self._logger = logging.getLogger(__name__)
42 |
43 | self._metadata = MetadataCatalog.get(dataset_name)
44 | if not hasattr(self._metadata, "json_file"):
45 | raise AttributeError(
46 | f"json_file was not found in MetaDataCatalog for '{dataset_name}'."
47 | )
48 |
49 | self.voc_size = cfg.MODEL.BATEXT.VOC_SIZE
50 | self.use_customer_dictionary = cfg.MODEL.BATEXT.CUSTOM_DICT
51 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON
52 | if not self.use_customer_dictionary:
53 | self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~']
54 | else:
55 | with open(self.use_customer_dictionary, 'rb') as fp:
56 | self.CTLABELS = pickle.load(fp)
57 | self._lexicon_matcher = LexiconMatcher(dataset_name, cfg.TEST.LEXICON_TYPE, cfg.TEST.USE_LEXICON,
58 | self.CTLABELS + [NULL_CHAR],
59 | weighted_ed=cfg.TEST.WEIGHTED_EDIT_DIST)
60 | assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS))
61 |
62 | json_file = PathManager.get_local_path(self._metadata.json_file)
63 | with contextlib.redirect_stdout(io.StringIO()):
64 | self._coco_api = COCO(json_file)
65 |
66 | # use dataset_name to decide eval_gt_path
67 | if "totaltext" in dataset_name:
68 | self._text_eval_gt_path = "datasets/evaluation/gt_totaltext.zip"
69 | self._word_spotting = True
70 | elif "ctw1500" in dataset_name:
71 | self._text_eval_gt_path = "datasets/evaluation/gt_ctw1500.zip"
72 | self._word_spotting = False
73 | elif "icdar2015" in dataset_name:
74 | self._text_eval_gt_path = "datasets/evaluation/gt_icdar2015.zip"
75 | self._word_spotting = False
76 | else:
77 | self._text_eval_gt_path = ""
78 | self._text_eval_confidence = cfg.MODEL.FCOS.INFERENCE_TH_TEST
79 |
80 | def reset(self):
81 | self._predictions = []
82 |
83 | def process(self, inputs, outputs):
84 | for input, output in zip(inputs, outputs):
85 | prediction = {"image_id": input["image_id"]}
86 |
87 | instances = output["instances"].to(self._cpu_device)
88 | prediction["instances"] = self.instances_to_coco_json(instances, input["image_id"])
89 | self._predictions.append(prediction)
90 |
91 | def to_eval_format(self, file_path, temp_dir="temp_det_results", cf_th=0.5):
92 | def fis_ascii(s):
93 | a = (ord(c) < 128 for c in s)
94 | return all(a)
95 |
96 | def de_ascii(s):
97 | a = [c for c in s if ord(c) < 128]
98 | outa = ''
99 | for i in a:
100 | outa +=i
101 | return outa
102 |
103 | with open(file_path, 'r') as f:
104 | data = json.load(f)
105 | with open('temp_all_det_cors.txt', 'w') as f2:
106 | for ix in range(len(data)):
107 | if data[ix]['score'] > 0.1:
108 | outstr = '{}: '.format(data[ix]['image_id'])
109 | xmin = 1000000
110 | ymin = 1000000
111 | xmax = 0
112 | ymax = 0
113 | for i in range(len(data[ix]['polys'])):
114 | outstr = outstr + str(int(data[ix]['polys'][i][0])) +','+str(int(data[ix]['polys'][i][1])) +','
115 | ass = de_ascii(data[ix]['rec'])
116 | if len(ass)>=0: #
117 | outstr = outstr + str(round(data[ix]['score'], 3)) +',####'+ass+'\n'
118 | f2.writelines(outstr)
119 | f2.close()
120 | dirn = temp_dir
121 | lsc = [cf_th]
122 | fres = open('temp_all_det_cors.txt', 'r').readlines()
123 | for isc in lsc:
124 | if not os.path.isdir(dirn):
125 | os.mkdir(dirn)
126 |
127 | for line in fres:
128 | line = line.strip()
129 | s = line.split(': ')
130 | filename = '{:07d}.txt'.format(int(s[0]))
131 | outName = os.path.join(dirn, filename)
132 | with open(outName, 'a') as fout:
133 | ptr = s[1].strip().split(',####')
134 | score = ptr[0].split(',')[-1]
135 | if float(score) < isc:
136 | continue
137 | cors = ','.join(e for e in ptr[0].split(',')[:-1])
138 | fout.writelines(cors+',####'+ptr[1]+'\n')
139 | os.remove("temp_all_det_cors.txt")
140 |
141 | def sort_detection(self, temp_dir):
142 | origin_file = temp_dir
143 | output_file = "final_"+temp_dir
144 |
145 | if not os.path.isdir(output_file):
146 | os.mkdir(output_file)
147 |
148 | files = glob.glob(origin_file+'*.txt')
149 | files.sort()
150 |
151 | for i in files:
152 | out = i.replace(origin_file, output_file)
153 | fin = open(i, 'r').readlines()
154 | fout = open(out, 'w')
155 | for iline, line in enumerate(fin):
156 | ptr = line.strip().split(',####')
157 | rec = ptr[1]
158 | cors = ptr[0].split(',')
159 | assert(len(cors) %2 == 0), 'cors invalid.'
160 | pts = [(int(cors[j]), int(cors[j+1])) for j in range(0,len(cors),2)]
161 | try:
162 | pgt = Polygon(pts)
163 | except Exception as e:
164 | print(e)
165 | print('An invalid detection in {} line {} is removed ... '.format(i, iline))
166 | continue
167 |
168 | if not pgt.is_valid:
169 | print('An invalid detection in {} line {} is removed ... '.format(i, iline))
170 | continue
171 |
172 | pRing = LinearRing(pts)
173 | if pRing.is_ccw:
174 | pts.reverse()
175 | outstr = ''
176 | for ipt in pts[:-1]:
177 | outstr += (str(int(ipt[0]))+','+ str(int(ipt[1]))+',')
178 | outstr += (str(int(pts[-1][0]))+','+ str(int(pts[-1][1])))
179 | outstr = outstr+',####' + rec
180 | fout.writelines(outstr+'\n')
181 | fout.close()
182 | os.chdir(output_file)
183 |
184 | def zipdir(path, ziph):
185 | # ziph is zipfile handle
186 | for root, dirs, files in os.walk(path):
187 | for file in files:
188 | ziph.write(os.path.join(root, file))
189 |
190 | zipf = zipfile.ZipFile('../det.zip', 'w', zipfile.ZIP_DEFLATED)
191 | zipdir('./', zipf)
192 | zipf.close()
193 | os.chdir("../")
194 | # clean temp files
195 | shutil.rmtree(origin_file)
196 | shutil.rmtree(output_file)
197 | return "det.zip"
198 |
199 | def evaluate_with_official_code(self, result_path, gt_path):
200 | return text_eval_script.text_eval_main(det_file=result_path, gt_file=gt_path, is_word_spotting=self._word_spotting)
201 |
202 | def evaluate(self):
203 | if self._distributed:
204 | comm.synchronize()
205 | predictions = comm.gather(self._predictions, dst=0)
206 | predictions = list(itertools.chain(*predictions))
207 |
208 | if not comm.is_main_process():
209 | return {}
210 | else:
211 | predictions = self._predictions
212 |
213 | if len(predictions) == 0:
214 | self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
215 | return {}
216 |
217 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
218 | PathManager.mkdirs(self._output_dir)
219 |
220 | file_path = os.path.join(self._output_dir, "text_results.json")
221 | self._logger.info("Saving results to {}".format(file_path))
222 | with PathManager.open(file_path, "w") as f:
223 | f.write(json.dumps(coco_results))
224 | f.flush()
225 |
226 | self._results = OrderedDict()
227 |
228 | if not self._text_eval_gt_path:
229 | return copy.deepcopy(self._results)
230 | # eval text
231 | temp_dir = "temp_det_results/"
232 | self.to_eval_format(file_path, temp_dir, self._text_eval_confidence)
233 | result_path = self.sort_detection(temp_dir)
234 | text_result = self.evaluate_with_official_code(result_path, self._text_eval_gt_path)
235 | os.remove(result_path)
236 |
237 | # parse
238 | template = "(\S+): (\S+): (\S+), (\S+): (\S+), (\S+): (\S+)"
239 | for task in ("e2e_method", "det_only_method"):
240 | result = text_result[task]
241 | groups = re.match(template, result).groups()
242 | self._results[groups[0]] = {groups[i*2+1]: float(groups[(i+1)*2]) for i in range(3)}
243 |
244 | return copy.deepcopy(self._results)
245 |
246 |
247 | def instances_to_coco_json(self, instances, img_id):
248 | num_instances = len(instances)
249 | if num_instances == 0:
250 | return []
251 |
252 | scores = instances.scores.tolist()
253 | if self.use_polygon:
254 | pnts = instances.polygons.numpy()
255 | else:
256 | pnts = instances.beziers.numpy()
257 | recs = instances.recs.numpy()
258 | rec_scores = instances.rec_scores.numpy()
259 |
260 | results = []
261 | for pnt, rec, score, rec_score in zip(pnts, recs, scores, rec_scores):
262 | # convert beziers to polygons
263 | poly = self.pnt_to_polygon(pnt)
264 | s = self.decode(rec)
265 | word = self._lexicon_matcher.find_match_word(s, img_id=str(img_id), scores=rec_score)
266 | if word is None:
267 | continue
268 | result = {
269 | "image_id": img_id,
270 | "category_id": 1,
271 | "polys": poly,
272 | "rec": word,
273 | "score": score,
274 | }
275 | results.append(result)
276 | return results
277 |
278 |
279 | def pnt_to_polygon(self, ctrl_pnt):
280 | if self.use_polygon:
281 | return ctrl_pnt.reshape(-1, 2).tolist()
282 | else:
283 | u = np.linspace(0, 1, 20)
284 | ctrl_pnt = ctrl_pnt.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4)
285 | points = np.outer((1 - u) ** 3, ctrl_pnt[:, 0]) \
286 | + np.outer(3 * u * ((1 - u) ** 2), ctrl_pnt[:, 1]) \
287 | + np.outer(3 * (u ** 2) * (1 - u), ctrl_pnt[:, 2]) \
288 | + np.outer(u ** 3, ctrl_pnt[:, 3])
289 |
290 | # convert points to polygon
291 | points = np.concatenate((points[:, :2], points[:, 2:]), axis=0)
292 | return points.tolist()
293 |
294 | def ctc_decode(self, rec):
295 | # ctc decoding
296 | last_char = False
297 | s = ''
298 | for c in rec:
299 | c = int(c)
300 | if c < self.voc_size - 1:
301 | if last_char != c:
302 | if self.voc_size == 96:
303 | s += self.CTLABELS[c]
304 | last_char = c
305 | else:
306 | s += str(chr(self.CTLABELS[c]))
307 | last_char = c
308 | elif c == self.voc_size -1:
309 | s += u'口'
310 | else:
311 | last_char = False
312 | return s
313 |
314 |
315 | def decode(self, rec):
316 | s = ''
317 | for c in rec:
318 | c = int(c)
319 | if c < self.voc_size - 1:
320 | if self.voc_size == 96:
321 | s += self.CTLABELS[c]
322 | else:
323 | s += str(chr(self.CTLABELS[c]))
324 | elif c == self.voc_size -1:
325 | s += NULL_CHAR
326 |
327 | return s
328 |
329 |
--------------------------------------------------------------------------------
/adet/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .ms_deform_attn import MSDeformAttn
2 |
3 | __all__ = [k for k in globals().keys() if not k.startswith("_")]
--------------------------------------------------------------------------------
/adet/layers/csrc/DeformAttn/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/adet/layers/csrc/DeformAttn/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/adet/layers/csrc/DeformAttn/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/adet/layers/csrc/cuda_version.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | namespace adet {
4 | int get_cudart_version() {
5 | return CUDART_VERSION;
6 | }
7 | } // namespace adet
8 |
--------------------------------------------------------------------------------
/adet/layers/csrc/vision.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | #include "DeformAttn/ms_deform_attn.h"
3 |
4 | namespace adet {
5 |
6 | #ifdef WITH_CUDA
7 | extern int get_cudart_version();
8 | #endif
9 |
10 | std::string get_cuda_version() {
11 | #ifdef WITH_CUDA
12 | std::ostringstream oss;
13 |
14 | // copied from
15 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
16 | auto printCudaStyleVersion = [&](int v) {
17 | oss << (v / 1000) << "." << (v / 10 % 100);
18 | if (v % 10 != 0) {
19 | oss << "." << (v % 10);
20 | }
21 | };
22 | printCudaStyleVersion(get_cudart_version());
23 | return oss.str();
24 | #else
25 | return std::string("not available");
26 | #endif
27 | }
28 |
29 | // similar to
30 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
31 | std::string get_compiler_version() {
32 | std::ostringstream ss;
33 | #if defined(__GNUC__)
34 | #ifndef __clang__
35 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
36 | #endif
37 | #endif
38 |
39 | #if defined(__clang_major__)
40 | {
41 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
42 | << __clang_patchlevel__;
43 | }
44 | #endif
45 |
46 | #if defined(_MSC_VER)
47 | { ss << "MSVC " << _MSC_FULL_VER; }
48 | #endif
49 | return ss.str();
50 | }
51 |
52 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
53 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
54 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
55 | }
56 |
57 | } // namespace adet
58 |
--------------------------------------------------------------------------------
/adet/layers/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 | import warnings
9 | import math
10 |
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | from torch.nn.init import xavier_uniform_, constant_
15 | from torch.autograd.function import once_differentiable
16 |
17 | from adet import _C
18 |
19 | class _MSDeformAttnFunction(torch.autograd.Function):
20 | @staticmethod
21 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
22 | ctx.im2col_step = im2col_step
23 | output = _C.ms_deform_attn_forward(
24 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
25 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
26 | return output
27 |
28 | @staticmethod
29 | @once_differentiable
30 | def backward(ctx, grad_output):
31 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
32 | grad_value, grad_sampling_loc, grad_attn_weight = \
33 | _C.ms_deform_attn_backward(
34 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
35 |
36 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
37 |
38 |
39 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
40 | # for debug and test only,
41 | # need to use cuda version instead
42 | N_, S_, M_, D_ = value.shape
43 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
44 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
45 | sampling_grids = 2 * sampling_locations - 1
46 | sampling_value_list = []
47 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
48 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
49 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
50 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
51 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
52 | # N_*M_, D_, Lq_, P_
53 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
54 | mode='bilinear', padding_mode='zeros', align_corners=False)
55 | sampling_value_list.append(sampling_value_l_)
56 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
57 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
58 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
59 | return output.transpose(1, 2).contiguous()
60 |
61 |
62 | def _is_power_of_2(n):
63 | if (not isinstance(n, int)) or (n < 0):
64 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
65 | return (n & (n-1) == 0) and n != 0
66 |
67 |
68 | class MSDeformAttn(nn.Module):
69 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
70 | """
71 | Multi-Scale Deformable Attention Module
72 | :param d_model hidden dimension
73 | :param n_levels number of feature levels
74 | :param n_heads number of attention heads
75 | :param n_points number of sampling points per attention head per feature level
76 | """
77 | super().__init__()
78 | if d_model % n_heads != 0:
79 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
80 | _d_per_head = d_model // n_heads
81 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
82 | if not _is_power_of_2(_d_per_head):
83 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
84 | "which is more efficient in our CUDA implementation.")
85 |
86 | self.im2col_step = 64
87 |
88 | self.d_model = d_model
89 | self.n_levels = n_levels
90 | self.n_heads = n_heads
91 | self.n_points = n_points
92 |
93 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
94 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
95 | self.value_proj = nn.Linear(d_model, d_model)
96 | self.output_proj = nn.Linear(d_model, d_model)
97 |
98 | self._reset_parameters()
99 |
100 | def _reset_parameters(self):
101 | constant_(self.sampling_offsets.weight.data, 0.)
102 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
103 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
104 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
105 | for i in range(self.n_points):
106 | grid_init[:, :, i, :] *= i + 1
107 | with torch.no_grad():
108 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
109 | constant_(self.attention_weights.weight.data, 0.)
110 | constant_(self.attention_weights.bias.data, 0.)
111 | xavier_uniform_(self.value_proj.weight.data)
112 | constant_(self.value_proj.bias.data, 0.)
113 | xavier_uniform_(self.output_proj.weight.data)
114 | constant_(self.output_proj.bias.data, 0.)
115 |
116 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
117 | """
118 | :param query (N, Length_{query}, C)
119 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
120 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
121 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
122 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
123 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
124 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
125 |
126 | :return output (N, Length_{query}, C)
127 | """
128 | N, Len_q, _ = query.shape
129 | N, Len_in, _ = input_flatten.shape
130 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
131 |
132 | value = self.value_proj(input_flatten)
133 | if input_padding_mask is not None:
134 | value = value.masked_fill(input_padding_mask[..., None], float(0))
135 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
136 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
137 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
138 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
139 | # N, Len_q, n_heads, n_levels, n_points, 2
140 | if reference_points.shape[-1] == 2:
141 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
142 | sampling_locations = reference_points[:, :, None, :, None, :] \
143 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
144 | elif reference_points.shape[-1] == 4:
145 | sampling_locations = reference_points[:, :, None, :, None, :2] \
146 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
147 | else:
148 | raise ValueError(
149 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
150 | output = _MSDeformAttnFunction.apply(
151 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
152 | output = self.output_proj(output)
153 | return output
154 |
--------------------------------------------------------------------------------
/adet/layers/pos_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 |
5 | class PositionalEncoding1D(nn.Module):
6 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
7 | """
8 | :param channels: The last dimension of the tensor you want to apply pos emb to.
9 | """
10 | super().__init__()
11 | self.channels = num_pos_feats
12 | dim_t = torch.arange(0, self.channels, 2).float()
13 | if scale is not None and normalize is False:
14 | raise ValueError("normalize should be True if scale is passed")
15 | if scale is None:
16 | scale = 2 * np.pi
17 | self.scale = scale
18 | self.normalize = normalize
19 | inv_freq = 1. / (temperature ** (dim_t / self.channels))
20 | self.register_buffer('inv_freq', inv_freq)
21 |
22 | def forward(self, tensor):
23 | """
24 | :param tensor: A 2d tensor of size (len, c)
25 | :return: Positional Encoding Matrix of size (len, c)
26 | """
27 | if tensor.ndim != 2:
28 | raise RuntimeError("The input tensor has to be 2D!")
29 | x, orig_ch = tensor.shape
30 | pos_x = torch.arange(
31 | 1, x + 1, device=tensor.device).type(self.inv_freq.type())
32 |
33 | if self.normalize:
34 | eps = 1e-6
35 | pos_x = pos_x / (pos_x[-1:] + eps) * self.scale
36 |
37 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
38 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1)
39 | emb = torch.zeros((x, self.channels),
40 | device=tensor.device).type(tensor.type())
41 | emb[:, :self.channels] = emb_x
42 |
43 | return emb[:, :orig_ch]
44 |
45 |
46 | class PositionalEncoding2D(nn.Module):
47 | """
48 | This is a more standard version of the position embedding, very similar to the one
49 | used by the Attention is all you need paper, generalized to work on images.
50 | """
51 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
52 | super().__init__()
53 | self.num_pos_feats = num_pos_feats
54 | self.temperature = temperature
55 | self.normalize = normalize
56 | if scale is not None and normalize is False:
57 | raise ValueError("normalize should be True if scale is passed")
58 | if scale is None:
59 | scale = 2 * np.pi
60 | self.scale = scale
61 |
62 | def forward(self, tensors):
63 | x = tensors.tensors
64 | mask = tensors.mask
65 | assert mask is not None
66 | not_mask = ~mask
67 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
68 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
69 | if self.normalize:
70 | eps = 1e-6
71 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
72 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
73 |
74 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75 | dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats)
76 |
77 | pos_x = x_embed[:, :, :, None] / dim_t
78 | pos_y = y_embed[:, :, :, None] / dim_t
79 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
80 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
81 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
82 | return pos
83 |
--------------------------------------------------------------------------------
/adet/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from .transformer_detector import TransformerDetector
3 |
4 | _EXCLUDE = {"torch", "ShapeSpec"}
5 | __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
6 |
--------------------------------------------------------------------------------
/adet/modeling/testr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpc-ucsd/TESTR/f369f138e041d1d27348a1f6600e456452001d23/adet/modeling/testr/__init__.py
--------------------------------------------------------------------------------
/adet/modeling/testr/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import copy
5 | from adet.utils.misc import accuracy, generalized_box_iou, box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, is_dist_avail_and_initialized
6 | from detectron2.utils.comm import get_world_size
7 |
8 |
9 | def sigmoid_focal_loss(inputs, targets, num_inst, alpha: float = 0.25, gamma: float = 2):
10 | """
11 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
12 | Args:
13 | inputs: A float tensor of arbitrary shape.
14 | The predictions for each example.
15 | targets: A float tensor with the same shape as inputs. Stores the binary
16 | classification label for each element in inputs
17 | (0 for the negative class and 1 for the positive class).
18 | alpha: (optional) Weighting factor in range (0,1) to balance
19 | positive vs negative examples. Default = -1 (no weighting).
20 | gamma: Exponent of the modulating factor (1 - p_t) to
21 | balance easy vs hard examples.
22 | Returns:
23 | Loss tensor
24 | """
25 | prob = inputs.sigmoid()
26 | ce_loss = F.binary_cross_entropy_with_logits(
27 | inputs, targets, reduction="none")
28 | p_t = prob * targets + (1 - prob) * (1 - targets)
29 | loss = ce_loss * ((1 - p_t) ** gamma)
30 |
31 | if alpha >= 0:
32 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
33 | loss = alpha_t * loss
34 |
35 | if loss.ndim == 4:
36 | return loss.mean((1, 2)).sum() / num_inst
37 | elif loss.ndim == 3:
38 | return loss.mean(1).sum() / num_inst
39 | else:
40 | raise NotImplementedError(f"Unsupported dim {loss.ndim}")
41 |
42 |
43 | class SetCriterion(nn.Module):
44 | """ This class computes the loss for TESTR.
45 | The process happens in two steps:
46 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
47 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
48 | """
49 |
50 | def __init__(self, num_classes, enc_matcher, dec_matcher, weight_dict, enc_losses, dec_losses, num_ctrl_points, focal_alpha=0.25, focal_gamma=2.0):
51 | """ Create the criterion.
52 | Parameters:
53 | num_classes: number of object categories, omitting the special no-object category
54 | matcher: module able to compute a matching between targets and proposals
55 | weight_dict: dict containing as key the names of the losses and as values their relative weight.
56 | losses: list of all the losses to be applied. See get_loss for list of available losses.
57 | focal_alpha: alpha in Focal Loss
58 | """
59 | super().__init__()
60 | self.num_classes = num_classes
61 | self.enc_matcher = enc_matcher
62 | self.dec_matcher = dec_matcher
63 | self.weight_dict = weight_dict
64 | self.enc_losses = enc_losses
65 | self.dec_losses = dec_losses
66 | self.focal_alpha = focal_alpha
67 | self.focal_gamma = focal_gamma
68 | self.num_ctrl_points = num_ctrl_points
69 |
70 | def loss_labels(self, outputs, targets, indices, num_inst, log=False):
71 | """Classification loss (NLL)
72 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
73 | """
74 | assert 'pred_logits' in outputs
75 | src_logits = outputs['pred_logits']
76 |
77 | idx = self._get_src_permutation_idx(indices)
78 |
79 | target_classes = torch.full(src_logits.shape[:-1], self.num_classes,
80 | dtype=torch.int64, device=src_logits.device)
81 | target_classes_o = torch.cat([t["labels"][J]
82 | for t, (_, J) in zip(targets, indices)])
83 | if len(target_classes_o.shape) < len(target_classes[idx].shape):
84 | target_classes_o = target_classes_o[..., None]
85 | target_classes[idx] = target_classes_o
86 |
87 | shape = list(src_logits.shape)
88 | shape[-1] += 1
89 | target_classes_onehot = torch.zeros(shape,
90 | dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
91 | target_classes_onehot.scatter_(-1, target_classes.unsqueeze(-1), 1)
92 | target_classes_onehot = target_classes_onehot[..., :-1]
93 | loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_inst,
94 | alpha=self.focal_alpha, gamma=self.focal_gamma) * src_logits.shape[1]
95 | losses = {'loss_ce': loss_ce}
96 |
97 | if log:
98 | # TODO this should probably be a separate loss, not hacked in this one here
99 | losses['class_error'] = 100 - \
100 | accuracy(src_logits[idx], target_classes_o)[0]
101 | return losses
102 |
103 | @torch.no_grad()
104 | def loss_cardinality(self, outputs, targets, indices, num_inst):
105 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
106 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
107 | """
108 | pred_logits = outputs['pred_logits']
109 | device = pred_logits.device
110 | tgt_lengths = torch.as_tensor(
111 | [len(v["labels"]) for v in targets], device=device)
112 | card_pred = (pred_logits.mean(-2).argmax(-1) == 0).sum(1)
113 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
114 | losses = {'cardinality_error': card_err}
115 | return losses
116 |
117 | def loss_boxes(self, outputs, targets, indices, num_inst):
118 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
119 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
120 | The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
121 | """
122 | assert 'pred_boxes' in outputs
123 | idx = self._get_src_permutation_idx(indices)
124 | src_boxes = outputs['pred_boxes'][idx]
125 | target_boxes = torch.cat([t['boxes'][i]
126 | for t, (_, i) in zip(targets, indices)], dim=0)
127 |
128 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
129 |
130 | losses = {}
131 | losses['loss_bbox'] = loss_bbox.sum() / num_inst
132 |
133 | loss_giou = 1 - torch.diag(generalized_box_iou(
134 | box_cxcywh_to_xyxy(src_boxes),
135 | box_cxcywh_to_xyxy(target_boxes)))
136 | losses['loss_giou'] = loss_giou.sum() / num_inst
137 | return losses
138 |
139 | def loss_texts(self, outputs, targets, indices, num_inst):
140 | assert 'pred_texts' in outputs
141 | idx = self._get_src_permutation_idx(indices)
142 | src_texts = outputs['pred_texts'][idx]
143 | target_ctrl_points = torch.cat([t['texts'][i] for t, (_, i) in zip(targets, indices)], dim=0)
144 | return {'loss_texts': F.cross_entropy(src_texts.transpose(1, 2), target_ctrl_points.long())}
145 |
146 |
147 | def loss_ctrl_points(self, outputs, targets, indices, num_inst):
148 | """Compute the losses related to the keypoint coordinates, the L1 regression loss
149 | """
150 | assert 'pred_ctrl_points' in outputs
151 | idx = self._get_src_permutation_idx(indices)
152 | src_ctrl_points = outputs['pred_ctrl_points'][idx]
153 | target_ctrl_points = torch.cat([t['ctrl_points'][i] for t, (_, i) in zip(targets, indices)], dim=0)
154 |
155 | loss_ctrl_points = F.l1_loss(src_ctrl_points, target_ctrl_points, reduction='sum')
156 |
157 | losses = {'loss_ctrl_points': loss_ctrl_points / num_inst}
158 | return losses
159 |
160 | @staticmethod
161 | def _get_src_permutation_idx(indices):
162 | # permute predictions following indices
163 | batch_idx = torch.cat([torch.full_like(src, i)
164 | for i, (src, _) in enumerate(indices)])
165 | src_idx = torch.cat([src for (src, _) in indices])
166 | return batch_idx, src_idx
167 |
168 | @staticmethod
169 | def _get_tgt_permutation_idx(indices):
170 | # permute targets following indices
171 | batch_idx = torch.cat([torch.full_like(tgt, i)
172 | for i, (_, tgt) in enumerate(indices)])
173 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
174 | return batch_idx, tgt_idx
175 |
176 | def get_loss(self, loss, outputs, targets, indices, num_inst, **kwargs):
177 | loss_map = {
178 | 'labels': self.loss_labels,
179 | 'cardinality': self.loss_cardinality,
180 | 'ctrl_points': self.loss_ctrl_points,
181 | 'boxes': self.loss_boxes,
182 | 'texts': self.loss_texts,
183 | }
184 | assert loss in loss_map, f'do you really want to compute {loss} loss?'
185 | return loss_map[loss](outputs, targets, indices, num_inst, **kwargs)
186 |
187 | def forward(self, outputs, targets):
188 | """ This performs the loss computation.
189 | Parameters:
190 | outputs: dict of tensors, see the output specification of the model for the format
191 | targets: list of dicts, such that len(targets) == batch_size.
192 | The expected keys in each dict depends on the losses applied, see each loss' doc
193 | """
194 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'}
195 |
196 | # Retrieve the matching between the outputs of the last layer and the targets
197 | indices = self.dec_matcher(outputs_without_aux, targets)
198 |
199 | # Compute the average number of target boxes accross all nodes, for normalization purposes
200 | num_inst = sum(len(t['ctrl_points']) for t in targets)
201 | num_inst = torch.as_tensor(
202 | [num_inst], dtype=torch.float, device=next(iter(outputs.values())).device)
203 | if is_dist_avail_and_initialized():
204 | torch.distributed.all_reduce(num_inst)
205 | num_inst = torch.clamp(num_inst / get_world_size(), min=1).item()
206 |
207 | # Compute all the requested losses
208 | losses = {}
209 | for loss in self.dec_losses:
210 | kwargs = {}
211 | losses.update(self.get_loss(loss, outputs, targets,
212 | indices, num_inst, **kwargs))
213 |
214 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
215 | if 'aux_outputs' in outputs:
216 | for i, aux_outputs in enumerate(outputs['aux_outputs']):
217 | indices = self.dec_matcher(aux_outputs, targets)
218 | for loss in self.dec_losses:
219 | kwargs = {}
220 | if loss == 'labels':
221 | # Logging is enabled only for the last layer
222 | kwargs['log'] = False
223 | l_dict = self.get_loss(
224 | loss, aux_outputs, targets, indices, num_inst, **kwargs)
225 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
226 | losses.update(l_dict)
227 |
228 | if 'enc_outputs' in outputs:
229 | enc_outputs = outputs['enc_outputs']
230 | indices = self.enc_matcher(enc_outputs, targets)
231 | for loss in self.enc_losses:
232 | kwargs = {}
233 | if loss == 'labels':
234 | kwargs['log'] = False
235 | l_dict = self.get_loss(
236 | loss, enc_outputs, targets, indices, num_inst, **kwargs)
237 | l_dict = {k + f'_enc': v for k, v in l_dict.items()}
238 | losses.update(l_dict)
239 |
240 | return losses
--------------------------------------------------------------------------------
/adet/modeling/testr/matcher.py:
--------------------------------------------------------------------------------
1 | """
2 | Modules to compute the matching cost and solve the corresponding LSAP.
3 | """
4 | import torch
5 | from scipy.optimize import linear_sum_assignment
6 | from torch import nn
7 | from adet.utils.misc import box_cxcywh_to_xyxy, generalized_box_iou
8 |
9 |
10 | class CtrlPointHungarianMatcher(nn.Module):
11 | """This class computes an assignment between the targets and the predictions of the network
12 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
13 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
14 | while the others are un-matched (and thus treated as non-objects).
15 | """
16 |
17 | def __init__(self,
18 | class_weight: float = 1,
19 | coord_weight: float = 1,
20 | focal_alpha: float = 0.25,
21 | focal_gamma: float = 2.0):
22 | """Creates the matcher
23 | Params:
24 | class_weight: This is the relative weight of the classification error in the matching cost
25 | coord_weight: This is the relative weight of the L1 error of the keypoint coordinates in the matching cost
26 | """
27 | super().__init__()
28 | self.class_weight = class_weight
29 | self.coord_weight = coord_weight
30 | self.alpha = focal_alpha
31 | self.gamma = focal_gamma
32 | assert class_weight != 0 or coord_weight != 0, "all costs cant be 0"
33 |
34 | def forward(self, outputs, targets):
35 | """ Performs the matching
36 | Params:
37 | outputs: This is a dict that contains at least these entries:
38 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
39 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
40 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
41 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
42 | objects in the target) containing the class labels
43 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
44 | Returns:
45 | A list of size batch_size, containing tuples of (index_i, index_j) where:
46 | - index_i is the indices of the selected predictions (in order)
47 | - index_j is the indices of the corresponding selected targets (in order)
48 | For each batch element, it holds:
49 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
50 | """
51 | with torch.no_grad():
52 | bs, num_queries = outputs["pred_logits"].shape[:2]
53 |
54 | # We flatten to compute the cost matrices in a batch
55 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
56 | # [batch_size, n_queries, n_points, 2] --> [batch_size * num_queries, n_points * 2]
57 | out_pts = outputs["pred_ctrl_points"].flatten(0, 1).flatten(-2)
58 |
59 | # Also concat the target labels and boxes
60 | tgt_pts = torch.cat([v["ctrl_points"] for v in targets]).flatten(-2)
61 | neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \
62 | (-(1 - out_prob + 1e-8).log())
63 | pos_cost_class = self.alpha * \
64 | ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
65 | # FIXME: hack here for label ID 0
66 | cost_class = (pos_cost_class[..., 0] - neg_cost_class[..., 0]).mean(-1, keepdims=True)
67 |
68 | cost_kpts = torch.cdist(out_pts, tgt_pts, p=1)
69 |
70 | C = self.class_weight * cost_class + self.coord_weight * cost_kpts
71 | C = C.view(bs, num_queries, -1).cpu()
72 |
73 | sizes = [len(v["ctrl_points"]) for v in targets]
74 | indices = [linear_sum_assignment(
75 | c[i]) for i, c in enumerate(C.split(sizes, -1))]
76 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
77 |
78 |
79 | class BoxHungarianMatcher(nn.Module):
80 | """This class computes an assignment between the targets and the predictions of the network
81 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
82 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
83 | while the others are un-matched (and thus treated as non-objects).
84 | """
85 |
86 | def __init__(self,
87 | class_weight: float = 1,
88 | coord_weight: float = 1,
89 | giou_weight: float = 1,
90 | focal_alpha: float = 0.25,
91 | focal_gamma: float = 2.0):
92 | """Creates the matcher
93 | Params:
94 | class_weight: This is the relative weight of the classification error in the matching cost
95 | coord_weight: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
96 | giou_weight: This is the relative weight of the giou loss of the bounding box in the matching cost
97 | """
98 | super().__init__()
99 | self.class_weight = class_weight
100 | self.coord_weight = coord_weight
101 | self.giou_weight = giou_weight
102 | self.alpha = focal_alpha
103 | self.gamma = focal_gamma
104 | assert class_weight != 0 or coord_weight != 0 or giou_weight != 0, "all costs cant be 0"
105 |
106 | def forward(self, outputs, targets):
107 | """ Performs the matching
108 | Params:
109 | outputs: This is a dict that contains at least these entries:
110 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
111 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
112 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
113 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
114 | objects in the target) containing the class labels
115 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
116 | Returns:
117 | A list of size batch_size, containing tuples of (index_i, index_j) where:
118 | - index_i is the indices of the selected predictions (in order)
119 | - index_j is the indices of the corresponding selected targets (in order)
120 | For each batch element, it holds:
121 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
122 | """
123 | with torch.no_grad():
124 | bs, num_queries = outputs["pred_logits"].shape[:2]
125 |
126 | # We flatten to compute the cost matrices in a batch
127 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
128 | out_bbox = outputs["pred_boxes"].flatten(
129 | 0, 1) # [batch_size * num_queries, 4]
130 |
131 | # Also concat the target labels and boxes
132 | tgt_ids = torch.cat([v["labels"] for v in targets])
133 | tgt_bbox = torch.cat([v["boxes"] for v in targets])
134 |
135 | # Compute the classification cost.
136 | neg_cost_class = (1 - self.alpha) * (out_prob ** self.gamma) * \
137 | (-(1 - out_prob + 1e-8).log())
138 | pos_cost_class = self.alpha * \
139 | ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
140 | cost_class = pos_cost_class[:, tgt_ids] - \
141 | neg_cost_class[:, tgt_ids]
142 |
143 | # Compute the L1 cost between boxes
144 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
145 |
146 | # Compute the giou cost betwen boxes
147 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
148 | box_cxcywh_to_xyxy(tgt_bbox))
149 |
150 | # Final cost matrix
151 | C = self.coord_weight * cost_bbox + self.class_weight * \
152 | cost_class + self.giou_weight * cost_giou
153 | C = C.view(bs, num_queries, -1).cpu()
154 |
155 | sizes = [len(v["boxes"]) for v in targets]
156 | indices = [linear_sum_assignment(
157 | c[i]) for i, c in enumerate(C.split(sizes, -1))]
158 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
159 |
160 |
161 | def build_matcher(cfg):
162 | cfg = cfg.MODEL.TRANSFORMER.LOSS
163 | return BoxHungarianMatcher(class_weight=cfg.BOX_CLASS_WEIGHT,
164 | coord_weight=cfg.BOX_COORD_WEIGHT,
165 | giou_weight=cfg.BOX_GIOU_WEIGHT,
166 | focal_alpha=cfg.FOCAL_ALPHA,
167 | focal_gamma=cfg.FOCAL_GAMMA), \
168 | CtrlPointHungarianMatcher(class_weight=cfg.POINT_CLASS_WEIGHT,
169 | coord_weight=cfg.POINT_COORD_WEIGHT,
170 | focal_alpha=cfg.FOCAL_ALPHA,
171 | focal_gamma=cfg.FOCAL_GAMMA)
--------------------------------------------------------------------------------
/adet/modeling/testr/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from adet.layers.deformable_transformer import DeformableTransformer
7 |
8 | from adet.layers.pos_encoding import PositionalEncoding1D
9 | from adet.utils.misc import NestedTensor, inverse_sigmoid_offset, nested_tensor_from_tensor_list, sigmoid_offset
10 |
11 | class MLP(nn.Module):
12 | """ Very simple multi-layer perceptron (also called FFN)"""
13 |
14 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
15 | super().__init__()
16 | self.num_layers = num_layers
17 | h = [hidden_dim] * (num_layers - 1)
18 | self.layers = nn.ModuleList(nn.Linear(n, k)
19 | for n, k in zip([input_dim] + h, h + [output_dim]))
20 |
21 | def forward(self, x):
22 | for i, layer in enumerate(self.layers):
23 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
24 | return x
25 |
26 | class TESTR(nn.Module):
27 | """
28 | Same as :class:`detectron2.modeling.ProposalNetwork`.
29 | Use one stage detector and a second stage for instance-wise prediction.
30 | """
31 | def __init__(self, cfg, backbone):
32 | super().__init__()
33 | self.device = torch.device(cfg.MODEL.DEVICE)
34 |
35 | self.backbone = backbone
36 |
37 | # fmt: off
38 | self.d_model = cfg.MODEL.TRANSFORMER.HIDDEN_DIM
39 | self.nhead = cfg.MODEL.TRANSFORMER.NHEADS
40 | self.num_encoder_layers = cfg.MODEL.TRANSFORMER.ENC_LAYERS
41 | self.num_decoder_layers = cfg.MODEL.TRANSFORMER.DEC_LAYERS
42 | self.dim_feedforward = cfg.MODEL.TRANSFORMER.DIM_FEEDFORWARD
43 | self.dropout = cfg.MODEL.TRANSFORMER.DROPOUT
44 | self.activation = "relu"
45 | self.return_intermediate_dec = True
46 | self.num_feature_levels = cfg.MODEL.TRANSFORMER.NUM_FEATURE_LEVELS
47 | self.dec_n_points = cfg.MODEL.TRANSFORMER.ENC_N_POINTS
48 | self.enc_n_points = cfg.MODEL.TRANSFORMER.DEC_N_POINTS
49 | self.num_proposals = cfg.MODEL.TRANSFORMER.NUM_QUERIES
50 | self.pos_embed_scale = cfg.MODEL.TRANSFORMER.POSITION_EMBEDDING_SCALE
51 | self.num_ctrl_points = cfg.MODEL.TRANSFORMER.NUM_CTRL_POINTS
52 | self.num_classes = 1
53 | self.max_text_len = cfg.MODEL.TRANSFORMER.NUM_CHARS
54 | self.voc_size = cfg.MODEL.TRANSFORMER.VOC_SIZE
55 | self.sigmoid_offset = not cfg.MODEL.TRANSFORMER.USE_POLYGON
56 |
57 | self.text_pos_embed = PositionalEncoding1D(self.d_model, normalize=True, scale=self.pos_embed_scale)
58 | # fmt: on
59 |
60 | self.transformer = DeformableTransformer(
61 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.num_encoder_layers,
62 | num_decoder_layers=self.num_decoder_layers, dim_feedforward=self.dim_feedforward,
63 | dropout=self.dropout, activation=self.activation, return_intermediate_dec=self.return_intermediate_dec,
64 | num_feature_levels=self.num_feature_levels, dec_n_points=self.dec_n_points,
65 | enc_n_points=self.enc_n_points, num_proposals=self.num_proposals,
66 | )
67 | self.ctrl_point_class = nn.Linear(self.d_model, self.num_classes)
68 | self.ctrl_point_coord = MLP(self.d_model, self.d_model, 2, 3)
69 | self.bbox_coord = MLP(self.d_model, self.d_model, 4, 3)
70 | self.bbox_class = nn.Linear(self.d_model, self.num_classes)
71 | self.text_class = nn.Linear(self.d_model, self.voc_size + 1)
72 |
73 | # shared prior between instances (objects)
74 | self.ctrl_point_embed = nn.Embedding(self.num_ctrl_points, self.d_model)
75 | self.text_embed = nn.Embedding(self.max_text_len, self.d_model)
76 |
77 |
78 | if self.num_feature_levels > 1:
79 | strides = [8, 16, 32]
80 | num_channels = [512, 1024, 2048]
81 | num_backbone_outs = len(strides)
82 | input_proj_list = []
83 | for _ in range(num_backbone_outs):
84 | in_channels = num_channels[_]
85 | input_proj_list.append(nn.Sequential(
86 | nn.Conv2d(in_channels, self.d_model, kernel_size=1),
87 | nn.GroupNorm(32, self.d_model),
88 | ))
89 | for _ in range(self.num_feature_levels - num_backbone_outs):
90 | input_proj_list.append(nn.Sequential(
91 | nn.Conv2d(in_channels, self.d_model,
92 | kernel_size=3, stride=2, padding=1),
93 | nn.GroupNorm(32, self.d_model),
94 | ))
95 | in_channels = self.d_model
96 | self.input_proj = nn.ModuleList(input_proj_list)
97 | else:
98 | strides = [32]
99 | num_channels = [2048]
100 | self.input_proj = nn.ModuleList([
101 | nn.Sequential(
102 | nn.Conv2d(
103 | num_channels[0], self.d_model, kernel_size=1),
104 | nn.GroupNorm(32, self.d_model),
105 | )])
106 | self.aux_loss = cfg.MODEL.TRANSFORMER.AUX_LOSS
107 |
108 | prior_prob = 0.01
109 | bias_value = -np.log((1 - prior_prob) / prior_prob)
110 | self.ctrl_point_class.bias.data = torch.ones(self.num_classes) * bias_value
111 | self.bbox_class.bias.data = torch.ones(self.num_classes) * bias_value
112 | nn.init.constant_(self.ctrl_point_coord.layers[-1].weight.data, 0)
113 | nn.init.constant_(self.ctrl_point_coord.layers[-1].bias.data, 0)
114 | for proj in self.input_proj:
115 | nn.init.xavier_uniform_(proj[0].weight, gain=1)
116 | nn.init.constant_(proj[0].bias, 0)
117 |
118 | num_pred = self.num_decoder_layers
119 | self.ctrl_point_class = nn.ModuleList(
120 | [self.ctrl_point_class for _ in range(num_pred)])
121 | self.ctrl_point_coord = nn.ModuleList(
122 | [self.ctrl_point_coord for _ in range(num_pred)])
123 | self.transformer.decoder.bbox_embed = None
124 |
125 | nn.init.constant_(self.bbox_coord.layers[-1].bias.data[2:], 0.0)
126 | self.transformer.bbox_class_embed = self.bbox_class
127 | self.transformer.bbox_embed = self.bbox_coord
128 |
129 | self.to(self.device)
130 |
131 |
132 | def forward(self, samples: NestedTensor):
133 | """ The forward expects a NestedTensor, which consists of:
134 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
135 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
136 | It returns a dict with the following elements:
137 | - "pred_logits": the classification logits (including no-object) for all queries.
138 | Shape= [batch_size x num_queries x (num_classes + 1)]
139 | - "pred_keypoints": The normalized keypoint coordinates for all queries, represented as
140 | (x, y). These values are normalized in [0, 1],
141 | relative to the size of each individual image (disregarding possible padding).
142 | See PostProcess for information on how to retrieve the unnormalized bounding box.
143 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
144 | dictionnaries containing the two above keys for each decoder layer.
145 | """
146 | if isinstance(samples, (list, torch.Tensor)):
147 | samples = nested_tensor_from_tensor_list(samples)
148 | features, pos = self.backbone(samples)
149 |
150 | if self.num_feature_levels == 1:
151 | features = [features[-1]]
152 | pos = [pos[-1]]
153 |
154 | srcs = []
155 | masks = []
156 | for l, feat in enumerate(features):
157 | src, mask = feat.decompose()
158 | srcs.append(self.input_proj[l](src))
159 | masks.append(mask)
160 | assert mask is not None
161 | if self.num_feature_levels > len(srcs):
162 | _len_srcs = len(srcs)
163 | for l in range(_len_srcs, self.num_feature_levels):
164 | if l == _len_srcs:
165 | src = self.input_proj[l](features[-1].tensors)
166 | else:
167 | src = self.input_proj[l](srcs[-1])
168 | m = masks[0]
169 | mask = F.interpolate(
170 | m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
171 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
172 | srcs.append(src)
173 | masks.append(mask)
174 | pos.append(pos_l)
175 |
176 | # n_points, embed_dim --> n_objects, n_points, embed_dim
177 | ctrl_point_embed = self.ctrl_point_embed.weight[None, ...].repeat(self.num_proposals, 1, 1)
178 | text_pos_embed = self.text_pos_embed(self.text_embed.weight)[None, ...].repeat(self.num_proposals, 1, 1)
179 | text_embed = self.text_embed.weight[None, ...].repeat(self.num_proposals, 1, 1)
180 |
181 | hs, hs_text, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(
182 | srcs, masks, pos, ctrl_point_embed, text_embed, text_pos_embed, text_mask=None)
183 |
184 | outputs_classes = []
185 | outputs_coords = []
186 | outputs_texts = []
187 | for lvl in range(hs.shape[0]):
188 | if lvl == 0:
189 | reference = init_reference
190 | else:
191 | reference = inter_references[lvl - 1]
192 | reference = inverse_sigmoid_offset(reference, offset=self.sigmoid_offset)
193 | outputs_class = self.ctrl_point_class[lvl](hs[lvl])
194 | tmp = self.ctrl_point_coord[lvl](hs[lvl])
195 | if reference.shape[-1] == 2:
196 | tmp += reference[:, :, None, :]
197 | else:
198 | assert reference.shape[-1] == 4
199 | tmp += reference[:, :, None, :2]
200 | outputs_texts.append(self.text_class(hs_text[lvl]))
201 | outputs_coord = sigmoid_offset(tmp, offset=self.sigmoid_offset)
202 | outputs_classes.append(outputs_class)
203 | outputs_coords.append(outputs_coord)
204 | outputs_class = torch.stack(outputs_classes)
205 | outputs_coord = torch.stack(outputs_coords)
206 | outputs_text = torch.stack(outputs_texts)
207 |
208 | out = {'pred_logits': outputs_class[-1],
209 | 'pred_ctrl_points': outputs_coord[-1],
210 | 'pred_texts': outputs_text[-1]}
211 | if self.aux_loss:
212 | out['aux_outputs'] = self._set_aux_loss(
213 | outputs_class, outputs_coord, outputs_text)
214 |
215 | enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
216 | out['enc_outputs'] = {
217 | 'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
218 | return out
219 |
220 | @torch.jit.unused
221 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_text):
222 | # this is a workaround to make torchscript happy, as torchscript
223 | # doesn't support dictionary with non-homogeneous values, such
224 | # as a dict having both a Tensor and a list.
225 | return [{'pred_logits': a, 'pred_ctrl_points': b, 'pred_texts': c}
226 | for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_text[:-1])]
--------------------------------------------------------------------------------
/adet/modeling/transformer_detector.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
8 | from detectron2.modeling import build_backbone
9 | from detectron2.modeling.postprocessing import detector_postprocess as d2_postprocesss
10 | from detectron2.structures import ImageList, Instances
11 |
12 | from adet.layers.pos_encoding import PositionalEncoding2D
13 | from adet.modeling.testr.losses import SetCriterion
14 | from adet.modeling.testr.matcher import build_matcher
15 | from adet.modeling.testr.models import TESTR
16 | from adet.utils.misc import NestedTensor, box_xyxy_to_cxcywh
17 |
18 |
19 | class Joiner(nn.Sequential):
20 | def __init__(self, backbone, position_embedding):
21 | super().__init__(backbone, position_embedding)
22 |
23 | def forward(self, tensor_list: NestedTensor):
24 | xs = self[0](tensor_list)
25 | out: List[NestedTensor] = []
26 | pos = []
27 | for _, x in xs.items():
28 | out.append(x)
29 | # position encoding
30 | pos.append(self[1](x).to(x.tensors.dtype))
31 |
32 | return out, pos
33 |
34 | class MaskedBackbone(nn.Module):
35 | """ This is a thin wrapper around D2's backbone to provide padding masking"""
36 | def __init__(self, cfg):
37 | super().__init__()
38 | self.backbone = build_backbone(cfg)
39 | backbone_shape = self.backbone.output_shape()
40 | self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()]
41 | self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels
42 |
43 | def forward(self, images):
44 | features = self.backbone(images.tensor)
45 | masks = self.mask_out_padding(
46 | [features_per_level.shape for features_per_level in features.values()],
47 | images.image_sizes,
48 | images.tensor.device,
49 | )
50 | assert len(features) == len(masks)
51 | for i, k in enumerate(features.keys()):
52 | features[k] = NestedTensor(features[k], masks[i])
53 | return features
54 |
55 | def mask_out_padding(self, feature_shapes, image_sizes, device):
56 | masks = []
57 | assert len(feature_shapes) == len(self.feature_strides)
58 | for idx, shape in enumerate(feature_shapes):
59 | N, _, H, W = shape
60 | masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device)
61 | for img_idx, (h, w) in enumerate(image_sizes):
62 | masks_per_feature_level[
63 | img_idx,
64 | : int(np.ceil(float(h) / self.feature_strides[idx])),
65 | : int(np.ceil(float(w) / self.feature_strides[idx])),
66 | ] = 0
67 | masks.append(masks_per_feature_level)
68 | return masks
69 |
70 |
71 | def detector_postprocess(results, output_height, output_width, mask_threshold=0.5):
72 | """
73 | In addition to the post processing of detectron2, we add scalign for
74 | bezier control points.
75 | """
76 | scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
77 | # results = d2_postprocesss(results, output_height, output_width, mask_threshold)
78 |
79 | # scale bezier points
80 | if results.has("beziers"):
81 | beziers = results.beziers
82 | # scale and clip in place
83 | h, w = results.image_size
84 | beziers[:, 0].clamp_(min=0, max=w)
85 | beziers[:, 1].clamp_(min=0, max=h)
86 | beziers[:, 6].clamp_(min=0, max=w)
87 | beziers[:, 7].clamp_(min=0, max=h)
88 | beziers[:, 8].clamp_(min=0, max=w)
89 | beziers[:, 9].clamp_(min=0, max=h)
90 | beziers[:, 14].clamp_(min=0, max=w)
91 | beziers[:, 15].clamp_(min=0, max=h)
92 | beziers[:, 0::2] *= scale_x
93 | beziers[:, 1::2] *= scale_y
94 |
95 | if results.has("polygons"):
96 | polygons = results.polygons
97 | polygons[:, 0::2] *= scale_x
98 | polygons[:, 1::2] *= scale_y
99 |
100 | return results
101 |
102 |
103 | @META_ARCH_REGISTRY.register()
104 | class TransformerDetector(nn.Module):
105 | """
106 | Same as :class:`detectron2.modeling.ProposalNetwork`.
107 | Use one stage detector and a second stage for instance-wise prediction.
108 | """
109 | def __init__(self, cfg):
110 | super().__init__()
111 | self.device = torch.device(cfg.MODEL.DEVICE)
112 |
113 | d2_backbone = MaskedBackbone(cfg)
114 | N_steps = cfg.MODEL.TRANSFORMER.HIDDEN_DIM // 2
115 | self.test_score_threshold = cfg.MODEL.TRANSFORMER.INFERENCE_TH_TEST
116 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON
117 | backbone = Joiner(d2_backbone, PositionalEncoding2D(N_steps, normalize=True))
118 | backbone.num_channels = d2_backbone.num_channels
119 | self.testr = TESTR(cfg, backbone)
120 |
121 | box_matcher, point_matcher = build_matcher(cfg)
122 |
123 | loss_cfg = cfg.MODEL.TRANSFORMER.LOSS
124 | weight_dict = {'loss_ce': loss_cfg.POINT_CLASS_WEIGHT, 'loss_ctrl_points': loss_cfg.POINT_COORD_WEIGHT, 'loss_texts': loss_cfg.POINT_TEXT_WEIGHT}
125 | enc_weight_dict = {'loss_bbox': loss_cfg.BOX_COORD_WEIGHT, 'loss_giou': loss_cfg.BOX_GIOU_WEIGHT, 'loss_ce': loss_cfg.BOX_CLASS_WEIGHT}
126 | if loss_cfg.AUX_LOSS:
127 | aux_weight_dict = {}
128 | # decoder aux loss
129 | for i in range(cfg.MODEL.TRANSFORMER.DEC_LAYERS - 1):
130 | aux_weight_dict.update(
131 | {k + f'_{i}': v for k, v in weight_dict.items()})
132 | # encoder aux loss
133 | aux_weight_dict.update(
134 | {k + f'_enc': v for k, v in enc_weight_dict.items()})
135 | weight_dict.update(aux_weight_dict)
136 |
137 | enc_losses = ['labels', 'boxes']
138 | dec_losses = ['labels', 'ctrl_points', 'texts']
139 |
140 | self.criterion = SetCriterion(self.testr.num_classes, box_matcher, point_matcher,
141 | weight_dict, enc_losses, dec_losses, self.testr.num_ctrl_points,
142 | focal_alpha=loss_cfg.FOCAL_ALPHA, focal_gamma=loss_cfg.FOCAL_GAMMA)
143 |
144 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
145 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
146 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std
147 | self.to(self.device)
148 |
149 | def preprocess_image(self, batched_inputs):
150 | """
151 | Normalize, pad and batch the input images.
152 | """
153 | images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs]
154 | images = ImageList.from_tensors(images)
155 | return images
156 |
157 | def forward(self, batched_inputs):
158 | """
159 | Args:
160 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
161 | Each item in the list contains the inputs for one image.
162 | For now, each item in the list is a dict that contains:
163 |
164 | * image: Tensor, image in (C, H, W) format.
165 | * instances (optional): groundtruth :class:`Instances`
166 | * proposals (optional): :class:`Instances`, precomputed proposals.
167 |
168 | Other information that's included in the original dicts, such as:
169 |
170 | * "height", "width" (int): the output resolution of the model, used in inference.
171 | See :meth:`postprocess` for details.
172 |
173 | Returns:
174 | list[dict]:
175 | Each dict is the output for one input image.
176 | The dict contains one key "instances" whose value is a :class:`Instances`.
177 | The :class:`Instances` object has the following keys:
178 | "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
179 | """
180 | images = self.preprocess_image(batched_inputs)
181 | output = self.testr(images)
182 |
183 | if self.training:
184 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
185 | targets = self.prepare_targets(gt_instances)
186 | loss_dict = self.criterion(output, targets)
187 | weight_dict = self.criterion.weight_dict
188 | for k in loss_dict.keys():
189 | if k in weight_dict:
190 | loss_dict[k] *= weight_dict[k]
191 | return loss_dict
192 | else:
193 | ctrl_point_cls = output["pred_logits"]
194 | ctrl_point_coord = output["pred_ctrl_points"]
195 | text_pred = output["pred_texts"]
196 | results = self.inference(ctrl_point_cls, ctrl_point_coord, text_pred, images.image_sizes)
197 | processed_results = []
198 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
199 | height = input_per_image.get("height", image_size[0])
200 | width = input_per_image.get("width", image_size[1])
201 | r = detector_postprocess(results_per_image, height, width)
202 | processed_results.append({"instances": r})
203 | return processed_results
204 |
205 | def prepare_targets(self, targets):
206 | new_targets = []
207 | for targets_per_image in targets:
208 | h, w = targets_per_image.image_size
209 | image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
210 | gt_classes = targets_per_image.gt_classes
211 | gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
212 | gt_boxes = box_xyxy_to_cxcywh(gt_boxes)
213 | raw_ctrl_points = targets_per_image.polygons if self.use_polygon else targets_per_image.beziers
214 | gt_ctrl_points = raw_ctrl_points.reshape(-1, self.testr.num_ctrl_points, 2) / torch.as_tensor([w, h], dtype=torch.float, device=self.device)[None, None, :]
215 | gt_text = targets_per_image.text
216 | new_targets.append({"labels": gt_classes, "boxes": gt_boxes, "ctrl_points": gt_ctrl_points, "texts": gt_text})
217 | return new_targets
218 |
219 | def inference(self, ctrl_point_cls, ctrl_point_coord, text_pred, image_sizes):
220 | assert len(ctrl_point_cls) == len(image_sizes)
221 | results = []
222 |
223 | text_pred = torch.softmax(text_pred, dim=-1)
224 | prob = ctrl_point_cls.mean(-2).sigmoid()
225 | scores, labels = prob.max(-1)
226 |
227 | for scores_per_image, labels_per_image, ctrl_point_per_image, text_per_image, image_size in zip(
228 | scores, labels, ctrl_point_coord, text_pred, image_sizes
229 | ):
230 | selector = scores_per_image >= self.test_score_threshold
231 | scores_per_image = scores_per_image[selector]
232 | labels_per_image = labels_per_image[selector]
233 | ctrl_point_per_image = ctrl_point_per_image[selector]
234 | text_per_image = text_per_image[selector]
235 | result = Instances(image_size)
236 | result.scores = scores_per_image
237 | result.pred_classes = labels_per_image
238 | result.rec_scores = text_per_image
239 | ctrl_point_per_image[..., 0] *= image_size[1]
240 | ctrl_point_per_image[..., 1] *= image_size[0]
241 | if self.use_polygon:
242 | result.polygons = ctrl_point_per_image.flatten(1)
243 | else:
244 | result.beziers = ctrl_point_per_image.flatten(1)
245 | _, topi = text_per_image.topk(1)
246 | result.recs = topi.squeeze(-1)
247 | results.append(result)
248 | return results
249 |
--------------------------------------------------------------------------------
/adet/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlpc-ucsd/TESTR/f369f138e041d1d27348a1f6600e456452001d23/adet/utils/__init__.py
--------------------------------------------------------------------------------
/adet/utils/comm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.distributed as dist
4 |
5 | from detectron2.utils.comm import get_world_size
6 |
7 |
8 | def reduce_sum(tensor):
9 | world_size = get_world_size()
10 | if world_size < 2:
11 | return tensor
12 | tensor = tensor.clone()
13 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
14 | return tensor
15 |
16 |
17 | def reduce_mean(tensor):
18 | num_gpus = get_world_size()
19 | total = reduce_sum(tensor)
20 | return total.float() / num_gpus
21 |
22 |
23 | def aligned_bilinear(tensor, factor):
24 | assert tensor.dim() == 4
25 | assert factor >= 1
26 | assert int(factor) == factor
27 |
28 | if factor == 1:
29 | return tensor
30 |
31 | h, w = tensor.size()[2:]
32 | tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode="replicate")
33 | oh = factor * h + 1
34 | ow = factor * w + 1
35 | tensor = F.interpolate(
36 | tensor, size=(oh, ow),
37 | mode='bilinear',
38 | align_corners=True
39 | )
40 | tensor = F.pad(
41 | tensor, pad=(factor // 2, 0, factor // 2, 0),
42 | mode="replicate"
43 | )
44 |
45 | return tensor[:, :, :oh - 1, :ow - 1]
46 |
47 |
48 | def compute_locations(h, w, stride, device):
49 | shifts_x = torch.arange(
50 | 0, w * stride, step=stride,
51 | dtype=torch.float32, device=device
52 | )
53 | shifts_y = torch.arange(
54 | 0, h * stride, step=stride,
55 | dtype=torch.float32, device=device
56 | )
57 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
58 | shift_x = shift_x.reshape(-1)
59 | shift_y = shift_y.reshape(-1)
60 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
61 | return locations
62 |
63 |
64 | def compute_ious(pred, target):
65 | """
66 | Args:
67 | pred: Nx4 predicted bounding boxes
68 | target: Nx4 target bounding boxes
69 | Both are in the form of FCOS prediction (l, t, r, b)
70 | """
71 | pred_left = pred[:, 0]
72 | pred_top = pred[:, 1]
73 | pred_right = pred[:, 2]
74 | pred_bottom = pred[:, 3]
75 |
76 | target_left = target[:, 0]
77 | target_top = target[:, 1]
78 | target_right = target[:, 2]
79 | target_bottom = target[:, 3]
80 |
81 | target_aera = (target_left + target_right) * \
82 | (target_top + target_bottom)
83 | pred_aera = (pred_left + pred_right) * \
84 | (pred_top + pred_bottom)
85 |
86 | w_intersect = torch.min(pred_left, target_left) + \
87 | torch.min(pred_right, target_right)
88 | h_intersect = torch.min(pred_bottom, target_bottom) + \
89 | torch.min(pred_top, target_top)
90 |
91 | g_w_intersect = torch.max(pred_left, target_left) + \
92 | torch.max(pred_right, target_right)
93 | g_h_intersect = torch.max(pred_bottom, target_bottom) + \
94 | torch.max(pred_top, target_top)
95 | ac_uion = g_w_intersect * g_h_intersect
96 |
97 | area_intersect = w_intersect * h_intersect
98 | area_union = target_aera + pred_aera - area_intersect
99 |
100 | ious = (area_intersect + 1.0) / (area_union + 1.0)
101 | gious = ious - (ac_uion - area_union) / ac_uion
102 |
103 | return ious, gious
104 |
--------------------------------------------------------------------------------
/adet/utils/misc.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | import torch
3 | from torch.functional import Tensor
4 | from torchvision.ops.boxes import box_area
5 | import torch.distributed as dist
6 |
7 |
8 | def is_dist_avail_and_initialized():
9 | if not dist.is_available():
10 | return False
11 | if not dist.is_initialized():
12 | return False
13 | return True
14 |
15 |
16 | @torch.no_grad()
17 | def accuracy(output, target, topk=(1,)):
18 | """Computes the precision@k for the specified values of k"""
19 | if target.numel() == 0:
20 | return [torch.zeros([], device=output.device)]
21 | if target.ndim == 2:
22 | assert output.ndim == 3
23 | output = output.mean(1)
24 | maxk = max(topk)
25 | batch_size = target.size(0)
26 |
27 | _, pred = output.topk(maxk, -1)
28 | pred = pred.t()
29 | correct = pred.eq(target.view(1, -1).expand_as(pred))
30 |
31 | res = []
32 | for k in topk:
33 | correct_k = correct[:k].view(-1).float().sum(0)
34 | res.append(correct_k.mul_(100.0 / batch_size))
35 | return res
36 |
37 |
38 | def box_cxcywh_to_xyxy(x):
39 | x_c, y_c, w, h = x.unbind(-1)
40 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
41 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
42 | return torch.stack(b, dim=-1)
43 |
44 |
45 | def box_xyxy_to_cxcywh(x):
46 | x0, y0, x1, y1 = x.unbind(-1)
47 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
48 | (x1 - x0), (y1 - y0)]
49 | return torch.stack(b, dim=-1)
50 |
51 |
52 | # modified from torchvision to also return the union
53 | def box_iou(boxes1, boxes2):
54 | area1 = box_area(boxes1)
55 | area2 = box_area(boxes2)
56 |
57 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
58 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
59 |
60 | wh = (rb - lt).clamp(min=0) # [N,M,2]
61 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
62 |
63 | union = area1[:, None] + area2 - inter
64 |
65 | iou = inter / union
66 | return iou, union
67 |
68 |
69 | def generalized_box_iou(boxes1, boxes2):
70 | """
71 | Generalized IoU from https://giou.stanford.edu/
72 | The boxes should be in [x0, y0, x1, y1] format
73 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
74 | and M = len(boxes2)
75 | """
76 | # degenerate boxes gives inf / nan results
77 | # so do an early check
78 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
79 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
80 | iou, union = box_iou(boxes1, boxes2)
81 |
82 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
83 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
84 |
85 | wh = (rb - lt).clamp(min=0) # [N,M,2]
86 | area = wh[:, :, 0] * wh[:, :, 1]
87 |
88 | return iou - (area - union) / area
89 |
90 |
91 | def masks_to_boxes(masks):
92 | """Compute the bounding boxes around the provided masks
93 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
94 | Returns a [N, 4] tensors, with the boxes in xyxy format
95 | """
96 | if masks.numel() == 0:
97 | return torch.zeros((0, 4), device=masks.device)
98 |
99 | h, w = masks.shape[-2:]
100 |
101 | y = torch.arange(0, h, dtype=torch.float)
102 | x = torch.arange(0, w, dtype=torch.float)
103 | y, x = torch.meshgrid(y, x)
104 |
105 | x_mask = (masks * x.unsqueeze(0))
106 | x_max = x_mask.flatten(1).max(-1)[0]
107 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
108 |
109 | y_mask = (masks * y.unsqueeze(0))
110 | y_max = y_mask.flatten(1).max(-1)[0]
111 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
112 |
113 | return torch.stack([x_min, y_min, x_max, y_max], 1)
114 |
115 | def inverse_sigmoid(x, eps=1e-5):
116 | x = x.clamp(min=0, max=1)
117 | x1 = x.clamp(min=eps)
118 | x2 = (1 - x).clamp(min=eps)
119 | return torch.log(x1/x2)
120 |
121 | def sigmoid_offset(x, offset=True):
122 | # modified sigmoid for range [-0.5, 1.5]
123 | if offset:
124 | return x.sigmoid() * 2 - 0.5
125 | else:
126 | return x.sigmoid()
127 |
128 | def inverse_sigmoid_offset(x, eps=1e-5, offset=True):
129 | if offset:
130 | x = (x + 0.5) / 2.0
131 | return inverse_sigmoid(x, eps)
132 |
133 | def _max_by_axis(the_list):
134 | # type: (List[List[int]]) -> List[int]
135 | maxes = the_list[0]
136 | for sublist in the_list[1:]:
137 | for index, item in enumerate(sublist):
138 | maxes[index] = max(maxes[index], item)
139 | return maxes
140 |
141 |
142 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
143 | # TODO make this more general
144 | if tensor_list[0].ndim == 3:
145 | # TODO make it support different-sized images
146 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
147 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
148 | batch_shape = [len(tensor_list)] + max_size
149 | b, c, h, w = batch_shape
150 | dtype = tensor_list[0].dtype
151 | device = tensor_list[0].device
152 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
153 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
154 | for img, pad_img, m in zip(tensor_list, tensor, mask):
155 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
156 | m[: img.shape[1], :img.shape[2]] = False
157 | else:
158 | raise ValueError('not supported')
159 | return NestedTensor(tensor, mask)
160 |
161 |
162 | class NestedTensor(object):
163 | def __init__(self, tensors, mask: Optional[Tensor]):
164 | self.tensors = tensors
165 | self.mask = mask
166 |
167 | def to(self, device):
168 | # type: (Device) -> NestedTensor # noqa
169 | cast_tensor = self.tensors.to(device)
170 | mask = self.mask
171 | if mask is not None:
172 | assert mask is not None
173 | cast_mask = mask.to(device)
174 | else:
175 | cast_mask = None
176 | return NestedTensor(cast_tensor, cast_mask)
177 |
178 | def decompose(self):
179 | return self.tensors, self.mask
180 |
181 | def __repr__(self):
182 | return str(self.tensors)
183 |
--------------------------------------------------------------------------------
/adet/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | from detectron2.utils.visualizer import Visualizer
4 | import matplotlib.colors as mplc
5 | import matplotlib.font_manager as mfm
6 |
7 | class TextVisualizer(Visualizer):
8 | def __init__(self, image, metadata, instance_mode, cfg):
9 | Visualizer.__init__(self, image, metadata, instance_mode=instance_mode)
10 | self.voc_size = cfg.MODEL.BATEXT.VOC_SIZE
11 | self.use_customer_dictionary = cfg.MODEL.BATEXT.CUSTOM_DICT
12 | self.use_polygon = cfg.MODEL.TRANSFORMER.USE_POLYGON
13 | if not self.use_customer_dictionary:
14 | self.CTLABELS = [' ','!','"','#','$','%','&','\'','(',')','*','+',',','-','.','/','0','1','2','3','4','5','6','7','8','9',':',';','<','=','>','?','@','A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','[','\\',']','^','_','`','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','{','|','}','~']
15 | else:
16 | with open(self.use_customer_dictionary, 'rb') as fp:
17 | self.CTLABELS = pickle.load(fp)
18 | assert(int(self.voc_size - 1) == len(self.CTLABELS)), "voc_size is not matched dictionary size, got {} and {}.".format(int(self.voc_size - 1), len(self.CTLABELS))
19 |
20 | def draw_instance_predictions(self, predictions):
21 | if self.use_polygon:
22 | ctrl_pnts = predictions.polygons.numpy()
23 | else:
24 | ctrl_pnts = predictions.beziers.numpy()
25 | scores = predictions.scores.tolist()
26 | recs = predictions.recs
27 |
28 | self.overlay_instances(ctrl_pnts, recs, scores)
29 |
30 | return self.output
31 |
32 | def _ctrl_pnt_to_poly(self, pnt):
33 | if self.use_polygon:
34 | points = pnt.reshape(-1, 2)
35 | else:
36 | # bezier to polygon
37 | u = np.linspace(0, 1, 20)
38 | pnt = pnt.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4)
39 | points = np.outer((1 - u) ** 3, pnt[:, 0]) \
40 | + np.outer(3 * u * ((1 - u) ** 2), pnt[:, 1]) \
41 | + np.outer(3 * (u ** 2) * (1 - u), pnt[:, 2]) \
42 | + np.outer(u ** 3, pnt[:, 3])
43 | points = np.concatenate((points[:, :2], points[:, 2:]), axis=0)
44 |
45 | return points
46 |
47 | def _decode_recognition(self, rec):
48 | s = ''
49 | for c in rec:
50 | c = int(c)
51 | if c < self.voc_size - 1:
52 | if self.voc_size == 96:
53 | s += self.CTLABELS[c]
54 | else:
55 | s += str(chr(self.CTLABELS[c]))
56 | elif c == self.voc_size -1:
57 | s += u'口'
58 | return s
59 |
60 | def _ctc_decode_recognition(self, rec):
61 | # ctc decoding
62 | last_char = False
63 | s = ''
64 | for c in rec:
65 | c = int(c)
66 | if c < self.voc_size - 1:
67 | if last_char != c:
68 | if self.voc_size == 96:
69 | s += self.CTLABELS[c]
70 | last_char = c
71 | else:
72 | s += str(chr(self.CTLABELS[c]))
73 | last_char = c
74 | elif c == self.voc_size -1:
75 | s += u'口'
76 | else:
77 | last_char = False
78 | return s
79 |
80 | def overlay_instances(self, ctrl_pnts, recs, scores, alpha=0.5):
81 | color = (0.1, 0.2, 0.5)
82 |
83 | for ctrl_pnt, rec, score in zip(ctrl_pnts, recs, scores):
84 | polygon = self._ctrl_pnt_to_poly(ctrl_pnt)
85 | self.draw_polygon(polygon, color, alpha=alpha)
86 |
87 | # draw text in the top left corner
88 | text = self._decode_recognition(rec)
89 | text = "{:.3f}: {}".format(score, text)
90 | lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
91 | text_pos = polygon[0]
92 | horiz_align = "left"
93 | font_size = self._default_font_size
94 |
95 | self.draw_text(
96 | text,
97 | text_pos,
98 | color=lighter_color,
99 | horizontal_alignment=horiz_align,
100 | font_size=font_size,
101 | draw_chinese=False if self.voc_size == 96 else True
102 | )
103 |
104 |
105 | def draw_text(
106 | self,
107 | text,
108 | position,
109 | *,
110 | font_size=None,
111 | color="g",
112 | horizontal_alignment="center",
113 | rotation=0,
114 | draw_chinese=False
115 | ):
116 | """
117 | Args:
118 | text (str): class label
119 | position (tuple): a tuple of the x and y coordinates to place text on image.
120 | font_size (int, optional): font of the text. If not provided, a font size
121 | proportional to the image width is calculated and used.
122 | color: color of the text. Refer to `matplotlib.colors` for full list
123 | of formats that are accepted.
124 | horizontal_alignment (str): see `matplotlib.text.Text`
125 | rotation: rotation angle in degrees CCW
126 | Returns:
127 | output (VisImage): image object with text drawn.
128 | """
129 | if not font_size:
130 | font_size = self._default_font_size
131 |
132 | # since the text background is dark, we don't want the text to be dark
133 | color = np.maximum(list(mplc.to_rgb(color)), 0.2)
134 | color[np.argmax(color)] = max(0.8, np.max(color))
135 |
136 | x, y = position
137 | if draw_chinese:
138 | font_path = "./simsun.ttc"
139 | prop = mfm.FontProperties(fname=font_path)
140 | self.output.ax.text(
141 | x,
142 | y,
143 | text,
144 | size=font_size * self.output.scale,
145 | family="sans-serif",
146 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
147 | verticalalignment="top",
148 | horizontalalignment=horizontal_alignment,
149 | color=color,
150 | zorder=10,
151 | rotation=rotation,
152 | fontproperties=prop
153 | )
154 | else:
155 | self.output.ax.text(
156 | x,
157 | y,
158 | text,
159 | size=font_size * self.output.scale,
160 | family="sans-serif",
161 | bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
162 | verticalalignment="top",
163 | horizontalalignment=horizontal_alignment,
164 | color=color,
165 | zorder=10,
166 | rotation=rotation,
167 | )
168 | return self.output
--------------------------------------------------------------------------------
/configs/TESTR/Base-TESTR.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "TransformerDetector"
3 | MASK_ON: False
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 | BACKBONE:
7 | NAME: "build_resnet_backbone"
8 | RESNETS:
9 | STRIDE_IN_1X1: False
10 | OUT_FEATURES: ["res3", "res4", "res5"]
11 | TRANSFORMER:
12 | ENABLED: True
13 | INFERENCE_TH_TEST: 0.45
14 | SOLVER:
15 | WEIGHT_DECAY: 1e-4
16 | OPTIMIZER: "ADAMW"
17 | LR_BACKBONE_NAMES: ['backbone.0']
18 | LR_LINEAR_PROJ_NAMES: ['reference_points', 'sampling_offsets']
19 | LR_LINEAR_PROJ_MULT: 0.1
20 | CLIP_GRADIENTS:
21 | ENABLED: True
22 | CLIP_TYPE: "full_model"
23 | CLIP_VALUE: 0.1
24 | NORM_TYPE: 2.0
25 | INPUT:
26 | HFLIP_TRAIN: False
27 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896)
28 | MAX_SIZE_TRAIN: 1600
29 | MIN_SIZE_TEST: 1600
30 | MAX_SIZE_TEST: 1824
31 | CROP:
32 | ENABLED: True
33 | CROP_INSTANCE: False
34 | SIZE: [0.1, 0.1]
35 | FORMAT: "RGB"
--------------------------------------------------------------------------------
/configs/TESTR/CTW1500/Base-CTW1500-Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | MODEL:
3 | TRANSFORMER:
4 | INFERENCE_TH_TEST: 0.7
5 | NUM_CHARS: 100
6 | USE_POLYGON: True
7 | NUM_CTRL_POINTS: 16
8 | LOSS:
9 | POINT_TEXT_WEIGHT: 4.0
10 | DATASETS:
11 | TRAIN: ("ctw1500_word_poly_train",)
12 | TEST: ("ctw1500_word_poly_test",)
13 | INPUT:
14 | MIN_SIZE_TEST: 1000
15 | MAX_SIZE_TRAIN: 1333
16 | TEST:
17 | USE_LEXICON: False
18 | LEXICON_TYPE: 1
19 |
20 |
--------------------------------------------------------------------------------
/configs/TESTR/CTW1500/Base-CTW1500.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | MODEL:
3 | TRANSFORMER:
4 | INFERENCE_TH_TEST: 0.6
5 | NUM_CHARS: 100
6 | DATASETS:
7 | TRAIN: ("ctw1500_word_train",)
8 | TEST: ("ctw1500_word_test",)
9 | INPUT:
10 | MIN_SIZE_TEST: 1000
11 | MAX_SIZE_TRAIN: 1333
12 | TEST:
13 | USE_LEXICON: False
14 | LEXICON_TYPE: 1
15 |
16 |
--------------------------------------------------------------------------------
/configs/TESTR/CTW1500/TESTR_R_50.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-CTW1500.yaml"
2 | MODEL:
3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50.pth"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | STEPS: (300000,)
19 | IMS_PER_BATCH: 8
20 | BASE_LR: 2e-5
21 | LR_BACKBONE: 2e-6
22 | WARMUP_ITERS: 0
23 | MAX_ITER: 200000
24 | CHECKPOINT_PERIOD: 10000
25 | TEST:
26 | EVAL_PERIOD: 10000
27 | OUTPUT_DIR: "output/TESTR/ctw1500/TESTR_R_50"
28 |
--------------------------------------------------------------------------------
/configs/TESTR/CTW1500/TESTR_R_50_Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-CTW1500-Polygon.yaml"
2 | MODEL:
3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | STEPS: (300000,)
19 | IMS_PER_BATCH: 8
20 | BASE_LR: 1e-5
21 | LR_BACKBONE: 1e-6
22 | WARMUP_ITERS: 0
23 | MAX_ITER: 200000
24 | CHECKPOINT_PERIOD: 10000
25 | TEST:
26 | EVAL_PERIOD: 10000
27 | OUTPUT_DIR: "output/TESTR/ctw1500/TESTR_R_50_Polygon"
28 |
--------------------------------------------------------------------------------
/configs/TESTR/ICDAR15/Base-ICDAR15-Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | DATASETS:
3 | TRAIN: ("icdar2015_train",)
4 | TEST: ("icdar2015_test",)
5 | MODEL:
6 | TRANSFORMER:
7 | USE_POLYGON: True
8 | NUM_CTRL_POINTS: 16
9 | LOSS:
10 | POINT_TEXT_WEIGHT: 4.0
11 | TEST:
12 | USE_LEXICON: True
13 | LEXICON_TYPE: 3
14 | WEIGHTED_EDIT_DIST: True
15 | INPUT:
16 | MIN_SIZE_TRAIN: (800, 832, 864, 896, 1000, 1200, 1400)
17 | MAX_SIZE_TRAIN: 2333
18 | MIN_SIZE_TEST: 1440
19 | MAX_SIZE_TEST: 4000
20 |
--------------------------------------------------------------------------------
/configs/TESTR/ICDAR15/TESTR_R_50_Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-ICDAR15-Polygon.yaml"
2 | MODEL:
3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | INFERENCE_TH_TEST: 0.3
9 | ENC_LAYERS: 6
10 | DEC_LAYERS: 6
11 | DIM_FEEDFORWARD: 1024
12 | HIDDEN_DIM: 256
13 | DROPOUT: 0.1
14 | NHEADS: 8
15 | NUM_QUERIES: 100
16 | ENC_N_POINTS: 4
17 | DEC_N_POINTS: 4
18 | SOLVER:
19 | IMS_PER_BATCH: 8
20 | BASE_LR: 1e-5
21 | LR_BACKBONE: 1e-6
22 | WARMUP_ITERS: 0
23 | STEPS: (200000,)
24 | MAX_ITER: 20000
25 | CHECKPOINT_PERIOD: 1000
26 | TEST:
27 | EVAL_PERIOD: 1000
28 | OUTPUT_DIR: "output/TESTR/icdar15/TESTR_R_50_Polygon"
29 |
--------------------------------------------------------------------------------
/configs/TESTR/Pretrain/Base-Pretrain-Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | DATASETS:
3 | TRAIN: ("mltbezier_word_poly_train", "totaltext_poly_train", "syntext1_poly_train", "syntext2_poly_train",)
4 | TEST: ("totaltext_poly_val",)
5 | MODEL:
6 | TRANSFORMER:
7 | USE_POLYGON: True
8 | NUM_CTRL_POINTS: 16
9 | LOSS:
10 | POINT_TEXT_WEIGHT: 4.0
11 |
--------------------------------------------------------------------------------
/configs/TESTR/Pretrain/Base-Pretrain.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | DATASETS:
3 | TRAIN: ("mltbezier_word_train", "totaltext_train", "syntext1_train", "syntext2_train",)
4 | TEST: ("totaltext_val",)
5 |
--------------------------------------------------------------------------------
/configs/TESTR/Pretrain/TESTR_R_50.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-Pretrain.yaml"
2 | MODEL:
3 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | IMS_PER_BATCH: 8
19 | BASE_LR: 2e-4
20 | LR_BACKBONE: 2e-5
21 | WARMUP_ITERS: 0
22 | STEPS: (340000,)
23 | MAX_ITER: 440000
24 | CHECKPOINT_PERIOD: 10000
25 | TEST:
26 | EVAL_PERIOD: 10000
27 | OUTPUT_DIR: "output/TESTR/pretrain/TESTR_R_50"
28 |
--------------------------------------------------------------------------------
/configs/TESTR/Pretrain/TESTR_R_50_Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-Pretrain-Polygon.yaml"
2 | MODEL:
3 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | IMS_PER_BATCH: 8
19 | BASE_LR: 1e-4
20 | LR_BACKBONE: 1e-5
21 | WARMUP_ITERS: 0
22 | STEPS: (340000,)
23 | MAX_ITER: 440000
24 | CHECKPOINT_PERIOD: 10000
25 | TEST:
26 | EVAL_PERIOD: 10000
27 | OUTPUT_DIR: "output/TESTR/pretrain/TESTR_R_50_Polygon"
28 |
--------------------------------------------------------------------------------
/configs/TESTR/TotalText/Base-TotalText-Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | DATASETS:
3 | TRAIN: ("totaltext_poly_train",)
4 | TEST: ("totaltext_poly_val",)
5 | MODEL:
6 | TRANSFORMER:
7 | USE_POLYGON: True
8 | NUM_CTRL_POINTS: 16
9 | LOSS:
10 | POINT_TEXT_WEIGHT: 4.0
11 | TEST:
12 | USE_LEXICON: False
13 | LEXICON_TYPE: 1
--------------------------------------------------------------------------------
/configs/TESTR/TotalText/Base-TotalText.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-TESTR.yaml"
2 | DATASETS:
3 | TRAIN: ("totaltext_train",)
4 | TEST: ("totaltext_val",)
5 | TEST:
6 | USE_LEXICON: False
7 | LEXICON_TYPE: 1
8 |
--------------------------------------------------------------------------------
/configs/TESTR/TotalText/TESTR_R_50.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-TotalText.yaml"
2 | MODEL:
3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50.pth"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | IMS_PER_BATCH: 8
19 | BASE_LR: 2e-5
20 | LR_BACKBONE: 2e-6
21 | WARMUP_ITERS: 0
22 | STEPS: (200000,)
23 | MAX_ITER: 20000
24 | CHECKPOINT_PERIOD: 1000
25 | TEST:
26 | EVAL_PERIOD: 1000
27 | OUTPUT_DIR: "output/TESTR/totaltext/TESTR_R_50"
28 |
--------------------------------------------------------------------------------
/configs/TESTR/TotalText/TESTR_R_50_Polygon.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-TotalText-Polygon.yaml"
2 | MODEL:
3 | WEIGHTS: "weights/TESTR/pretrain_testr_R_50_polygon.pth"
4 | RESNETS:
5 | DEPTH: 50
6 | TRANSFORMER:
7 | NUM_FEATURE_LEVELS: 4
8 | ENC_LAYERS: 6
9 | DEC_LAYERS: 6
10 | DIM_FEEDFORWARD: 1024
11 | HIDDEN_DIM: 256
12 | DROPOUT: 0.1
13 | NHEADS: 8
14 | NUM_QUERIES: 100
15 | ENC_N_POINTS: 4
16 | DEC_N_POINTS: 4
17 | SOLVER:
18 | IMS_PER_BATCH: 8
19 | BASE_LR: 1e-5
20 | LR_BACKBONE: 1e-6
21 | WARMUP_ITERS: 0
22 | STEPS: (200000,)
23 | MAX_ITER: 20000
24 | CHECKPOINT_PERIOD: 1000
25 | TEST:
26 | EVAL_PERIOD: 1000
27 | OUTPUT_DIR: "output/TESTR/totaltext/TESTR_R_50_Polygon"
28 |
--------------------------------------------------------------------------------
/demo/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import argparse
3 | import glob
4 | import multiprocessing as mp
5 | import os
6 | import time
7 | import cv2
8 | import tqdm
9 |
10 | from detectron2.data.detection_utils import read_image
11 | from detectron2.utils.logger import setup_logger
12 |
13 | from predictor import VisualizationDemo
14 | from adet.config import get_cfg
15 |
16 | # constants
17 | WINDOW_NAME = "COCO detections"
18 |
19 |
20 | def setup_cfg(args):
21 | # load config from file and command-line arguments
22 | cfg = get_cfg()
23 | cfg.merge_from_file(args.config_file)
24 | cfg.merge_from_list(args.opts)
25 | # Set score_threshold for builtin models
26 | cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
27 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
28 | cfg.MODEL.FCOS.INFERENCE_TH_TEST = args.confidence_threshold
29 | cfg.MODEL.MEInst.INFERENCE_TH_TEST = args.confidence_threshold
30 | cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
31 | cfg.freeze()
32 | return cfg
33 |
34 |
35 | def get_parser():
36 | parser = argparse.ArgumentParser(description="Detectron2 Demo")
37 | parser.add_argument(
38 | "--config-file",
39 | default="configs/quick_schedules/e2e_mask_rcnn_R_50_FPN_inference_acc_test.yaml",
40 | metavar="FILE",
41 | help="path to config file",
42 | )
43 | parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
44 | parser.add_argument("--video-input", help="Path to video file.")
45 | parser.add_argument("--input", nargs="+", help="A list of space separated input images")
46 | parser.add_argument(
47 | "--output",
48 | help="A file or directory to save output visualizations. "
49 | "If not given, will show output in an OpenCV window.",
50 | )
51 |
52 | parser.add_argument(
53 | "--confidence-threshold",
54 | type=float,
55 | default=0.3,
56 | help="Minimum score for instance predictions to be shown",
57 | )
58 | parser.add_argument(
59 | "--opts",
60 | help="Modify config options using the command-line 'KEY VALUE' pairs",
61 | default=[],
62 | nargs=argparse.REMAINDER,
63 | )
64 | return parser
65 |
66 |
67 | if __name__ == "__main__":
68 | mp.set_start_method("spawn", force=True)
69 | args = get_parser().parse_args()
70 | logger = setup_logger()
71 | logger.info("Arguments: " + str(args))
72 |
73 | cfg = setup_cfg(args)
74 |
75 | demo = VisualizationDemo(cfg)
76 |
77 | if args.input:
78 | if os.path.isdir(args.input[0]):
79 | args.input = [os.path.join(args.input[0], fname) for fname in os.listdir(args.input[0])]
80 | elif len(args.input) == 1:
81 | args.input = glob.glob(os.path.expanduser(args.input[0]))
82 | assert args.input, "The input path(s) was not found"
83 | for path in tqdm.tqdm(args.input, disable=not args.output):
84 | # use PIL, to be consistent with evaluation
85 | img = read_image(path, format="BGR")
86 | start_time = time.time()
87 | predictions, visualized_output = demo.run_on_image(img)
88 | logger.info(
89 | "{}: detected {} instances in {:.2f}s".format(
90 | path, len(predictions["instances"]), time.time() - start_time
91 | )
92 | )
93 |
94 | if args.output:
95 | if os.path.isdir(args.output):
96 | assert os.path.isdir(args.output), args.output
97 | out_filename = os.path.join(args.output, os.path.basename(path))
98 | else:
99 | assert len(args.input) == 1, "Please specify a directory with args.output"
100 | out_filename = args.output
101 | visualized_output.save(out_filename)
102 | else:
103 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
104 | if cv2.waitKey(0) == 27:
105 | break # esc to quit
106 | elif args.webcam:
107 | assert args.input is None, "Cannot have both --input and --webcam!"
108 | cam = cv2.VideoCapture(0)
109 | for vis in tqdm.tqdm(demo.run_on_video(cam)):
110 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
111 | cv2.imshow(WINDOW_NAME, vis)
112 | if cv2.waitKey(1) == 27:
113 | break # esc to quit
114 | cv2.destroyAllWindows()
115 | elif args.video_input:
116 | video = cv2.VideoCapture(args.video_input)
117 | width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
118 | height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
119 | frames_per_second = video.get(cv2.CAP_PROP_FPS)
120 | num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
121 | basename = os.path.basename(args.video_input)
122 |
123 | if args.output:
124 | if os.path.isdir(args.output):
125 | output_fname = os.path.join(args.output, basename)
126 | output_fname = os.path.splitext(output_fname)[0] + ".mkv"
127 | else:
128 | output_fname = args.output
129 | assert not os.path.isfile(output_fname), output_fname
130 | output_file = cv2.VideoWriter(
131 | filename=output_fname,
132 | # some installation of opencv may not support x264 (due to its license),
133 | # you can try other format (e.g. MPEG)
134 | fourcc=cv2.VideoWriter_fourcc(*"x264"),
135 | fps=float(frames_per_second),
136 | frameSize=(width, height),
137 | isColor=True,
138 | )
139 | assert os.path.isfile(args.video_input)
140 | for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
141 | if args.output:
142 | output_file.write(vis_frame)
143 | else:
144 | cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
145 | cv2.imshow(basename, vis_frame)
146 | if cv2.waitKey(1) == 27:
147 | break # esc to quit
148 | video.release()
149 | if args.output:
150 | output_file.release()
151 | else:
152 | cv2.destroyAllWindows()
153 |
--------------------------------------------------------------------------------
/demo/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import numpy as np
3 | import atexit
4 | import bisect
5 | import multiprocessing as mp
6 | from collections import deque
7 | import cv2
8 | import torch
9 | import matplotlib.pyplot as plt
10 |
11 | from detectron2.data import MetadataCatalog
12 | from detectron2.engine.defaults import DefaultPredictor
13 | from detectron2.utils.video_visualizer import VideoVisualizer
14 | from detectron2.utils.visualizer import ColorMode, Visualizer
15 |
16 | from adet.utils.visualizer import TextVisualizer
17 |
18 |
19 | class VisualizationDemo(object):
20 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
21 | """
22 | Args:
23 | cfg (CfgNode):
24 | instance_mode (ColorMode):
25 | parallel (bool): whether to run the model in different processes from visualization.
26 | Useful since the visualization logic can be slow.
27 | """
28 | self.metadata = MetadataCatalog.get(
29 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
30 | )
31 | self.cfg = cfg
32 | self.cpu_device = torch.device("cpu")
33 | self.instance_mode = instance_mode
34 | self.vis_text = cfg.MODEL.ROI_HEADS.NAME == "TextHead" or cfg.MODEL.TRANSFORMER.ENABLED
35 |
36 | self.parallel = parallel
37 | if parallel:
38 | num_gpu = torch.cuda.device_count()
39 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
40 | else:
41 | self.predictor = DefaultPredictor(cfg)
42 |
43 | def run_on_image(self, image):
44 | """
45 | Args:
46 | image (np.ndarray): an image of shape (H, W, C) (in BGR order).
47 | This is the format used by OpenCV.
48 |
49 | Returns:
50 | predictions (dict): the output of the model.
51 | vis_output (VisImage): the visualized image output.
52 | """
53 | vis_output = None
54 | predictions = self.predictor(image)
55 | # Convert image from OpenCV BGR format to Matplotlib RGB format.
56 | image = image[:, :, ::-1]
57 | if self.vis_text:
58 | visualizer = TextVisualizer(image, self.metadata, instance_mode=self.instance_mode, cfg=self.cfg)
59 | else:
60 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
61 |
62 | if "bases" in predictions:
63 | self.vis_bases(predictions["bases"])
64 | if "panoptic_seg" in predictions:
65 | panoptic_seg, segments_info = predictions["panoptic_seg"]
66 | vis_output = visualizer.draw_panoptic_seg_predictions(
67 | panoptic_seg.to(self.cpu_device), segments_info
68 | )
69 | else:
70 | if "sem_seg" in predictions:
71 | vis_output = visualizer.draw_sem_seg(
72 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device))
73 | if "instances" in predictions:
74 | instances = predictions["instances"].to(self.cpu_device)
75 | vis_output = visualizer.draw_instance_predictions(predictions=instances)
76 |
77 | return predictions, vis_output
78 |
79 | def _frame_from_video(self, video):
80 | while video.isOpened():
81 | success, frame = video.read()
82 | if success:
83 | yield frame
84 | else:
85 | break
86 |
87 | def vis_bases(self, bases):
88 | basis_colors = [[2, 200, 255], [107, 220, 255], [30, 200, 255], [60, 220, 255]]
89 | bases = bases[0].squeeze()
90 | bases = (bases / 8).tanh().cpu().numpy()
91 | num_bases = len(bases)
92 | fig, axes = plt.subplots(nrows=num_bases // 2, ncols=2)
93 | for i, basis in enumerate(bases):
94 | basis = (basis + 1) / 2
95 | basis = basis / basis.max()
96 | basis_viz = np.zeros((basis.shape[0], basis.shape[1], 3), dtype=np.uint8)
97 | basis_viz[:, :, 0] = basis_colors[i][0]
98 | basis_viz[:, :, 1] = basis_colors[i][1]
99 | basis_viz[:, :, 2] = np.uint8(basis * 255)
100 | basis_viz = cv2.cvtColor(basis_viz, cv2.COLOR_HSV2RGB)
101 | axes[i // 2][i % 2].imshow(basis_viz)
102 | plt.show()
103 |
104 | def run_on_video(self, video):
105 | """
106 | Visualizes predictions on frames of the input video.
107 |
108 | Args:
109 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
110 | either a webcam or a video file.
111 |
112 | Yields:
113 | ndarray: BGR visualizations of each video frame.
114 | """
115 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
116 |
117 | def process_predictions(frame, predictions):
118 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
119 | if "panoptic_seg" in predictions:
120 | panoptic_seg, segments_info = predictions["panoptic_seg"]
121 | vis_frame = video_visualizer.draw_panoptic_seg_predictions(
122 | frame, panoptic_seg.to(self.cpu_device), segments_info
123 | )
124 | elif "instances" in predictions:
125 | predictions = predictions["instances"].to(self.cpu_device)
126 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
127 | elif "sem_seg" in predictions:
128 | vis_frame = video_visualizer.draw_sem_seg(
129 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
130 | )
131 |
132 | # Converts Matplotlib RGB format to OpenCV BGR format
133 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
134 | return vis_frame
135 |
136 | frame_gen = self._frame_from_video(video)
137 | if self.parallel:
138 | buffer_size = self.predictor.default_buffer_size
139 |
140 | frame_data = deque()
141 |
142 | for cnt, frame in enumerate(frame_gen):
143 | frame_data.append(frame)
144 | self.predictor.put(frame)
145 |
146 | if cnt >= buffer_size:
147 | frame = frame_data.popleft()
148 | predictions = self.predictor.get()
149 | yield process_predictions(frame, predictions)
150 |
151 | while len(frame_data):
152 | frame = frame_data.popleft()
153 | predictions = self.predictor.get()
154 | yield process_predictions(frame, predictions)
155 | else:
156 | for frame in frame_gen:
157 | yield process_predictions(frame, self.predictor(frame))
158 |
159 |
160 | class AsyncPredictor:
161 | """
162 | A predictor that runs the model asynchronously, possibly on >1 GPUs.
163 | Because rendering the visualization takes considerably amount of time,
164 | this helps improve throughput when rendering videos.
165 | """
166 |
167 | class _StopToken:
168 | pass
169 |
170 | class _PredictWorker(mp.Process):
171 | def __init__(self, cfg, task_queue, result_queue):
172 | self.cfg = cfg
173 | self.task_queue = task_queue
174 | self.result_queue = result_queue
175 | super().__init__()
176 |
177 | def run(self):
178 | predictor = DefaultPredictor(self.cfg)
179 |
180 | while True:
181 | task = self.task_queue.get()
182 | if isinstance(task, AsyncPredictor._StopToken):
183 | break
184 | idx, data = task
185 | result = predictor(data)
186 | self.result_queue.put((idx, result))
187 |
188 | def __init__(self, cfg, num_gpus: int = 1):
189 | """
190 | Args:
191 | cfg (CfgNode):
192 | num_gpus (int): if 0, will run on CPU
193 | """
194 | num_workers = max(num_gpus, 1)
195 | self.task_queue = mp.Queue(maxsize=num_workers * 3)
196 | self.result_queue = mp.Queue(maxsize=num_workers * 3)
197 | self.procs = []
198 | for gpuid in range(max(num_gpus, 1)):
199 | cfg = cfg.clone()
200 | cfg.defrost()
201 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
202 | self.procs.append(
203 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
204 | )
205 |
206 | self.put_idx = 0
207 | self.get_idx = 0
208 | self.result_rank = []
209 | self.result_data = []
210 |
211 | for p in self.procs:
212 | p.start()
213 | atexit.register(self.shutdown)
214 |
215 | def put(self, image):
216 | self.put_idx += 1
217 | self.task_queue.put((self.put_idx, image))
218 |
219 | def get(self):
220 | self.get_idx += 1 # the index needed for this request
221 | if len(self.result_rank) and self.result_rank[0] == self.get_idx:
222 | res = self.result_data[0]
223 | del self.result_data[0], self.result_rank[0]
224 | return res
225 |
226 | while True:
227 | # make sure the results are returned in the correct order
228 | idx, res = self.result_queue.get()
229 | if idx == self.get_idx:
230 | return res
231 | insert = bisect.bisect(self.result_rank, idx)
232 | self.result_rank.insert(insert, idx)
233 | self.result_data.insert(insert, res)
234 |
235 | def __len__(self):
236 | return self.put_idx - self.get_idx
237 |
238 | def __call__(self, image):
239 | self.put(image)
240 | return self.get()
241 |
242 | def shutdown(self):
243 | for _ in self.procs:
244 | self.task_queue.put(AsyncPredictor._StopToken())
245 |
246 | @property
247 | def default_buffer_size(self):
248 | return len(self.procs) * 5
249 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import glob
5 | import os
6 | from setuptools import find_packages, setup
7 | import torch
8 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
9 |
10 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
11 | assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"
12 |
13 |
14 | def get_extensions():
15 | this_dir = os.path.dirname(os.path.abspath(__file__))
16 | extensions_dir = os.path.join(this_dir, "adet", "layers", "csrc")
17 |
18 | main_source = os.path.join(extensions_dir, "vision.cpp")
19 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
20 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
21 | os.path.join(extensions_dir, "*.cu")
22 | )
23 |
24 | sources = [main_source] + sources
25 |
26 | extension = CppExtension
27 |
28 | extra_compile_args = {"cxx": []}
29 | define_macros = []
30 |
31 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1":
32 | extension = CUDAExtension
33 | sources += source_cuda
34 | define_macros += [("WITH_CUDA", None)]
35 | extra_compile_args["nvcc"] = [
36 | "-DCUDA_HAS_FP16=1",
37 | "-D__CUDA_NO_HALF_OPERATORS__",
38 | "-D__CUDA_NO_HALF_CONVERSIONS__",
39 | "-D__CUDA_NO_HALF2_OPERATORS__",
40 | ]
41 |
42 | if torch_ver < [1, 7]:
43 | # supported by https://github.com/pytorch/pytorch/pull/43931
44 | CC = os.environ.get("CC", None)
45 | if CC is not None:
46 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
47 |
48 | sources = [os.path.join(extensions_dir, s) for s in sources]
49 |
50 | include_dirs = [extensions_dir]
51 |
52 | ext_modules = [
53 | extension(
54 | "adet._C",
55 | sources,
56 | include_dirs=include_dirs,
57 | define_macros=define_macros,
58 | extra_compile_args=extra_compile_args,
59 | )
60 | ]
61 |
62 | return ext_modules
63 |
64 |
65 | setup(
66 | name="AdelaiDet",
67 | version="0.2.0",
68 | author="Adelaide Intelligent Machines",
69 | url="https://github.com/stanstarks/AdelaiDet",
70 | description="AdelaiDet is AIM's research "
71 | "platform for instance-level detection tasks based on Detectron2.",
72 | packages=find_packages(exclude=("configs", "tests")),
73 | python_requires=">=3.6",
74 | install_requires=[
75 | "termcolor>=1.1",
76 | "Pillow>=6.0",
77 | "yacs>=0.1.6",
78 | "tabulate",
79 | "cloudpickle",
80 | "matplotlib",
81 | "tqdm>4.29.0",
82 | "tensorboard",
83 | "rapidfuzz",
84 | "Polygon3",
85 | "shapely",
86 | "scikit-image",
87 | "editdistance",
88 | "opencv-python",
89 | "numba",
90 | ],
91 | extras_require={"all": ["psutil"]},
92 | ext_modules=get_extensions(),
93 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
94 | )
95 |
--------------------------------------------------------------------------------
/tools/train_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Detection Training Script.
4 |
5 | This scripts reads a given config file and runs the training or evaluation.
6 | It is an entry point that is made to train standard models in detectron2.
7 |
8 | In order to let one script support training of many models,
9 | this script contains logic that are specific to these built-in models and therefore
10 | may not be suitable for your own project.
11 | For example, your research project perhaps only needs a single "evaluator".
12 |
13 | Therefore, we recommend you to use detectron2 as an library and take
14 | this file as an example of how to use the library.
15 | You may want to write your own script with your datasets and other customizations.
16 | """
17 |
18 | import logging
19 | import os
20 | from collections import OrderedDict
21 | from typing import Any, Dict, List, Set
22 | import torch
23 | import itertools
24 | from torch.nn.parallel import DistributedDataParallel
25 |
26 | import detectron2.utils.comm as comm
27 | from detectron2.data import MetadataCatalog, build_detection_train_loader
28 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
29 | from detectron2.utils.events import EventStorage
30 | from detectron2.evaluation import (
31 | COCOEvaluator,
32 | COCOPanopticEvaluator,
33 | DatasetEvaluators,
34 | LVISEvaluator,
35 | PascalVOCDetectionEvaluator,
36 | SemSegEvaluator,
37 | verify_results,
38 | )
39 | from detectron2.solver.build import maybe_add_gradient_clipping
40 | from detectron2.modeling import GeneralizedRCNNWithTTA
41 | from detectron2.utils.logger import setup_logger
42 |
43 | from adet.data.dataset_mapper import DatasetMapperWithBasis
44 | from adet.config import get_cfg
45 | from adet.checkpoint import AdetCheckpointer
46 | from adet.evaluation import TextEvaluator
47 |
48 |
49 | class Trainer(DefaultTrainer):
50 | """
51 | This is the same Trainer except that we rewrite the
52 | `build_train_loader`/`resume_or_load` method.
53 | """
54 | def build_hooks(self):
55 | """
56 | Replace `DetectionCheckpointer` with `AdetCheckpointer`.
57 |
58 | Build a list of default hooks, including timing, evaluation,
59 | checkpointing, lr scheduling, precise BN, writing events.
60 | """
61 | ret = super().build_hooks()
62 | for i in range(len(ret)):
63 | if isinstance(ret[i], hooks.PeriodicCheckpointer):
64 | self.checkpointer = AdetCheckpointer(
65 | self.model,
66 | self.cfg.OUTPUT_DIR,
67 | optimizer=self.optimizer,
68 | scheduler=self.scheduler,
69 | )
70 | ret[i] = hooks.PeriodicCheckpointer(self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
71 | return ret
72 |
73 | def resume_or_load(self, resume=True):
74 | checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
75 | if resume and self.checkpointer.has_checkpoint():
76 | self.start_iter = checkpoint.get("iteration", -1) + 1
77 |
78 | def train_loop(self, start_iter: int, max_iter: int):
79 | """
80 | Args:
81 | start_iter, max_iter (int): See docs above
82 | """
83 | logger = logging.getLogger("adet.trainer")
84 | logger.info("Starting training from iteration {}".format(start_iter))
85 |
86 | self.iter = self.start_iter = start_iter
87 | self.max_iter = max_iter
88 |
89 | with EventStorage(start_iter) as self.storage:
90 | self.before_train()
91 | for self.iter in range(start_iter, max_iter):
92 | self.before_step()
93 | self.run_step()
94 | self.after_step()
95 | self.after_train()
96 |
97 | def train(self):
98 | """
99 | Run training.
100 |
101 | Returns:
102 | OrderedDict of results, if evaluation is enabled. Otherwise None.
103 | """
104 | self.train_loop(self.start_iter, self.max_iter)
105 | if hasattr(self, "_last_eval_results") and comm.is_main_process():
106 | verify_results(self.cfg, self._last_eval_results)
107 | return self._last_eval_results
108 |
109 | @classmethod
110 | def build_train_loader(cls, cfg):
111 | """
112 | Returns:
113 | iterable
114 |
115 | It calls :func:`detectron2.data.build_detection_train_loader` with a customized
116 | DatasetMapper, which adds categorical labels as a semantic mask.
117 | """
118 | mapper = DatasetMapperWithBasis(cfg, True)
119 | return build_detection_train_loader(cfg, mapper=mapper)
120 |
121 | @classmethod
122 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
123 | """
124 | Create evaluator(s) for a given dataset.
125 | This uses the special metadata "evaluator_type" associated with each builtin dataset.
126 | For your own dataset, you can simply create an evaluator manually in your
127 | script and do not have to worry about the hacky if-else logic here.
128 | """
129 | if output_folder is None:
130 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
131 | evaluator_list = []
132 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
133 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
134 | evaluator_list.append(
135 | SemSegEvaluator(
136 | dataset_name,
137 | distributed=True,
138 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
139 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
140 | output_dir=output_folder,
141 | )
142 | )
143 | if evaluator_type in ["coco", "coco_panoptic_seg"]:
144 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder))
145 | if evaluator_type == "coco_panoptic_seg":
146 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
147 | if evaluator_type == "pascal_voc":
148 | return PascalVOCDetectionEvaluator(dataset_name)
149 | if evaluator_type == "lvis":
150 | return LVISEvaluator(dataset_name, cfg, True, output_folder)
151 | if evaluator_type == "text":
152 | return TextEvaluator(dataset_name, cfg, True, output_folder)
153 | if len(evaluator_list) == 0:
154 | raise NotImplementedError(
155 | "no Evaluator for the dataset {} with the type {}".format(
156 | dataset_name, evaluator_type
157 | )
158 | )
159 | if len(evaluator_list) == 1:
160 | return evaluator_list[0]
161 | return DatasetEvaluators(evaluator_list)
162 |
163 | @classmethod
164 | def test_with_TTA(cls, cfg, model):
165 | logger = logging.getLogger("adet.trainer")
166 | # In the end of training, run an evaluation with TTA
167 | # Only support some R-CNN models.
168 | logger.info("Running inference with test-time augmentation ...")
169 | model = GeneralizedRCNNWithTTA(cfg, model)
170 | evaluators = [
171 | cls.build_evaluator(
172 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
173 | )
174 | for name in cfg.DATASETS.TEST
175 | ]
176 | res = cls.test(cfg, model, evaluators)
177 | res = OrderedDict({k + "_TTA": v for k, v in res.items()})
178 | return res
179 |
180 | @classmethod
181 | def build_optimizer(cls, cfg, model):
182 | def match_name_keywords(n, name_keywords):
183 | out = False
184 | for b in name_keywords:
185 | if b in n:
186 | out = True
187 | break
188 | return out
189 |
190 | params: List[Dict[str, Any]] = []
191 | memo: Set[torch.nn.parameter.Parameter] = set()
192 | for key, value in model.named_parameters(recurse=True):
193 | if not value.requires_grad:
194 | continue
195 | # Avoid duplicating parameters
196 | if value in memo:
197 | continue
198 | memo.add(value)
199 | lr = cfg.SOLVER.BASE_LR
200 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
201 |
202 | if match_name_keywords(key, cfg.SOLVER.LR_BACKBONE_NAMES):
203 | lr = cfg.SOLVER.LR_BACKBONE
204 | elif match_name_keywords(key, cfg.SOLVER.LR_LINEAR_PROJ_NAMES):
205 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.LR_LINEAR_PROJ_MULT
206 |
207 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
208 |
209 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
210 | # detectron2 doesn't have full model gradient clipping now
211 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
212 | enable = (
213 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED
214 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
215 | and clip_norm_val > 0.0
216 | )
217 |
218 | class FullModelGradientClippingOptimizer(optim):
219 | def step(self, closure=None):
220 | all_params = itertools.chain(*[x["params"] for x in self.param_groups])
221 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
222 | super().step(closure=closure)
223 |
224 | return FullModelGradientClippingOptimizer if enable else optim
225 |
226 | optimizer_type = cfg.SOLVER.OPTIMIZER
227 | if optimizer_type == "SGD":
228 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
229 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
230 | )
231 | elif optimizer_type == "ADAMW":
232 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
233 | params, cfg.SOLVER.BASE_LR
234 | )
235 | else:
236 | raise NotImplementedError(f"no optimizer type {optimizer_type}")
237 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
238 | optimizer = maybe_add_gradient_clipping(cfg, optimizer)
239 | return optimizer
240 |
241 |
242 | def setup(args):
243 | """
244 | Create configs and perform basic setups.
245 | """
246 | cfg = get_cfg()
247 | cfg.merge_from_file(args.config_file)
248 | cfg.merge_from_list(args.opts)
249 | cfg.freeze()
250 | default_setup(cfg, args)
251 |
252 | rank = comm.get_rank()
253 | setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet")
254 |
255 | return cfg
256 |
257 |
258 | def main(args):
259 | cfg = setup(args)
260 |
261 | if args.eval_only:
262 | model = Trainer.build_model(cfg)
263 | AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
264 | cfg.MODEL.WEIGHTS, resume=args.resume
265 | )
266 | res = Trainer.test(cfg, model) # d2 defaults.py
267 | if comm.is_main_process():
268 | verify_results(cfg, res)
269 | if cfg.TEST.AUG.ENABLED:
270 | res.update(Trainer.test_with_TTA(cfg, model))
271 | return res
272 |
273 | """
274 | If you'd like to do anything fancier than the standard training logic,
275 | consider writing your own training loop or subclassing the trainer.
276 | """
277 | trainer = Trainer(cfg)
278 | trainer.resume_or_load(resume=args.resume)
279 | if cfg.TEST.AUG.ENABLED:
280 | trainer.register_hooks(
281 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
282 | )
283 | return trainer.train()
284 |
285 |
286 | if __name__ == "__main__":
287 | args = default_argument_parser().parse_args()
288 | print("Command Line Args:", args)
289 | launch(
290 | main,
291 | args.num_gpus,
292 | num_machines=args.num_machines,
293 | machine_rank=args.machine_rank,
294 | dist_url=args.dist_url,
295 | args=(args,),
296 | )
297 |
--------------------------------------------------------------------------------