├── .gitignore ├── LICENSE ├── __init__.py ├── compile.sh ├── config └── psenet │ ├── psenet_r50_ctw.py │ ├── psenet_r50_ctw_finetune.py │ ├── psenet_r50_ic15_1024.py │ ├── psenet_r50_ic15_1024_finetune.py │ ├── psenet_r50_ic15_736.py │ ├── psenet_r50_ic15_736_finetune.py │ ├── psenet_r50_synth.py │ ├── psenet_r50_tt.py │ └── psenet_r50_tt_finetune.py ├── dataset ├── __init__.py ├── builder.py └── psenet │ ├── __init__.py │ ├── check_dataloader.py │ ├── psenet_ctw.py │ ├── psenet_ic15.py │ ├── psenet_synth.py │ └── psenet_tt.py ├── eval ├── ctw │ ├── eval.py │ └── file_util.py ├── eval_ctw.sh ├── eval_ic15.sh ├── eval_ic15_rec.sh ├── eval_msra.sh ├── eval_tt.sh ├── eval_tt_rec.sh ├── ic15 │ ├── gt.zip │ ├── rrc_evaluation_funcs.py │ ├── rrc_evaluation_funcs_v1.py │ ├── rrc_evaluation_funcs_v2.py │ ├── script.py │ └── script_self_adapt.py ├── ic15_rec │ ├── gt.zip │ ├── readme.txt │ ├── rrc_evaluation_funcs_1_1.py │ ├── script.py │ └── script_self_adapt.py ├── msra │ ├── eval.py │ └── file_util.py ├── tt │ ├── Deteval.py │ ├── Deteval_rec.py │ └── polygon_wrapper.py └── tt_rec │ ├── gt.zip │ ├── readme.txt │ ├── rrc_evaluation_funcs_1_1.py │ └── script.py ├── logo.jpg ├── models ├── __init__.py ├── backbone │ ├── __init__.py │ ├── builder.py │ └── resnet.py ├── builder.py ├── head │ ├── __init__.py │ ├── builder.py │ └── psenet_head.py ├── loss │ ├── __init__.py │ ├── acc.py │ ├── builder.py │ ├── dice_loss.py │ ├── emb_loss_v1.py │ ├── iou.py │ └── ohem.py ├── neck │ ├── __init__.py │ ├── builder.py │ └── fpn.py ├── post_processing │ ├── __init__.py │ └── pse │ │ ├── __init__.py │ │ ├── pse.cpp │ │ ├── pse.pyx │ │ ├── readme.txt │ │ └── setup.py ├── psenet.py ├── pypse.py └── utils │ ├── __init__.py │ ├── conv_bn_relu.py │ └── fuse_conv_bn.py ├── readme.md ├── requirement.txt ├── test.py ├── train.py └── utils ├── __init__.py ├── average_meter.py ├── logger.py └── result_format.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.pyc 4 | 5 | # C extensions 6 | *.so 7 | *.o 8 | *.nfs* 9 | 10 | # Distribution / packaging 11 | .Python 12 | *build/ 13 | *out/ 14 | *outputs/ 15 | *data/ 16 | *weights/ 17 | *ckpt/ 18 | *pretrain/ 19 | *.pth 20 | *job.* 21 | *env.sh 22 | *.tar 23 | *checkpoints/ 24 | *dataloader_vis/ 25 | *pretrained/ 26 | pretrained 27 | data 28 | vis/ 29 | cc.sh 30 | *~ 31 | tmp/ 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018-2019 Open-MMLab. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2018-2019 Open-MMLab. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #from .pa import pa 2 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | cd ./models/post_processing/pse/ 2 | python setup.py build_ext --inplace 3 | cd ../../../ -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ctw.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_CTW', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.7, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_CTW', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=600, 50 | optimizer='SGD' 51 | ) 52 | test_cfg = dict( 53 | min_score=0.85, 54 | min_area=16, 55 | kernel_num=7, 56 | bbox_type='poly', 57 | result_path='outputs/submit_ctw.zip' 58 | ) 59 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ctw_finetune.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_CTW', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.7, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_CTW', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=520, 50 | optimizer='SGD', 51 | pretrain='checkpoints/psenet_r50_synth/checkpoint.pth.tar' 52 | ) 53 | test_cfg = dict( 54 | min_score=0.85, 55 | min_area=16, 56 | kernel_num=7, 57 | bbox_type='poly', 58 | result_path='outputs/submit_ctw.zip' 59 | ) 60 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ic15_1024.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_IC15', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=1024, 35 | kernel_num=7, 36 | min_scale=0.4, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_IC15', 41 | split='test', 42 | short_size=1024, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=580, 50 | optimizer='SGD' 51 | ) 52 | test_cfg = dict( 53 | min_score=0.85, 54 | min_area=16, 55 | kernel_num=7, 56 | bbox_type='rect', 57 | result_path='outputs/submit_ic15.zip' 58 | ) 59 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ic15_1024_finetune.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_IC15', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=1024, 35 | kernel_num=7, 36 | min_scale=0.4, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_IC15', 41 | split='test', 42 | short_size=1024, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=580, 50 | optimizer='SGD', 51 | pretrain='checkpoints/psenet_r50_synth/checkpoint.pth.tar' 52 | ) 53 | test_cfg = dict( 54 | min_score=0.85, 55 | min_area=16, 56 | kernel_num=7, 57 | bbox_type='rect', 58 | result_path='outputs/submit_ic15.zip' 59 | ) 60 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ic15_736.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_IC15', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.4, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_IC15', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=600, 50 | optimizer='SGD' 51 | ) 52 | test_cfg = dict( 53 | min_score=0.85, 54 | min_area=16, 55 | kernel_num=7, 56 | bbox_type='rect', 57 | result_path='outputs/submit_ic15.zip' 58 | ) 59 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_ic15_736_finetune.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_IC15', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.4, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_IC15', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=600, 50 | optimizer='SGD', 51 | pretrain='checkpoints/psenet_r50_synth/checkpoint.pth.tar' 52 | ) 53 | test_cfg = dict( 54 | min_score=0.85, 55 | min_area=16, 56 | kernel_num=7, 57 | bbox_type='rect', 58 | result_path='outputs/submit_ic15_736.zip' 59 | ) 60 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_synth.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_Synth', 31 | is_transform=True, 32 | img_size=736, 33 | short_size=736, 34 | kernel_num=7, 35 | min_scale=0.7, 36 | read_type='cv2' 37 | ) 38 | ) 39 | train_cfg = dict( 40 | lr=1e-3, 41 | schedule='polylr', 42 | epoch=1, 43 | optimizer='SGD' 44 | ) 45 | 46 | -------------------------------------------------------------------------------- /config/psenet/psenet_r50_tt.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_TT', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.7, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_TT', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=600, 50 | optimizer='SGD' 51 | ) 52 | test_cfg = dict( 53 | min_score=0.87, 54 | min_area=16, 55 | kernel_num=7, 56 | bbox_type='poly', 57 | result_path='outputs/submit_tt/' 58 | ) -------------------------------------------------------------------------------- /config/psenet/psenet_r50_tt_finetune.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PSENet', 3 | backbone=dict( 4 | type='resnet50', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPN', 9 | in_channels=(256, 512, 1024, 2048), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PSENet_Head', 14 | in_channels=1024, 15 | hidden_dim=256, 16 | num_classes=7, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=0.7 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.3 24 | ) 25 | ) 26 | ) 27 | data = dict( 28 | batch_size=16, 29 | train=dict( 30 | type='PSENET_TT', 31 | split='train', 32 | is_transform=True, 33 | img_size=736, 34 | short_size=736, 35 | kernel_num=7, 36 | min_scale=0.7, 37 | read_type='cv2' 38 | ), 39 | test=dict( 40 | type='PSENET_TT', 41 | split='test', 42 | short_size=736, 43 | read_type='cv2' 44 | ) 45 | ) 46 | train_cfg = dict( 47 | lr=1e-3, 48 | schedule=(200, 400,), 49 | epoch=600, 50 | optimizer='SGD', 51 | pretrain='checkpoints/psenet_r50_synth/checkpoint.pth.tar' 52 | ) 53 | test_cfg = dict( 54 | min_score=0.87, 55 | min_area=16, 56 | kernel_num=7, 57 | bbox_type='poly', 58 | result_path='outputs/submit_tt/' 59 | ) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .psenet import PSENET_IC15, PSENET_TT, PSENET_Synth, PSENET_CTW 2 | from .builder import build_data_loader 3 | 4 | __all__ = ['PSENET_IC15', 'PSENET_TT', 'PSENET_CTW', 'PSENET_Synth'] 5 | -------------------------------------------------------------------------------- /dataset/builder.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | 3 | 4 | def build_data_loader(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | data_loader = dataset.__dict__[cfg.type](**param) 12 | 13 | return data_loader 14 | -------------------------------------------------------------------------------- /dataset/psenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .psenet_ic15 import PSENET_IC15 2 | from .psenet_tt import PSENET_TT 3 | from .psenet_ctw import PSENET_CTW 4 | from .psenet_synth import PSENET_Synth -------------------------------------------------------------------------------- /dataset/psenet/check_dataloader.py: -------------------------------------------------------------------------------- 1 | from psenet_ctw import PSENET_CTW 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import random 6 | import os 7 | 8 | torch.manual_seed(123456) 9 | torch.cuda.manual_seed(123456) 10 | np.random.seed(123456) 11 | random.seed(123456) 12 | 13 | 14 | def to_rgb(img): 15 | img = img.reshape(img.shape[0], img.shape[1], 1) 16 | img = np.concatenate((img, img, img), axis=2) * 255 17 | return img 18 | 19 | 20 | def save(img_path, imgs): 21 | if not os.path.exists('vis/'): 22 | os.makedirs('vis/') 23 | 24 | for i in range(len(imgs)): 25 | imgs[i] = cv2.copyMakeBorder(imgs[i], 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 26 | res = np.concatenate(imgs, axis=1) 27 | if type(img_path) != str: 28 | img_name = img_path[0].split('/')[-1] 29 | else: 30 | img_name = img_path.split('/')[-1] 31 | print('saved %s.' % img_name) 32 | cv2.imwrite('vis/' + img_name, res) 33 | 34 | 35 | 36 | # data_loader = SynthLoader(split='train', is_transform=True, img_size=640, kernel_scale=0.5, short_size=640, 37 | # for_rec=True) 38 | # data_loader = IC15Loader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736, 39 | # for_rec=True) 40 | # data_loader = CombineLoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736, 41 | # for_rec=True) 42 | # data_loader = TTLoader(split='train', is_transform=True, img_size=640, kernel_scale=0.8, short_size=640, 43 | # for_rec=True, read_type='pil') 44 | # data_loader = CombineAllLoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736, 45 | # for_rec=True) 46 | data_loader = PSENET_CTW(split='test', is_transform=True, img_size=736) 47 | # data_loader = MSRALoader(split='train', is_transform=True, img_size=736, kernel_scale=0.5, short_size=736, 48 | # for_rec=True) 49 | # data_loader = CTWv2Loader(split='train', is_transform=True, img_size=640, kernel_scale=0.7, short_size=640, 50 | # for_rec=True) 51 | # data_loader = IC15(split='train', is_transform=True, img_size=640,) 52 | 53 | train_loader = torch.utils.data.DataLoader( 54 | data_loader, 55 | batch_size=1, 56 | shuffle=False, 57 | num_workers=0, 58 | drop_last=True) 59 | 60 | for batch_idx, imgs in enumerate(train_loader): 61 | if batch_idx > 100: 62 | break 63 | # image_name = data_loader.img_paths[batch_idx].split('/')[-1].split('.')[0] 64 | 65 | # print('%d/%d %s'%(batch_idx, len(train_loader), data_loader.img_paths[batch_idx])) 66 | print('%d/%d' % (batch_idx, len(train_loader))) 67 | 68 | img = imgs[0].numpy() 69 | img = ((img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + 70 | np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)) * 255).astype(np.uint8) 71 | img = np.transpose(img, (1, 2, 0))[:, :, ::-1].copy() 72 | 73 | # gt_text = to_rgb(gt_texts[0].numpy()) 74 | # gt_kernel_0 = to_rgb(gt_kernels[0, 0].numpy()) 75 | # gt_kernel_1 = to_rgb(gt_kernels[0, 1].numpy()) 76 | # gt_kernel_2 = to_rgb(gt_kernels[0, 2].numpy()) 77 | # gt_kernel_3 = to_rgb(gt_kernels[0, 3].numpy()) 78 | # gt_kernel_4 = to_rgb(gt_kernels[0, 4].numpy()) 79 | # gt_kernel_5 = to_rgb(gt_kernels[0, 5].numpy()) 80 | # gt_text_mask = to_rgb(training_masks[0].numpy().astype(np.uint8)) 81 | 82 | 83 | # save('%d.png' % batch_idx, [img, gt_text, gt_kernel_0, gt_kernel_1, gt_kernel_2, gt_kernel_3, gt_kernel_4, gt_kernel_5, gt_text_mask]) 84 | save('%d_test.png' % batch_idx, [img]) -------------------------------------------------------------------------------- /dataset/psenet/psenet_ctw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils import data 4 | import cv2 5 | import random 6 | import torchvision.transforms as transforms 7 | import torch 8 | import pyclipper 9 | import Polygon as plg 10 | import math 11 | import string 12 | import scipy.io as scio 13 | import mmcv 14 | 15 | ctw_root_dir = './data/ctw1500/' 16 | ctw_train_data_dir = ctw_root_dir + 'train/text_image/' 17 | ctw_train_gt_dir = ctw_root_dir + 'train/text_label_curve/' 18 | ctw_test_data_dir = ctw_root_dir + 'test/text_image/' 19 | ctw_test_gt_dir = ctw_root_dir + 'test/text_label_circum/' 20 | 21 | 22 | def get_img(img_path, read_type='pil'): 23 | try: 24 | if read_type == 'cv2': 25 | img = cv2.imread(img_path) 26 | img = img[:, :, [2, 1, 0]] 27 | elif read_type == 'pil': 28 | img = np.array(Image.open(img_path)) 29 | except Exception as e: 30 | print(img_path) 31 | raise 32 | return img 33 | 34 | 35 | def get_ann(img, gt_path): 36 | h, w = img.shape[0:2] 37 | lines = mmcv.list_from_file(gt_path) 38 | bboxes = [] 39 | words = [] 40 | for line in lines: 41 | line = line.replace('\xef\xbb\xbf', '') 42 | gt = line.split(',') 43 | 44 | x1 = np.int(gt[0]) 45 | y1 = np.int(gt[1]) 46 | 47 | bbox = [np.int(gt[i]) for i in range(4, 32)] 48 | bbox = np.asarray(bbox) + ([x1 * 1.0, y1 * 1.0] * 14) 49 | bbox = np.asarray(bbox) / ([w * 1.0, h * 1.0] * 14) 50 | 51 | bboxes.append(bbox) 52 | words.append('???') 53 | return bboxes, words 54 | 55 | 56 | def random_horizontal_flip(imgs): 57 | if random.random() < 0.5: 58 | for i in range(len(imgs)): 59 | imgs[i] = np.flip(imgs[i], axis=1).copy() 60 | return imgs 61 | 62 | 63 | def random_rotate(imgs): 64 | max_angle = 10 65 | angle = random.random() * 2 * max_angle - max_angle 66 | for i in range(len(imgs)): 67 | img = imgs[i] 68 | w, h = img.shape[:2] 69 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 70 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 71 | imgs[i] = img_rotation 72 | return imgs 73 | 74 | 75 | def scale_aligned(img, scale): 76 | h, w = img.shape[0:2] 77 | h = int(h * scale + 0.5) 78 | w = int(w * scale + 0.5) 79 | if h % 32 != 0: 80 | h = h + (32 - h % 32) 81 | if w % 32 != 0: 82 | w = w + (32 - w % 32) 83 | img = cv2.resize(img, dsize=(w, h)) 84 | return img 85 | 86 | 87 | def random_scale(img, short_size=736): 88 | h, w = img.shape[0:2] 89 | 90 | random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) 91 | scale = (np.random.choice(random_scale) * short_size) / min(h, w) 92 | 93 | img = scale_aligned(img, scale) 94 | return img 95 | 96 | 97 | def scale_aligned_short(img, short_size=736): 98 | h, w = img.shape[0:2] 99 | scale = short_size * 1.0 / min(h, w) 100 | h = int(h * scale + 0.5) 101 | w = int(w * scale + 0.5) 102 | if h % 32 != 0: 103 | h = h + (32 - h % 32) 104 | if w % 32 != 0: 105 | w = w + (32 - w % 32) 106 | img = cv2.resize(img, dsize=(w, h)) 107 | return img 108 | 109 | 110 | def random_crop_padding(imgs, target_size): 111 | h, w = imgs[0].shape[0:2] 112 | t_w, t_h = target_size 113 | p_w, p_h = target_size 114 | if w == t_w and h == t_h: 115 | return imgs 116 | 117 | t_h = t_h if t_h < h else h 118 | t_w = t_w if t_w < w else w 119 | 120 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 121 | # make sure to crop the text region 122 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 123 | tl[tl < 0] = 0 124 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 125 | br[br < 0] = 0 126 | br[0] = min(br[0], h - t_h) 127 | br[1] = min(br[1], w - t_w) 128 | 129 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 130 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 131 | else: 132 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 133 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 134 | 135 | n_imgs = [] 136 | for idx in range(len(imgs)): 137 | if len(imgs[idx].shape) == 3: 138 | s3_length = int(imgs[idx].shape[-1]) 139 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 140 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 141 | value=tuple(0 for i in range(s3_length))) 142 | else: 143 | img = imgs[idx][i:i + t_h, j:j + t_w] 144 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 145 | n_imgs.append(img_p) 146 | return n_imgs 147 | 148 | 149 | def dist(a, b): 150 | return np.linalg.norm((a - b), ord=2, axis=0) 151 | 152 | 153 | def perimeter(bbox): 154 | peri = 0.0 155 | for i in range(bbox.shape[0]): 156 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 157 | return peri 158 | 159 | 160 | def shrink(bboxes, rate, max_shr=20): 161 | rate = rate * rate 162 | shrinked_bboxes = [] 163 | for bbox in bboxes: 164 | area = plg.Polygon(bbox).area() 165 | peri = perimeter(bbox) 166 | 167 | try: 168 | pco = pyclipper.PyclipperOffset() 169 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 170 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 171 | 172 | shrinked_bbox = pco.Execute(-offset) 173 | if len(shrinked_bbox) == 0: 174 | shrinked_bboxes.append(bbox) 175 | continue 176 | 177 | shrinked_bbox = np.array(shrinked_bbox[0]) 178 | if shrinked_bbox.shape[0] <= 2: 179 | shrinked_bboxes.append(bbox) 180 | continue 181 | 182 | shrinked_bboxes.append(shrinked_bbox) 183 | except Exception as e: 184 | print(type(shrinked_bbox), shrinked_bbox) 185 | print('area:', area, 'peri:', peri) 186 | shrinked_bboxes.append(bbox) 187 | 188 | return shrinked_bboxes 189 | 190 | 191 | class PSENET_CTW(data.Dataset): 192 | def __init__(self, 193 | split='train', 194 | is_transform=False, 195 | img_size=None, 196 | short_size=736, 197 | kernel_num=7, 198 | min_scale=0.4, 199 | read_type='pil', 200 | report_speed=False): 201 | self.split = split 202 | self.is_transform = is_transform 203 | 204 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 205 | self.kernel_num = kernel_num 206 | self.min_scale = min_scale 207 | self.short_size = short_size 208 | self.read_type = read_type 209 | 210 | if split == 'train': 211 | data_dirs = [ctw_train_data_dir] 212 | gt_dirs = [ctw_train_gt_dir] 213 | elif split == 'test': 214 | data_dirs = [ctw_test_data_dir] 215 | gt_dirs = [ctw_test_gt_dir] 216 | else: 217 | print('Error: split must be test or train!') 218 | raise 219 | 220 | self.img_paths = [] 221 | self.gt_paths = [] 222 | 223 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 224 | img_names = [img_name for img_name in mmcv.utils.scandir(data_dir, '.jpg')] 225 | img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir, '.png')]) 226 | 227 | img_paths = [] 228 | gt_paths = [] 229 | for idx, img_name in enumerate(img_names): 230 | img_path = data_dir + img_name 231 | img_paths.append(img_path) 232 | 233 | gt_name = img_name.split('.')[0] + '.txt' 234 | gt_path = gt_dir + gt_name 235 | gt_paths.append(gt_path) 236 | 237 | self.img_paths.extend(img_paths) 238 | self.gt_paths.extend(gt_paths) 239 | 240 | if report_speed: 241 | target_size = 3000 242 | data_size = len(self.img_paths) 243 | extend_scale = (target_size + data_size - 1) // data_size 244 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 245 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 246 | 247 | self.max_word_num = 200 248 | 249 | def __len__(self): 250 | return len(self.img_paths) 251 | 252 | def prepare_train_data(self, index): 253 | img_path = self.img_paths[index] 254 | gt_path = self.gt_paths[index] 255 | 256 | img = get_img(img_path, self.read_type) 257 | bboxes, words = get_ann(img, gt_path) 258 | 259 | if len(bboxes) > self.max_word_num: 260 | bboxes = bboxes[:self.max_word_num] 261 | 262 | if self.is_transform: 263 | img = random_scale(img, self.short_size) 264 | 265 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 266 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 267 | if len(bboxes) > 0: 268 | for i in range(len(bboxes)): 269 | bboxes[i] = np.reshape(bboxes[i] * ([img.shape[1], img.shape[0]] * (bboxes[i].shape[0] // 2)), 270 | (bboxes[i].shape[0] // 2, 2)).astype('int32') 271 | for i in range(len(bboxes)): 272 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 273 | if words[i] == '###': 274 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 275 | 276 | gt_kernels = [] 277 | for i in range(1, self.kernel_num): 278 | rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i 279 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 280 | kernel_bboxes = shrink(bboxes, rate) 281 | for i in range(len(bboxes)): 282 | cv2.drawContours(gt_kernel, [kernel_bboxes[i].astype(int)], -1, 1, -1) 283 | gt_kernels.append(gt_kernel) 284 | 285 | if self.is_transform: 286 | imgs = [img, gt_instance, training_mask] 287 | imgs.extend(gt_kernels) 288 | 289 | imgs = random_horizontal_flip(imgs) 290 | imgs = random_rotate(imgs) 291 | imgs = random_crop_padding(imgs, self.img_size) 292 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 293 | 294 | gt_text = gt_instance.copy() 295 | gt_text[gt_text > 0] = 1 296 | gt_kernels = np.array(gt_kernels) 297 | 298 | if self.is_transform: 299 | img = Image.fromarray(img) 300 | img = img.convert('RGB') 301 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 302 | else: 303 | img = Image.fromarray(img) 304 | img = img.convert('RGB') 305 | 306 | img = transforms.ToTensor()(img) 307 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 308 | 309 | gt_text = torch.from_numpy(gt_text).long() 310 | gt_kernels = torch.from_numpy(gt_kernels).long() 311 | training_mask = torch.from_numpy(training_mask).long() 312 | 313 | data = dict( 314 | imgs=img, 315 | gt_texts=gt_text, 316 | gt_kernels=gt_kernels, 317 | training_masks=training_mask, 318 | ) 319 | 320 | return data 321 | # return img, gt_text, gt_kernels, training_mask 322 | 323 | def prepare_test_data(self, index): 324 | img_path = self.img_paths[index] 325 | 326 | img = get_img(img_path, self.read_type) 327 | img_meta = dict( 328 | org_img_size=np.array(img.shape[:2]) 329 | ) 330 | 331 | img = scale_aligned_short(img, self.short_size) 332 | img_meta.update(dict( 333 | img_size=np.array(img.shape[:2]) 334 | )) 335 | 336 | img = Image.fromarray(img) 337 | img = img.convert('RGB') 338 | img = transforms.ToTensor()(img) 339 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 340 | data = dict( 341 | imgs=img, 342 | img_metas=img_meta 343 | ) 344 | 345 | return data 346 | 347 | def __getitem__(self, index): 348 | if self.split == 'train': 349 | return self.prepare_train_data(index) 350 | elif self.split == 'test': 351 | return self.prepare_test_data(index) 352 | -------------------------------------------------------------------------------- /dataset/psenet/psenet_ic15.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils import data 4 | import cv2 5 | import random 6 | import torchvision.transforms as transforms 7 | import torch 8 | import pyclipper 9 | import Polygon as plg 10 | import math 11 | import mmcv 12 | import string 13 | 14 | ic15_root_dir = './data/ICDAR2015/Challenge4/' 15 | ic15_train_data_dir = ic15_root_dir + 'ch4_training_images/' 16 | ic15_train_gt_dir = ic15_root_dir + 'ch4_training_localization_transcription_gt/' 17 | 18 | ic15_test_data_dir = ic15_root_dir + 'ch4_test_images/' 19 | ic15_test_gt_dir = ic15_root_dir + 'ch4_test_localization_transcription_gt/' 20 | 21 | 22 | def get_img(img_path, read_type='pil'): 23 | try: 24 | if read_type == 'cv2': 25 | img = cv2.imread(img_path) 26 | img = img[:, :, [2, 1, 0]] 27 | elif read_type == 'pil': 28 | img = np.array(Image.open(img_path)) 29 | except Exception as e: 30 | print('Cannot read image: %s.' % img_path) 31 | raise 32 | return img 33 | 34 | 35 | def dist(a, b): 36 | return np.linalg.norm((a - b), ord=2, axis=0) 37 | 38 | 39 | def perimeter(bbox): 40 | peri = 0.0 41 | for i in range(bbox.shape[0]): 42 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 43 | return peri 44 | 45 | 46 | def shrink(bboxes, rate, max_shr=20): 47 | rate = rate * rate 48 | shrinked_bboxes = [] 49 | for bbox in bboxes: 50 | area = plg.Polygon(bbox).area() 51 | peri = perimeter(bbox) 52 | 53 | try: 54 | pco = pyclipper.PyclipperOffset() 55 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 56 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 57 | 58 | shrinked_bbox = pco.Execute(-offset) 59 | if len(shrinked_bbox) == 0: 60 | shrinked_bboxes.append(bbox) 61 | continue 62 | 63 | shrinked_bbox = np.array(shrinked_bbox)[0] 64 | if shrinked_bbox.shape[0] <= 2: 65 | shrinked_bboxes.append(bbox) 66 | continue 67 | 68 | shrinked_bboxes.append(shrinked_bbox) 69 | except Exception as e: 70 | print('area:', area, 'peri:', peri) 71 | shrinked_bboxes.append(bbox) 72 | 73 | return shrinked_bboxes 74 | 75 | 76 | def get_ann(img, gt_path): 77 | h, w = img.shape[0:2] 78 | lines = mmcv.list_from_file(gt_path) 79 | bboxes = [] 80 | words = [] 81 | for line in lines: 82 | line = line.encode('utf-8').decode('utf-8-sig') 83 | line = line.replace('\xef\xbb\xbf\ufeff', '') 84 | gt = line.split(',') 85 | word = gt[8].replace('\r', '').replace('\n', '') 86 | if word[0] == '#': 87 | words.append('###') 88 | else: 89 | words.append(word) 90 | bbox = [int(gt[i]) for i in range(8)] 91 | bbox = np.array(bbox) / ([w * 1.0, h * 1.0] * 4) 92 | bboxes.append(bbox) 93 | return np.array(bboxes), words 94 | 95 | 96 | def random_horizontal_flip(imgs): 97 | if random.random() < 0.5: 98 | for i in range(len(imgs)): 99 | imgs[i] = np.flip(imgs[i], axis=1).copy() 100 | return imgs 101 | 102 | 103 | def random_rotate(imgs): 104 | max_angle = 10 105 | angle = random.random() * 2 * max_angle - max_angle 106 | for i in range(len(imgs)): 107 | img = imgs[i] 108 | w, h = img.shape[:2] 109 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 110 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 111 | imgs[i] = img_rotation 112 | return imgs 113 | 114 | 115 | def scale_aligned_short(img, short_size=736): 116 | # print('original img_size:', img.shape) 117 | h, w = img.shape[0:2] 118 | scale = short_size * 1.0 / min(h, w) 119 | h = int(h * scale + 0.5) 120 | w = int(w * scale + 0.5) 121 | if h % 32 != 0: 122 | h = h + (32 - h % 32) 123 | if w % 32 != 0: 124 | w = w + (32 - w % 32) 125 | img = cv2.resize(img, dsize=(w, h)) 126 | # print('img_size:', img.shape) 127 | return img 128 | 129 | 130 | def scale_aligned(img, h_scale, w_scale): 131 | h, w = img.shape[0:2] 132 | h = int(h * h_scale + 0.5) 133 | w = int(w * w_scale + 0.5) 134 | if h % 32 != 0: 135 | h = h + (32 - h % 32) 136 | if w % 32 != 0: 137 | w = w + (32 - w % 32) 138 | img = cv2.resize(img, dsize=(w, h)) 139 | return img 140 | 141 | 142 | def random_scale(img, short_size=736): 143 | h, w = img.shape[0:2] 144 | 145 | scale = np.random.choice(np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])) 146 | scale = (scale * short_size) / min(h, w) 147 | 148 | aspect = np.random.choice(np.array([0.9, 0.95, 1.0, 1.05, 1.1])) 149 | h_scale = scale * math.sqrt(aspect) 150 | w_scale = scale / math.sqrt(aspect) 151 | 152 | img = scale_aligned(img, h_scale, w_scale) 153 | return img 154 | 155 | 156 | def random_crop_padding(imgs, target_size): 157 | h, w = imgs[0].shape[0:2] 158 | t_w, t_h = target_size 159 | p_w, p_h = target_size 160 | if w == t_w and h == t_h: 161 | return imgs 162 | 163 | t_h = t_h if t_h < h else h 164 | t_w = t_w if t_w < w else w 165 | 166 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 167 | # make sure to crop the text region 168 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 169 | tl[tl < 0] = 0 170 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 171 | br[br < 0] = 0 172 | br[0] = min(br[0], h - t_h) 173 | br[1] = min(br[1], w - t_w) 174 | 175 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 176 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 177 | else: 178 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 179 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 180 | 181 | n_imgs = [] 182 | for idx in range(len(imgs)): 183 | if len(imgs[idx].shape) == 3: 184 | s3_length = int(imgs[idx].shape[-1]) 185 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 186 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 187 | value=tuple(0 for i in range(s3_length))) 188 | else: 189 | img = imgs[idx][i:i + t_h, j:j + t_w] 190 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 191 | n_imgs.append(img_p) 192 | return n_imgs 193 | 194 | 195 | class PSENET_IC15(data.Dataset): 196 | def __init__(self, 197 | split='train', 198 | is_transform=False, 199 | img_size=None, 200 | short_size=736, 201 | kernel_num=7, 202 | min_scale=0.4, 203 | with_rec=False, 204 | read_type='pil', 205 | report_speed=False): 206 | self.split = split 207 | self.is_transform = is_transform 208 | 209 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 210 | self.short_size = short_size 211 | self.with_rec = with_rec 212 | self.kernel_num = kernel_num 213 | self.min_scale = min_scale 214 | self.read_type = read_type 215 | 216 | if split == 'train': 217 | data_dirs = [ic15_train_data_dir] 218 | gt_dirs = [ic15_train_gt_dir] 219 | elif split == 'test': 220 | data_dirs = [ic15_test_data_dir] 221 | gt_dirs = [ic15_test_gt_dir] 222 | else: 223 | print('Error: split must be train or test!') 224 | raise 225 | 226 | self.img_paths = [] 227 | self.gt_paths = [] 228 | 229 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 230 | img_names = [img_name for img_name in mmcv.utils.scandir(data_dir, '.jpg')] 231 | img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir, '.png')]) 232 | 233 | img_paths = [] 234 | gt_paths = [] 235 | for idx, img_name in enumerate(img_names): 236 | img_path = data_dir + img_name 237 | img_paths.append(img_path) 238 | 239 | gt_name = 'gt_' + img_name.split('.')[0] + '.txt' 240 | gt_path = gt_dir + gt_name 241 | gt_paths.append(gt_path) 242 | 243 | self.img_paths.extend(img_paths) 244 | self.gt_paths.extend(gt_paths) 245 | 246 | # sample for speed test 247 | if report_speed: 248 | target_size = 3000 249 | extend_scale = (target_size + len(self.img_paths) - 1) // len(self.img_paths) 250 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 251 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 252 | 253 | self.max_word_num = 200 254 | # self.max_word_len = 32 255 | 256 | def __len__(self): 257 | return len(self.img_paths) 258 | 259 | def prepare_train_data(self, index): 260 | img_path = self.img_paths[index] 261 | gt_path = self.gt_paths[index] 262 | 263 | img = get_img(img_path, self.read_type) 264 | bboxes, words = get_ann(img, gt_path) 265 | 266 | # max line in gt 267 | if bboxes.shape[0] > self.max_word_num: 268 | bboxes = bboxes[:self.max_word_num] 269 | words = words[:self.max_word_num] 270 | 271 | if self.is_transform: 272 | img = random_scale(img, self.short_size) 273 | 274 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 275 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 276 | if bboxes.shape[0] > 0: # line 277 | bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), 278 | (bboxes.shape[0], -1, 2)).astype('int32') 279 | for i in range(bboxes.shape[0]): 280 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 281 | if words[i] == '###': 282 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 283 | 284 | gt_kernels = [] 285 | for i in range(1, self.kernel_num): 286 | rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i 287 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 288 | kernel_bboxes = shrink(bboxes, rate) 289 | for i in range(bboxes.shape[0]): 290 | cv2.drawContours(gt_kernel, [kernel_bboxes[i].astype(int)], -1, 1, -1) 291 | gt_kernels.append(gt_kernel) 292 | 293 | if self.is_transform: 294 | imgs = [img, gt_instance, training_mask] 295 | imgs.extend(gt_kernels) 296 | 297 | if not self.with_rec: 298 | imgs = random_horizontal_flip(imgs) 299 | imgs = random_rotate(imgs) 300 | imgs = random_crop_padding(imgs, self.img_size) 301 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 302 | 303 | gt_text = gt_instance.copy() 304 | gt_text[gt_text > 0] = 1 305 | gt_kernels = np.array(gt_kernels) 306 | 307 | img = Image.fromarray(img) 308 | img = img.convert('RGB') 309 | if self.is_transform: 310 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 311 | 312 | img = transforms.ToTensor()(img) 313 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 314 | gt_text = torch.from_numpy(gt_text).long() 315 | gt_kernels = torch.from_numpy(gt_kernels).long() 316 | training_mask = torch.from_numpy(training_mask).long() 317 | 318 | data = dict( 319 | imgs=img, 320 | gt_texts=gt_text, 321 | gt_kernels=gt_kernels, 322 | training_masks=training_mask, 323 | ) 324 | 325 | return data 326 | 327 | def prepare_test_data(self, index): 328 | img_path = self.img_paths[index] 329 | 330 | img = get_img(img_path, self.read_type) 331 | img_meta = dict( 332 | org_img_size=np.array(img.shape[:2]) 333 | ) 334 | 335 | img = scale_aligned_short(img, self.short_size) 336 | img_meta.update(dict( 337 | img_size=np.array(img.shape[:2]) 338 | )) 339 | 340 | img = Image.fromarray(img) 341 | img = img.convert('RGB') 342 | img = transforms.ToTensor()(img) 343 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 344 | 345 | data = dict( 346 | imgs=img, 347 | img_metas=img_meta 348 | ) 349 | 350 | return data 351 | 352 | def __getitem__(self, index): 353 | if self.split == 'train': 354 | return self.prepare_train_data(index) 355 | elif self.split == 'test': 356 | return self.prepare_test_data(index) 357 | -------------------------------------------------------------------------------- /dataset/psenet/psenet_synth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils import data 4 | import cv2 5 | import random 6 | import torchvision.transforms as transforms 7 | import torch 8 | import pyclipper 9 | import Polygon as plg 10 | import math 11 | import string 12 | import scipy.io as scio 13 | 14 | synth_root_dir = './data/SynthText/' 15 | synth_train_data_dir = synth_root_dir 16 | synth_train_gt_path = synth_root_dir + 'gt.mat' 17 | 18 | 19 | def get_img(img_path, read_type='pil'): 20 | try: 21 | if read_type == 'cv2': 22 | img = cv2.imread(img_path) 23 | img = img[:, :, [2, 1, 0]] 24 | elif read_type == 'pil': 25 | img = np.array(Image.open(img_path)) 26 | except Exception as e: 27 | print(img_path) 28 | raise 29 | return img 30 | 31 | 32 | def get_ann(img, gts, texts, index): 33 | bboxes = np.array(gts[index]) 34 | bboxes = np.reshape(bboxes, (bboxes.shape[0], bboxes.shape[1], -1)) 35 | bboxes = bboxes.transpose(2, 1, 0) 36 | bboxes = np.reshape(bboxes, (bboxes.shape[0], -1)) / ([img.shape[1], img.shape[0]] * 4) 37 | 38 | words = [] 39 | for text in texts[index]: 40 | text = text.replace('\n', ' ').replace('\r', ' ') 41 | words.extend([w for w in text.split(' ') if len(w) > 0]) 42 | 43 | return bboxes, words 44 | 45 | 46 | def random_horizontal_flip(imgs): 47 | if random.random() < 0.5: 48 | for i in range(len(imgs)): 49 | imgs[i] = np.flip(imgs[i], axis=1).copy() 50 | return imgs 51 | 52 | 53 | def random_rotate(imgs): 54 | max_angle = 10 55 | angle = random.random() * 2 * max_angle - max_angle 56 | for i in range(len(imgs)): 57 | img = imgs[i] 58 | w, h = img.shape[:2] 59 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 60 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 61 | imgs[i] = img_rotation 62 | return imgs 63 | 64 | 65 | def scale_aligned(img, h_scale, w_scale): 66 | h, w = img.shape[0:2] 67 | h = int(h * h_scale + 0.5) 68 | w = int(w * w_scale + 0.5) 69 | if h % 32 != 0: 70 | h = h + (32 - h % 32) 71 | if w % 32 != 0: 72 | w = w + (32 - w % 32) 73 | img = cv2.resize(img, dsize=(w, h)) 74 | return img 75 | 76 | 77 | def random_scale(img, short_size=736): 78 | h, w = img.shape[0:2] 79 | 80 | scale = np.random.choice(np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])) 81 | scale = (scale * short_size) / min(h, w) 82 | 83 | aspect = np.random.choice(np.array([0.9, 0.95, 1.0, 1.05, 1.1])) 84 | h_scale = scale * math.sqrt(aspect) 85 | w_scale = scale / math.sqrt(aspect) 86 | 87 | img = scale_aligned(img, h_scale, w_scale) 88 | return img 89 | 90 | 91 | def random_crop_padding(imgs, target_size): 92 | h, w = imgs[0].shape[0:2] 93 | t_w, t_h = target_size 94 | p_w, p_h = target_size 95 | if w == t_w and h == t_h: 96 | return imgs 97 | 98 | t_h = t_h if t_h < h else h 99 | t_w = t_w if t_w < w else w 100 | 101 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 102 | # make sure to crop the text region 103 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 104 | tl[tl < 0] = 0 105 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 106 | br[br < 0] = 0 107 | br[0] = min(br[0], h - t_h) 108 | br[1] = min(br[1], w - t_w) 109 | 110 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 111 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 112 | else: 113 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 114 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 115 | 116 | n_imgs = [] 117 | for idx in range(len(imgs)): 118 | if len(imgs[idx].shape) == 3: 119 | s3_length = int(imgs[idx].shape[-1]) 120 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 121 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 122 | value=tuple(0 for i in range(s3_length))) 123 | else: 124 | img = imgs[idx][i:i + t_h, j:j + t_w] 125 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 126 | n_imgs.append(img_p) 127 | return n_imgs 128 | 129 | 130 | def update_word_mask(instance, instance_before_crop, word_mask): 131 | labels = np.unique(instance) 132 | 133 | for label in labels: 134 | if label == 0: 135 | continue 136 | ind = instance == label 137 | if np.sum(ind) == 0: 138 | word_mask[label] = 0 139 | continue 140 | ind_before_crop = instance_before_crop == label 141 | # print(np.sum(ind), np.sum(ind_before_crop)) 142 | if float(np.sum(ind)) / np.sum(ind_before_crop) > 0.9: 143 | continue 144 | word_mask[label] = 0 145 | 146 | return word_mask 147 | 148 | 149 | def dist(a, b): 150 | return np.linalg.norm((a - b), ord=2, axis=0) 151 | 152 | 153 | def perimeter(bbox): 154 | peri = 0.0 155 | for i in range(bbox.shape[0]): 156 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 157 | return peri 158 | 159 | 160 | def shrink(bboxes, rate, max_shr=20): 161 | rate = rate * rate 162 | shrinked_bboxes = [] 163 | for bbox in bboxes: 164 | area = plg.Polygon(bbox).area() 165 | peri = perimeter(bbox) 166 | 167 | try: 168 | pco = pyclipper.PyclipperOffset() 169 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 170 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 171 | 172 | shrinked_bbox = pco.Execute(-offset) 173 | if len(shrinked_bbox) == 0: 174 | shrinked_bboxes.append(bbox) 175 | continue 176 | 177 | shrinked_bbox = np.array(shrinked_bbox[0]) 178 | if shrinked_bbox.shape[0] <= 2: 179 | shrinked_bboxes.append(bbox) 180 | continue 181 | 182 | shrinked_bboxes.append(shrinked_bbox) 183 | except Exception as e: 184 | print('area:', area, 'peri:', peri) 185 | shrinked_bboxes.append(bbox) 186 | 187 | return shrinked_bboxes 188 | 189 | 190 | class PSENET_Synth(data.Dataset): 191 | def __init__(self, 192 | is_transform=False, 193 | img_size=None, 194 | short_size=736, 195 | kernel_num=7, 196 | min_scale=0.7, 197 | read_type='pil'): 198 | self.is_transform = is_transform 199 | 200 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 201 | self.kernel_num = kernel_num 202 | self.min_scale = min_scale 203 | self.short_size = short_size 204 | self.read_type = read_type 205 | 206 | data = scio.loadmat(synth_train_gt_path) 207 | 208 | self.img_paths = data['imnames'][0] 209 | self.gts = data['wordBB'][0] 210 | self.texts = data['txt'][0] 211 | 212 | self.max_word_num = 200 213 | 214 | def __len__(self): 215 | return len(self.img_paths) 216 | 217 | def __getitem__(self, index): 218 | img_path = synth_train_data_dir + self.img_paths[index][0] 219 | img = get_img(img_path, read_type=self.read_type) 220 | bboxes, words = get_ann(img, self.gts, self.texts, index) 221 | 222 | if bboxes.shape[0] > self.max_word_num: 223 | bboxes = bboxes[:self.max_word_num] 224 | words = words[:self.max_word_num] 225 | 226 | if self.is_transform: 227 | img = random_scale(img, self.short_size) 228 | 229 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 230 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 231 | if bboxes.shape[0] > 0: 232 | bboxes = np.reshape(bboxes * ([img.shape[1], img.shape[0]] * 4), 233 | (bboxes.shape[0], -1, 2)).astype('int32') 234 | for i in range(bboxes.shape[0]): 235 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 236 | if words[i] == '###': 237 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 238 | 239 | gt_kernels = [] 240 | for i in range(1, self.kernel_num): 241 | rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i 242 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 243 | kernel_bboxes = shrink(bboxes, rate) 244 | for i in range(len(bboxes)): 245 | cv2.drawContours(gt_kernel, [kernel_bboxes[i].astype(int)], -1, 1, -1) 246 | gt_kernels.append(gt_kernel) 247 | 248 | if self.is_transform: 249 | imgs = [img, gt_instance, training_mask] 250 | imgs.extend(gt_kernels) 251 | 252 | imgs = random_rotate(imgs) 253 | imgs = random_crop_padding(imgs, self.img_size) 254 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 255 | 256 | gt_text = gt_instance.copy() 257 | gt_text[gt_text > 0] = 1 258 | gt_kernels = np.array(gt_kernels) 259 | 260 | max_instance = np.max(gt_instance) 261 | gt_bboxes = np.zeros((self.max_word_num, 4), dtype=np.int32) 262 | for i in range(1, max_instance + 1): 263 | ind = gt_instance == i 264 | if np.sum(ind) == 0: 265 | continue 266 | points = np.array(np.where(ind)).transpose((1, 0)) 267 | tl = np.min(points, axis=0) 268 | br = np.max(points, axis=0) + 1 269 | gt_bboxes[i] = (tl[0], tl[1], br[0], br[1]) 270 | 271 | img = Image.fromarray(img) 272 | img = img.convert('RGB') 273 | 274 | if self.is_transform: 275 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 276 | 277 | img = transforms.ToTensor()(img) 278 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 279 | 280 | gt_text = torch.from_numpy(gt_text).long() 281 | gt_kernels = torch.from_numpy(gt_kernels).long() 282 | training_mask = torch.from_numpy(training_mask).long() 283 | 284 | data = dict( 285 | imgs=img, 286 | gt_texts=gt_text, 287 | gt_kernels=gt_kernels, 288 | training_masks=training_mask, 289 | ) 290 | 291 | return data 292 | -------------------------------------------------------------------------------- /dataset/psenet/psenet_tt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils import data 4 | import cv2 5 | import random 6 | import torchvision.transforms as transforms 7 | import torch 8 | import pyclipper 9 | import Polygon as plg 10 | import math 11 | import string 12 | import scipy.io as scio 13 | import mmcv 14 | from mmcv.parallel import DataContainer as DC 15 | 16 | tt_root_dir = './data/total_text/' 17 | tt_train_data_dir = tt_root_dir + 'Images/Train/' 18 | tt_train_gt_dir = tt_root_dir + 'Groundtruth/Polygon/Train/' 19 | tt_test_data_dir = tt_root_dir + 'Images/Test/' 20 | tt_test_gt_dir = tt_root_dir + 'Groundtruth/Polygon/Test/' 21 | 22 | 23 | def get_img(img_path, read_type='pil'): 24 | try: 25 | if read_type == 'cv2': 26 | img = cv2.imread(img_path) 27 | img = img[:, :, [2, 1, 0]] 28 | elif read_type == 'pil': 29 | img = np.array(Image.open(img_path)) 30 | except Exception as e: 31 | print(img_path) 32 | raise 33 | return img 34 | 35 | 36 | def read_mat_lindes(path): 37 | f = scio.loadmat(path) 38 | return f 39 | 40 | 41 | def get_ann(img, gt_path): 42 | h, w = img.shape[0:2] 43 | bboxes = [] 44 | words = [] 45 | data = read_mat_lindes(gt_path) 46 | data_polygt = data['polygt'] 47 | for i, lines in enumerate(data_polygt): 48 | X = np.array(lines[1]) 49 | Y = np.array(lines[3]) 50 | 51 | point_num = len(X[0]) 52 | word = lines[4] 53 | if len(word) == 0: 54 | word = '???' 55 | else: 56 | word = word[0] 57 | # word = word[0].encode("utf-8") 58 | 59 | if word == '#': 60 | word = '###' 61 | 62 | words.append(word) 63 | 64 | arr = np.concatenate([X, Y]).T 65 | bbox = [] 66 | for i in range(point_num): 67 | bbox.append(arr[i][0]) 68 | bbox.append(arr[i][1]) 69 | bbox = np.asarray(bbox) / ([w * 1.0, h * 1.0] * point_num) 70 | bboxes.append(bbox) 71 | 72 | return bboxes, words 73 | 74 | 75 | def random_horizontal_flip(imgs): 76 | if random.random() < 0.5: 77 | for i in range(len(imgs)): 78 | imgs[i] = np.flip(imgs[i], axis=1).copy() 79 | return imgs 80 | 81 | 82 | def random_rotate(imgs): 83 | max_angle = 10 84 | angle = random.random() * 2 * max_angle - max_angle 85 | for i in range(len(imgs)): 86 | img = imgs[i] 87 | w, h = img.shape[:2] 88 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 89 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) 90 | imgs[i] = img_rotation 91 | return imgs 92 | 93 | 94 | def scale_aligned(img, scale): 95 | h, w = img.shape[0:2] 96 | h = int(h * scale + 0.5) 97 | w = int(w * scale + 0.5) 98 | if h % 32 != 0: 99 | h = h + (32 - h % 32) 100 | if w % 32 != 0: 101 | w = w + (32 - w % 32) 102 | img = cv2.resize(img, dsize=(w, h)) 103 | return img 104 | 105 | 106 | def random_scale(img, short_size=736): 107 | h, w = img.shape[0:2] 108 | 109 | random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) 110 | scale = (np.random.choice(random_scale) * short_size) / min(h, w) 111 | 112 | img = scale_aligned(img, scale) 113 | return img 114 | 115 | 116 | def scale_aligned_short(img, short_size=736): 117 | h, w = img.shape[0:2] 118 | scale = short_size * 1.0 / min(h, w) 119 | h = int(h * scale + 0.5) 120 | w = int(w * scale + 0.5) 121 | if h % 32 != 0: 122 | h = h + (32 - h % 32) 123 | if w % 32 != 0: 124 | w = w + (32 - w % 32) 125 | img = cv2.resize(img, dsize=(w, h)) 126 | return img 127 | 128 | 129 | def random_crop_padding(imgs, target_size): 130 | h, w = imgs[0].shape[0:2] 131 | t_w, t_h = target_size 132 | p_w, p_h = target_size 133 | if w == t_w and h == t_h: 134 | return imgs 135 | 136 | t_h = t_h if t_h < h else h 137 | t_w = t_w if t_w < w else w 138 | 139 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 140 | # make sure to crop the text region 141 | tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 142 | tl[tl < 0] = 0 143 | br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) 144 | br[br < 0] = 0 145 | br[0] = min(br[0], h - t_h) 146 | br[1] = min(br[1], w - t_w) 147 | 148 | i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 149 | j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 150 | else: 151 | i = random.randint(0, h - t_h) if h - t_h > 0 else 0 152 | j = random.randint(0, w - t_w) if w - t_w > 0 else 0 153 | 154 | n_imgs = [] 155 | for idx in range(len(imgs)): 156 | if len(imgs[idx].shape) == 3: 157 | s3_length = int(imgs[idx].shape[-1]) 158 | img = imgs[idx][i:i + t_h, j:j + t_w, :] 159 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, 160 | value=tuple(0 for i in range(s3_length))) 161 | else: 162 | img = imgs[idx][i:i + t_h, j:j + t_w] 163 | img_p = cv2.copyMakeBorder(img, 0, p_h - t_h, 0, p_w - t_w, borderType=cv2.BORDER_CONSTANT, value=(0,)) 164 | n_imgs.append(img_p) 165 | return n_imgs 166 | 167 | 168 | def update_word_mask(instance, instance_before_crop, word_mask): 169 | labels = np.unique(instance) 170 | 171 | for label in labels: 172 | if label == 0: 173 | continue 174 | ind = instance == label 175 | if np.sum(ind) == 0: 176 | word_mask[label] = 0 177 | continue 178 | ind_before_crop = instance_before_crop == label 179 | # print(np.sum(ind), np.sum(ind_before_crop)) 180 | if float(np.sum(ind)) / np.sum(ind_before_crop) > 0.9: 181 | continue 182 | word_mask[label] = 0 183 | 184 | return word_mask 185 | 186 | 187 | def dist(a, b): 188 | return np.linalg.norm((a - b), ord=2, axis=0) 189 | 190 | 191 | def perimeter(bbox): 192 | peri = 0.0 193 | for i in range(bbox.shape[0]): 194 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 195 | return peri 196 | 197 | 198 | def shrink(bboxes, rate, max_shr=20): 199 | rate = rate * rate 200 | shrinked_bboxes = [] 201 | for bbox in bboxes: 202 | area = plg.Polygon(bbox).area() 203 | peri = perimeter(bbox) 204 | 205 | try: 206 | pco = pyclipper.PyclipperOffset() 207 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 208 | offset = min(int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 209 | 210 | shrinked_bbox = pco.Execute(-offset) 211 | if len(shrinked_bbox) == 0: 212 | shrinked_bboxes.append(bbox) 213 | continue 214 | 215 | shrinked_bbox = np.array(shrinked_bbox[0]) 216 | if shrinked_bbox.shape[0] <= 2: 217 | shrinked_bboxes.append(bbox) 218 | continue 219 | 220 | shrinked_bboxes.append(shrinked_bbox) 221 | except Exception as e: 222 | print(type(shrinked_bbox), shrinked_bbox) 223 | print('area:', area, 'peri:', peri) 224 | shrinked_bboxes.append(bbox) 225 | 226 | return shrinked_bboxes 227 | 228 | 229 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PAD', UNKNOWN='UNK'): 230 | if voc_type == 'LOWERCASE': 231 | voc = list(string.digits + string.ascii_lowercase) 232 | elif voc_type == 'ALLCASES': 233 | voc = list(string.digits + string.ascii_letters) 234 | elif voc_type == 'ALLCASES_SYMBOLS': 235 | voc = list(string.printable[:-6]) 236 | else: 237 | raise KeyError('voc_type must be one of "LOWERCASE", "ALLCASES", "ALLCASES_SYMBOLS"') 238 | 239 | # update the voc with specifical chars 240 | voc.append(EOS) 241 | voc.append(PADDING) 242 | voc.append(UNKNOWN) 243 | 244 | char2id = dict(zip(voc, range(len(voc)))) 245 | id2char = dict(zip(range(len(voc)), voc)) 246 | 247 | return voc, char2id, id2char 248 | 249 | 250 | class PSENET_TT(data.Dataset): 251 | def __init__(self, 252 | split='train', 253 | is_transform=False, 254 | img_size=None, 255 | short_size=736, 256 | kernel_num=7, 257 | min_scale=0.7, 258 | with_rec=False, 259 | read_type='pil', 260 | report_speed=False): 261 | self.split = split 262 | self.is_transform = is_transform 263 | 264 | self.img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size) 265 | self.kernel_num = kernel_num 266 | self.min_scale = min_scale 267 | self.short_size = short_size 268 | self.with_rec = with_rec 269 | self.read_type = read_type 270 | 271 | if split == 'train': 272 | data_dirs = [tt_train_data_dir] 273 | gt_dirs = [tt_train_gt_dir] 274 | elif split == 'test': 275 | data_dirs = [tt_test_data_dir] 276 | gt_dirs = [tt_test_gt_dir] 277 | else: 278 | print('Error: split must be train or test!') 279 | raise 280 | 281 | self.img_paths = [] 282 | self.gt_paths = [] 283 | 284 | for data_dir, gt_dir in zip(data_dirs, gt_dirs): 285 | img_names = [img_name for img_name in mmcv.utils.scandir(data_dir, '.jpg')] 286 | img_names.extend([img_name for img_name in mmcv.utils.scandir(data_dir, '.png')]) 287 | 288 | img_paths = [] 289 | gt_paths = [] 290 | for idx, img_name in enumerate(img_names): 291 | img_path = data_dir + img_name 292 | img_paths.append(img_path) 293 | 294 | gt_name = 'poly_gt_' + img_name.split('.')[0] + '.mat' 295 | gt_path = gt_dir + gt_name 296 | gt_paths.append(gt_path) 297 | 298 | self.img_paths.extend(img_paths) 299 | self.gt_paths.extend(gt_paths) 300 | 301 | if report_speed: 302 | target_size = 3000 303 | data_size = len(self.img_paths) 304 | extend_scale = (target_size + data_size - 1) // data_size 305 | self.img_paths = (self.img_paths * extend_scale)[:target_size] 306 | self.gt_paths = (self.gt_paths * extend_scale)[:target_size] 307 | self.max_word_num = 200 308 | 309 | def __len__(self): 310 | return len(self.img_paths) 311 | 312 | def prepare_train_data(self, index): 313 | img_path = self.img_paths[index] 314 | gt_path = self.gt_paths[index] 315 | 316 | img = get_img(img_path, self.read_type) 317 | bboxes, words = get_ann(img, gt_path) 318 | 319 | if len(bboxes) > self.max_word_num: 320 | bboxes = bboxes[:self.max_word_num] 321 | words = words[:self.max_word_num] 322 | 323 | if self.is_transform: 324 | img = random_scale(img, self.short_size) 325 | 326 | gt_instance = np.zeros(img.shape[0:2], dtype='uint8') 327 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 328 | if len(bboxes) > 0: 329 | for i in range(len(bboxes)): 330 | bboxes[i] = np.reshape(bboxes[i] * ([img.shape[1], img.shape[0]] * (bboxes[i].shape[0] // 2)), 331 | (bboxes[i].shape[0] // 2, 2)).astype('int32') 332 | for i in range(len(bboxes)): 333 | cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) 334 | if words[i] == '###': 335 | cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) 336 | 337 | gt_kernels = [] 338 | 339 | for i in range(1, self.kernel_num): 340 | rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i 341 | gt_kernel = np.zeros(img.shape[0:2], dtype='uint8') 342 | kernel_bboxes = shrink(bboxes, rate) 343 | for i in range(len(bboxes)): 344 | cv2.drawContours(gt_kernel, [kernel_bboxes[i].astype(int)], -1, 1, -1) 345 | gt_kernels.append(gt_kernel) 346 | 347 | if self.is_transform: 348 | imgs = [img, gt_instance, training_mask] 349 | imgs.extend(gt_kernels) 350 | 351 | if not self.with_rec: 352 | imgs = random_horizontal_flip(imgs) 353 | imgs = random_rotate(imgs) 354 | imgs = random_crop_padding(imgs, self.img_size) 355 | img, gt_instance, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:] 356 | 357 | gt_text = gt_instance.copy() 358 | gt_text[gt_text > 0] = 1 359 | gt_kernels = np.array(gt_kernels) 360 | 361 | 362 | if self.is_transform: 363 | img = Image.fromarray(img) 364 | img = img.convert('RGB') 365 | img = transforms.ColorJitter(brightness=32.0 / 255, saturation=0.5)(img) 366 | else: 367 | img = Image.fromarray(img) 368 | img = img.convert('RGB') 369 | 370 | img = transforms.ToTensor()(img) 371 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 372 | gt_text = torch.from_numpy(gt_text).long() 373 | gt_kernels = torch.from_numpy(gt_kernels).long() 374 | training_mask = torch.from_numpy(training_mask).long() 375 | 376 | data = dict( 377 | imgs=img, 378 | gt_texts=gt_text, 379 | gt_kernels=gt_kernels, 380 | training_masks=training_mask, 381 | ) 382 | 383 | return data 384 | 385 | def prepare_test_data(self, index): 386 | img_path = self.img_paths[index] 387 | 388 | img = get_img(img_path, self.read_type) 389 | img_meta = dict( 390 | org_img_size=np.array(img.shape[:2]) 391 | ) 392 | 393 | img = scale_aligned_short(img, self.short_size) 394 | img_meta.update(dict( 395 | img_size=np.array(img.shape[:2]) 396 | )) 397 | 398 | img = Image.fromarray(img) 399 | img = img.convert('RGB') 400 | img = transforms.ToTensor()(img) 401 | img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 402 | 403 | data = dict( 404 | imgs=img, 405 | img_metas=img_meta 406 | ) 407 | 408 | return data 409 | 410 | def __getitem__(self, index): 411 | if self.split == 'train': 412 | return self.prepare_train_data(index) 413 | elif self.split == 'test': 414 | return self.prepare_test_data(index) 415 | -------------------------------------------------------------------------------- /eval/ctw/eval.py: -------------------------------------------------------------------------------- 1 | import file_util 2 | import Polygon as plg 3 | import numpy as np 4 | import mmcv 5 | 6 | project_root = '../../' 7 | 8 | pred_root = project_root + 'outputs/submit_ctw5' 9 | gt_root = project_root + 'data/ctw1500/test/text_label_circum/' 10 | 11 | 12 | def get_pred(path): 13 | lines = file_util.read_file(path).split('\n') 14 | bboxes = [] 15 | for line in lines: 16 | if line == '': 17 | continue 18 | bbox = line.split(',') 19 | if len(bbox) % 2 == 1: 20 | print(path) 21 | bbox = [int(x) for x in bbox] 22 | bboxes.append(bbox) 23 | return bboxes 24 | 25 | 26 | def get_gt(path): 27 | lines = file_util.read_file(path).split('\n') 28 | bboxes = [] 29 | for line in lines: 30 | if line == '': 31 | continue 32 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 33 | # gt = util.str.split(line, ',') 34 | gt = line.split(',') 35 | 36 | x1 = np.int(gt[0]) 37 | y1 = np.int(gt[1]) 38 | 39 | bbox = [np.int(gt[i]) for i in range(4, 32)] 40 | bbox = np.asarray(bbox) + ([x1, y1] * 14) 41 | 42 | bboxes.append(bbox) 43 | return bboxes 44 | 45 | 46 | def get_union(pD, pG): 47 | areaA = pD.area() 48 | areaB = pG.area() 49 | return areaA + areaB - get_intersection(pD, pG); 50 | 51 | 52 | def get_intersection(pD, pG): 53 | pInt = pD & pG 54 | if len(pInt) == 0: 55 | return 0 56 | return pInt.area() 57 | 58 | 59 | if __name__ == '__main__': 60 | th = 0.5 61 | pred_list = file_util.read_dir(pred_root) 62 | 63 | tp, fp, npos = 0, 0, 0 64 | 65 | for pred_path in pred_list: 66 | preds = get_pred(pred_path) 67 | gt_path = gt_root + pred_path.split('/')[-1] 68 | gts = get_gt(gt_path) 69 | npos += len(gts) 70 | 71 | cover = set() 72 | for pred_id, pred in enumerate(preds): 73 | pred = np.array(pred) 74 | pred = pred.reshape(pred.shape[0] / 2, 2)[:, ::-1] 75 | 76 | pred_p = plg.Polygon(pred) 77 | 78 | flag = False 79 | for gt_id, gt in enumerate(gts): 80 | gt = np.array(gt) 81 | gt = gt.reshape(gt.shape[0] / 2, 2) 82 | gt_p = plg.Polygon(gt) 83 | 84 | union = get_union(pred_p, gt_p) 85 | inter = get_intersection(pred_p, gt_p) 86 | 87 | if inter * 1.0 / union >= th: 88 | if gt_id not in cover: 89 | flag = True 90 | cover.add(gt_id) 91 | if flag: 92 | tp += 1.0 93 | else: 94 | fp += 1.0 95 | 96 | # print tp, fp, npos 97 | precision = tp / (tp + fp) 98 | recall = tp / npos 99 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 100 | 101 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 102 | -------------------------------------------------------------------------------- /eval/ctw/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/eval_ctw.sh: -------------------------------------------------------------------------------- 1 | cd ctw && python2 eval.py && cd .. 2 | -------------------------------------------------------------------------------- /eval/eval_ic15.sh: -------------------------------------------------------------------------------- 1 | cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15.zip&& cd .. 2 | -------------------------------------------------------------------------------- /eval/eval_ic15_rec.sh: -------------------------------------------------------------------------------- 1 | cd ic15_rec && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15_rec.zip && cd .. -------------------------------------------------------------------------------- /eval/eval_msra.sh: -------------------------------------------------------------------------------- 1 | cd msra && python2 eval.py && cd .. 2 | -------------------------------------------------------------------------------- /eval/eval_tt.sh: -------------------------------------------------------------------------------- 1 | cd tt && python2 Deteval.py && cd .. -------------------------------------------------------------------------------- /eval/eval_tt_rec.sh: -------------------------------------------------------------------------------- 1 | cd tt_rec && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15_rec.zip && cd .. -------------------------------------------------------------------------------- /eval/ic15/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whai362/PSENet/5a42734dc56df42b7192494933ea8fcb3f486494/eval/ic15/gt.zip -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs_v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | if withTranscription and withConfidence: 179 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | if m == None : 181 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | elif withConfidence: 183 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | if m == None : 185 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | elif withTranscription: 187 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | if m == None : 189 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | else: 191 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | if m == None : 193 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | 197 | validate_clockwise_points(points) 198 | 199 | if (imWidth>0 and imHeight>0): 200 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 201 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 202 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 203 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 204 | 205 | 206 | if withConfidence: 207 | try: 208 | confidence = float(m.group(numPoints+1)) 209 | except ValueError: 210 | raise Exception("Confidence value must be a float") 211 | 212 | if withTranscription: 213 | posTranscription = numPoints + (2 if withConfidence else 1) 214 | transcription = m.group(posTranscription) 215 | m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 216 | if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 217 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 218 | 219 | return points,confidence,transcription 220 | 221 | 222 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 223 | if(x<0 or x>imWidth): 224 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 225 | if(y<0 or y>imHeight): 226 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 227 | 228 | def validate_clockwise_points(points): 229 | """ 230 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 231 | """ 232 | 233 | if len(points) != 8: 234 | raise Exception("Points list not valid." + str(len(points))) 235 | 236 | point = [ 237 | [int(points[0]) , int(points[1])], 238 | [int(points[2]) , int(points[3])], 239 | [int(points[4]) , int(points[5])], 240 | [int(points[6]) , int(points[7])] 241 | ] 242 | edge = [ 243 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 244 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 245 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 246 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 247 | ] 248 | 249 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 250 | if summatory>0: 251 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 252 | 253 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 254 | """ 255 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 256 | xmin,ymin,xmax,ymax,[confidence],[transcription] 257 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 258 | """ 259 | pointsList = [] 260 | transcriptionsList = [] 261 | confidencesList = [] 262 | 263 | lines = content.split( "\r\n" if CRLF else "\n" ) 264 | for line in lines: 265 | line = line.replace("\r","").replace("\n","") 266 | if(line != "") : 267 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 268 | pointsList.append(points) 269 | transcriptionsList.append(transcription) 270 | confidencesList.append(confidence) 271 | 272 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 273 | import numpy as np 274 | sorted_ind = np.argsort(-np.array(confidencesList)) 275 | confidencesList = [confidencesList[i] for i in sorted_ind] 276 | pointsList = [pointsList[i] for i in sorted_ind] 277 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 278 | 279 | return pointsList,confidencesList,transcriptionsList 280 | 281 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 282 | """ 283 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 284 | Params: 285 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 286 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 287 | validate_data_fn: points to a method that validates the corrct format of the submission 288 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 289 | """ 290 | 291 | if (p == None): 292 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 293 | if(len(sys.argv)<3): 294 | print_help() 295 | 296 | evalParams = default_evaluation_params_fn() 297 | if 'p' in p.keys(): 298 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 299 | 300 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 301 | try: 302 | validate_data_fn(p['g'], p['s'], evalParams) 303 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 304 | resDict.update(evalData) 305 | 306 | except Exception, e: 307 | resDict['Message']= str(e) 308 | resDict['calculated']=False 309 | 310 | if 'o' in p: 311 | if not os.path.exists(p['o']): 312 | os.makedirs(p['o']) 313 | 314 | resultsOutputname = p['o'] + '/results.zip' 315 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 316 | 317 | del resDict['per_sample'] 318 | if 'output_items' in resDict.keys(): 319 | del resDict['output_items'] 320 | 321 | outZip.writestr('method.json',json.dumps(resDict)) 322 | 323 | if not resDict['calculated']: 324 | if show_result: 325 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 326 | if 'o' in p: 327 | outZip.close() 328 | return resDict 329 | 330 | if 'o' in p: 331 | if per_sample == True: 332 | for k,v in evalData['per_sample'].iteritems(): 333 | outZip.writestr( k + '.json',json.dumps(v)) 334 | 335 | if 'output_items' in evalData.keys(): 336 | for k, v in evalData['output_items'].iteritems(): 337 | outZip.writestr( k,v) 338 | 339 | outZip.close() 340 | 341 | if show_result: 342 | sys.stdout.write("Calculated!") 343 | sys.stdout.write(json.dumps(resDict['method'])) 344 | 345 | return resDict 346 | 347 | 348 | def main_validation(default_evaluation_params_fn,validate_data_fn): 349 | """ 350 | This process validates a method 351 | Params: 352 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 353 | validate_data_fn: points to a method that validates the corrct format of the submission 354 | """ 355 | try: 356 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 357 | evalParams = default_evaluation_params_fn() 358 | if 'p' in p.keys(): 359 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 360 | 361 | validate_data_fn(p['g'], p['s'], evalParams) 362 | print 'SUCCESS' 363 | sys.exit(0) 364 | except Exception as e: 365 | print str(e) 366 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | # if withTranscription and withConfidence: 179 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | # if m == None : 181 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | # elif withConfidence: 183 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | # if m == None : 185 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | # elif withTranscription: 187 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | # if m == None : 189 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | # else: 191 | # m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | # if m == None : 193 | # raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | # points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | # print line 197 | nums = line.split(',')[:8] 198 | points = [(float)(nums[i]) for i in range(8)] 199 | 200 | # validate_clockwise_points(points) 201 | 202 | if (imWidth>0 and imHeight>0): 203 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 204 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 205 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 206 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 207 | 208 | 209 | # if withConfidence: 210 | # try: 211 | # confidence = float(m.group(numPoints+1)) 212 | # except ValueError: 213 | # raise Exception("Confidence value must be a float") 214 | 215 | # if withTranscription: 216 | # posTranscription = numPoints + (2 if withConfidence else 1) 217 | # transcription = m.group(posTranscription) 218 | # m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 219 | # if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 220 | # transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 221 | 222 | return points,confidence,transcription 223 | 224 | 225 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 226 | if(x<0 or x>imWidth): 227 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 228 | if(y<0 or y>imHeight): 229 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 230 | 231 | def validate_clockwise_points(points): 232 | """ 233 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 234 | """ 235 | 236 | if len(points) != 8: 237 | raise Exception("Points list not valid." + str(len(points))) 238 | 239 | point = [ 240 | [int(points[0]) , int(points[1])], 241 | [int(points[2]) , int(points[3])], 242 | [int(points[4]) , int(points[5])], 243 | [int(points[6]) , int(points[7])] 244 | ] 245 | edge = [ 246 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 247 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 248 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 249 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 250 | ] 251 | 252 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 253 | if summatory>0: 254 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 255 | 256 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 257 | """ 258 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 259 | xmin,ymin,xmax,ymax,[confidence],[transcription] 260 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 261 | """ 262 | pointsList = [] 263 | transcriptionsList = [] 264 | confidencesList = [] 265 | 266 | lines = content.split( "\r\n" if CRLF else "\n" ) 267 | for line in lines: 268 | line = line.replace("\r","").replace("\n","") 269 | if(line != "") : 270 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 271 | pointsList.append(points) 272 | transcriptionsList.append(transcription) 273 | confidencesList.append(confidence) 274 | 275 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 276 | import numpy as np 277 | sorted_ind = np.argsort(-np.array(confidencesList)) 278 | confidencesList = [confidencesList[i] for i in sorted_ind] 279 | pointsList = [pointsList[i] for i in sorted_ind] 280 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 281 | 282 | return pointsList,confidencesList,transcriptionsList 283 | 284 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 285 | """ 286 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 287 | Params: 288 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 289 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 290 | validate_data_fn: points to a method that validates the corrct format of the submission 291 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 292 | """ 293 | 294 | if (p == None): 295 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 296 | if(len(sys.argv)<3): 297 | print_help() 298 | 299 | evalParams = default_evaluation_params_fn() 300 | if 'p' in p.keys(): 301 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 302 | 303 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 304 | try: 305 | validate_data_fn(p['g'], p['s'], evalParams) 306 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 307 | resDict.update(evalData) 308 | 309 | except Exception, e: 310 | resDict['Message']= str(e) 311 | resDict['calculated']=False 312 | 313 | if 'o' in p: 314 | if not os.path.exists(p['o']): 315 | os.makedirs(p['o']) 316 | 317 | resultsOutputname = p['o'] + '/results.zip' 318 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 319 | 320 | del resDict['per_sample'] 321 | if 'output_items' in resDict.keys(): 322 | del resDict['output_items'] 323 | 324 | outZip.writestr('method.json',json.dumps(resDict)) 325 | 326 | if not resDict['calculated']: 327 | if show_result: 328 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 329 | if 'o' in p: 330 | outZip.close() 331 | return resDict 332 | 333 | if 'o' in p: 334 | if per_sample == True: 335 | for k,v in evalData['per_sample'].iteritems(): 336 | outZip.writestr( k + '.json',json.dumps(v)) 337 | 338 | if 'output_items' in evalData.keys(): 339 | for k, v in evalData['output_items'].iteritems(): 340 | outZip.writestr( k,v) 341 | 342 | outZip.close() 343 | 344 | if show_result: 345 | sys.stdout.write("Calculated!") 346 | sys.stdout.write(json.dumps(resDict['method'])) 347 | 348 | return resDict 349 | 350 | 351 | def main_validation(default_evaluation_params_fn,validate_data_fn): 352 | """ 353 | This process validates a method 354 | Params: 355 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 356 | validate_data_fn: points to a method that validates the corrct format of the submission 357 | """ 358 | try: 359 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 360 | evalParams = default_evaluation_params_fn() 361 | if 'p' in p.keys(): 362 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 363 | 364 | validate_data_fn(p['g'], p['s'], evalParams) 365 | print 'SUCCESS' 366 | sys.exit(0) 367 | except Exception as e: 368 | print str(e) 369 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from collections import namedtuple 4 | import rrc_evaluation_funcs 5 | import importlib 6 | 7 | def evaluation_imports(): 8 | """ 9 | evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. 10 | """ 11 | return { 12 | 'Polygon':'plg', 13 | 'numpy':'np' 14 | } 15 | 16 | def default_evaluation_params(): 17 | """ 18 | default_evaluation_params: Default parameters to use for the validation and evaluation. 19 | """ 20 | return { 21 | 'IOU_CONSTRAINT' :0.5, 22 | 'AREA_PRECISION_CONSTRAINT' :0.5, 23 | 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', 24 | 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', 25 | 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 26 | 'CRLF':False, # Lines are delimited by Windows CRLF format 27 | 'CONFIDENCES':False, #Detections must include confidence value. AP will be calculated 28 | 'PER_SAMPLE_RESULTS':True #Generate per sample results and produce data for visualization 29 | } 30 | 31 | def validate_data(gtFilePath, submFilePath,evaluationParams): 32 | """ 33 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 34 | Validates also that there are no missing files in the folder. 35 | If some error detected, the method raises the error 36 | """ 37 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 38 | 39 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 40 | 41 | #Validate format of GroundTruth 42 | for k in gt: 43 | rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) 44 | 45 | #Validate format of results 46 | for k in subm: 47 | if (k in gt) == False : 48 | raise Exception("The sample %s not present in GT" %k) 49 | 50 | rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 51 | 52 | 53 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 54 | """ 55 | Method evaluate_method: evaluate method and returns the results 56 | Results. Dictionary with the following values: 57 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 58 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 59 | """ 60 | 61 | for module,alias in evaluation_imports().iteritems(): 62 | globals()[alias] = importlib.import_module(module) 63 | 64 | def polygon_from_points(points): 65 | """ 66 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 67 | """ 68 | # resBoxes=np.empty([1,8],dtype='int32') 69 | # resBoxes[0,0]=int(points[0]) 70 | # resBoxes[0,4]=int(points[1]) 71 | # resBoxes[0,1]=int(points[2]) 72 | # resBoxes[0,5]=int(points[3]) 73 | # resBoxes[0,2]=int(points[4]) 74 | # resBoxes[0,6]=int(points[5]) 75 | # resBoxes[0,3]=int(points[6]) 76 | # resBoxes[0,7]=int(points[7]) 77 | # pointMat = resBoxes[0].reshape([2,4]).T 78 | # return plg.Polygon( pointMat) 79 | 80 | p = np.array(points) 81 | p = p.reshape(p.shape[0]//2, 2) 82 | p = plg.Polygon(p) 83 | return p 84 | 85 | def rectangle_to_polygon(rect): 86 | resBoxes=np.empty([1,8],dtype='int32') 87 | resBoxes[0,0]=int(rect.xmin) 88 | resBoxes[0,4]=int(rect.ymax) 89 | resBoxes[0,1]=int(rect.xmin) 90 | resBoxes[0,5]=int(rect.ymin) 91 | resBoxes[0,2]=int(rect.xmax) 92 | resBoxes[0,6]=int(rect.ymin) 93 | resBoxes[0,3]=int(rect.xmax) 94 | resBoxes[0,7]=int(rect.ymax) 95 | 96 | pointMat = resBoxes[0].reshape([2,4]).T 97 | 98 | return plg.Polygon( pointMat) 99 | 100 | def rectangle_to_points(rect): 101 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] 102 | return points 103 | 104 | def get_union(pD,pG): 105 | areaA = pD.area(); 106 | areaB = pG.area(); 107 | return areaA + areaB - get_intersection(pD, pG); 108 | 109 | def get_intersection_over_union(pD,pG): 110 | try: 111 | return get_intersection(pD, pG) / get_union(pD, pG); 112 | except: 113 | return 0 114 | 115 | def get_intersection(pD,pG): 116 | pInt = pD & pG 117 | if len(pInt) == 0: 118 | return 0 119 | return pInt.area() 120 | 121 | def compute_ap(confList, matchList,numGtCare): 122 | correct = 0 123 | AP = 0 124 | if len(confList)>0: 125 | confList = np.array(confList) 126 | matchList = np.array(matchList) 127 | sorted_ind = np.argsort(-confList) 128 | confList = confList[sorted_ind] 129 | matchList = matchList[sorted_ind] 130 | for n in range(len(confList)): 131 | match = matchList[n] 132 | if match: 133 | correct += 1 134 | AP += float(correct)/(n + 1) 135 | 136 | if numGtCare>0: 137 | AP /= numGtCare 138 | 139 | return AP 140 | 141 | perSampleMetrics = {} 142 | 143 | matchedSum = 0 144 | 145 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 146 | 147 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 148 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 149 | 150 | numGlobalCareGt = 0; 151 | numGlobalCareDet = 0; 152 | 153 | arrGlobalConfidences = []; 154 | arrGlobalMatches = []; 155 | 156 | for resFile in gt: 157 | 158 | gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) 159 | recall = 0 160 | precision = 0 161 | hmean = 0 162 | 163 | detMatched = 0 164 | 165 | iouMat = np.empty([1,1]) 166 | 167 | gtPols = [] 168 | detPols = [] 169 | 170 | gtPolPoints = [] 171 | detPolPoints = [] 172 | 173 | #Array of Ground Truth Polygons' keys marked as don't Care 174 | gtDontCarePolsNum = [] 175 | #Array of Detected Polygons' matched with a don't Care GT 176 | detDontCarePolsNum = [] 177 | 178 | pairs = [] 179 | detMatchedNums = [] 180 | 181 | arrSampleConfidences = []; 182 | arrSampleMatch = []; 183 | sampleAP = 0; 184 | 185 | evaluationLog = "" 186 | 187 | pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) 188 | for n in range(len(pointsList)): 189 | points = pointsList[n] 190 | transcription = transcriptionsList[n] 191 | dontCare = transcription == "###" 192 | if evaluationParams['LTRB']: 193 | gtRect = Rectangle(*points) 194 | gtPol = rectangle_to_polygon(gtRect) 195 | else: 196 | gtPol = polygon_from_points(points) 197 | gtPols.append(gtPol) 198 | gtPolPoints.append(points) 199 | if dontCare: 200 | gtDontCarePolsNum.append( len(gtPols)-1 ) 201 | 202 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") 203 | 204 | if resFile in subm: 205 | 206 | detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) 207 | def get_pred(file): 208 | lines = file.split('\n') 209 | pointsList = [] 210 | for line in lines: 211 | if line == '': 212 | continue 213 | bbox = line.split(',') 214 | if len(bbox) % 2 == 1: 215 | print(path) 216 | bbox = [int(x) for x in bbox] 217 | pointsList.append(bbox) 218 | return pointsList 219 | 220 | # pointsList,confidencesList,_ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 221 | # print(pointsList) 222 | # print(confidencesList) 223 | 224 | pointsList = get_pred(detFile) 225 | confidencesList = [0.0] * len(pointsList) 226 | 227 | for n in range(len(pointsList)): 228 | points = pointsList[n] 229 | 230 | if evaluationParams['LTRB']: 231 | detRect = Rectangle(*points) 232 | detPol = rectangle_to_polygon(detRect) 233 | else: 234 | detPol = polygon_from_points(points) 235 | detPols.append(detPol) 236 | detPolPoints.append(points) 237 | if len(gtDontCarePolsNum)>0 : 238 | for dontCarePol in gtDontCarePolsNum: 239 | dontCarePol = gtPols[dontCarePol] 240 | intersected_area = get_intersection(dontCarePol,detPol) 241 | pdDimensions = detPol.area() 242 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 243 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): 244 | detDontCarePolsNum.append( len(detPols)-1 ) 245 | break 246 | 247 | evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") 248 | 249 | if len(gtPols)>0 and len(detPols)>0: 250 | #Calculate IoU and precision matrixs 251 | outputShape=[len(gtPols),len(detPols)] 252 | iouMat = np.empty(outputShape) 253 | gtRectMat = np.zeros(len(gtPols),np.int8) 254 | detRectMat = np.zeros(len(detPols),np.int8) 255 | for gtNum in range(len(gtPols)): 256 | for detNum in range(len(detPols)): 257 | pG = gtPols[gtNum] 258 | pD = detPols[detNum] 259 | iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) 260 | 261 | for gtNum in range(len(gtPols)): 262 | for detNum in range(len(detPols)): 263 | if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : 264 | if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: 265 | gtRectMat[gtNum] = 1 266 | detRectMat[detNum] = 1 267 | detMatched += 1 268 | pairs.append({'gt':gtNum,'det':detNum}) 269 | detMatchedNums.append(detNum) 270 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 271 | 272 | if evaluationParams['CONFIDENCES']: 273 | for detNum in range(len(detPols)): 274 | if detNum not in detDontCarePolsNum : 275 | #we exclude the don't care detections 276 | match = detNum in detMatchedNums 277 | 278 | arrSampleConfidences.append(confidencesList[detNum]) 279 | arrSampleMatch.append(match) 280 | 281 | arrGlobalConfidences.append(confidencesList[detNum]); 282 | arrGlobalMatches.append(match); 283 | 284 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 285 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 286 | if numGtCare == 0: 287 | recall = float(1) 288 | precision = float(0) if numDetCare >0 else float(1) 289 | sampleAP = precision 290 | else: 291 | recall = float(detMatched) / numGtCare 292 | precision = 0 if numDetCare==0 else float(detMatched) / numDetCare 293 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 294 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) 295 | 296 | hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) 297 | 298 | matchedSum += detMatched 299 | numGlobalCareGt += numGtCare 300 | numGlobalCareDet += numDetCare 301 | 302 | if evaluationParams['PER_SAMPLE_RESULTS']: 303 | perSampleMetrics[resFile] = { 304 | 'precision':precision, 305 | 'recall':recall, 306 | 'hmean':hmean, 307 | 'pairs':pairs, 308 | 'AP':sampleAP, 309 | 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), 310 | 'gtPolPoints':gtPolPoints, 311 | 'detPolPoints':detPolPoints, 312 | 'gtDontCare':gtDontCarePolsNum, 313 | 'detDontCare':detDontCarePolsNum, 314 | 'evaluationParams': evaluationParams, 315 | 'evaluationLog': evaluationLog 316 | } 317 | 318 | # Compute MAP and MAR 319 | AP = 0 320 | if evaluationParams['CONFIDENCES']: 321 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 322 | 323 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt 324 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet 325 | methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) 326 | 327 | methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } 328 | 329 | resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} 330 | 331 | return resDict 332 | 333 | 334 | 335 | if __name__=='__main__': 336 | 337 | rrc_evaluation_funcs.main_evaluation(None,default_evaluation_params,validate_data,evaluate_method) 338 | print('') 339 | -------------------------------------------------------------------------------- /eval/ic15/script_self_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | # parser.add_argument('--gt', nargs='?', type=str, default=None) 7 | parser.add_argument('--pred', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | output_root = '../outputs/tmp_results/' 11 | pred = mmcv.load(args.pred) 12 | 13 | def write_result_as_txt(image_name, bboxes, path, words=None): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | file_path = path + 'res_%s.txt'%(image_name) 18 | lines = [] 19 | for i, bbox in enumerate(bboxes): 20 | values = [int(v) for v in bbox] 21 | if words is None: 22 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n"%tuple(values) 23 | lines.append(line) 24 | elif words[i] is not None: 25 | line = "%d,%d,%d,%d,%d,%d,%d,%d"%tuple(values) + ",%s\n"%words[i] 26 | lines.append(line) 27 | with open(file_path, 'w') as f: 28 | for line in lines: 29 | f.write(line) 30 | 31 | def eval(thr): 32 | for key in pred: 33 | pred_ = pred[key] 34 | line_num = len(pred_['scores']) 35 | bboxes = [] 36 | # words = [] 37 | for i in range(line_num): 38 | if pred_['scores'][i] < thr: 39 | continue 40 | bboxes.append(pred_['bboxes'][i]) 41 | # words.append(pred_['words'][i]) 42 | 43 | write_result_as_txt(key, bboxes, output_root) 44 | 45 | cmd = 'cd %s;zip -j %s %s/*' % ('../outputs/', 'tmp_results.zip', 'tmp_results') 46 | res_cmd = os.popen(cmd) 47 | res_cmd.read() 48 | 49 | cmd = 'cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/tmp_results.zip && cd ..' 50 | res_cmd = os.popen(cmd) 51 | res_cmd = res_cmd.read() 52 | h_mean = float(res_cmd.split(',')[-2].split(':')[-1]) 53 | return res_cmd, h_mean 54 | 55 | max_h_mean = 0 56 | best_thr = 0 57 | best_res = '' 58 | for i in range(85, 100): 59 | thr = float(i) / 100 60 | # print('Testing thr: %f'%thr) 61 | res, h_mean = eval(thr) 62 | # print(thr, h_mean) 63 | if h_mean > max_h_mean: 64 | max_h_mean = h_mean 65 | best_thr = thr 66 | best_res = res 67 | 68 | print('thr: %f | %s'%(best_thr, best_res)) 69 | -------------------------------------------------------------------------------- /eval/ic15_rec/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whai362/PSENet/5a42734dc56df42b7192494933ea8fcb3f486494/eval/ic15_rec/gt.zip -------------------------------------------------------------------------------- /eval/ic15_rec/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /eval/ic15_rec/script_self_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | # parser.add_argument('--gt', nargs='?', type=str, default=None) 7 | parser.add_argument('--pred', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | output_root = '../outputs/tmp_results/' 11 | pred = mmcv.load(args.pred) 12 | 13 | def write_result_as_txt(image_name, bboxes, path, words=None): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | file_path = path + 'res_%s.txt'%(image_name) 18 | lines = [] 19 | for i, bbox in enumerate(bboxes): 20 | values = [int(v) for v in bbox] 21 | if words is None: 22 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n"%tuple(values) 23 | lines.append(line) 24 | elif words[i] is not None: 25 | line = "%d,%d,%d,%d,%d,%d,%d,%d"%tuple(values) + ",%s\n"%words[i] 26 | lines.append(line) 27 | with open(file_path, 'w') as f: 28 | for line in lines: 29 | f.write(line.encode('utf-8')) 30 | 31 | def eval(thr): 32 | for key in pred: 33 | pred_ = pred[key] 34 | line_num = len(pred_['scores']) 35 | bboxes = [] 36 | words = [] 37 | for i in range(line_num): 38 | if pred_['word_scores'][i] < thr: 39 | continue 40 | bboxes.append(pred_['bboxes'][i]) 41 | words.append(pred_['words'][i]) 42 | 43 | write_result_as_txt(key, bboxes, output_root, words) 44 | 45 | cmd = 'cd %s;zip -j %s %s/*' % ('../outputs/', 'tmp_results.zip', 'tmp_results') 46 | res_cmd = os.popen(cmd) 47 | res_cmd.read() 48 | 49 | cmd = 'cd ic15_rec && python2 script.py -g=gt.zip -s=../../outputs/tmp_results.zip && cd ..' 50 | res_cmd = os.popen(cmd) 51 | res_cmd = res_cmd.read() 52 | h_mean = float(res_cmd.split(',')[-2].split(':')[-1]) 53 | return res_cmd, h_mean 54 | 55 | max_h_mean = 0 56 | best_thr = 0 57 | best_res = '' 58 | for i in range(80, 100): 59 | thr = float(i) / 100 60 | # print('Testing thr: %f'%thr) 61 | res, h_mean = eval(thr) 62 | if h_mean >= max_h_mean: 63 | max_h_mean = h_mean 64 | best_thr = thr 65 | best_res = res 66 | 67 | print('thr: %f | %s'%(best_thr, best_res)) 68 | -------------------------------------------------------------------------------- /eval/msra/eval.py: -------------------------------------------------------------------------------- 1 | import file_util 2 | import Polygon as plg 3 | import numpy as np 4 | import math 5 | import cv2 6 | 7 | project_root = '../../' 8 | 9 | pred_root = project_root + 'outputs/submit_msra/' 10 | gt_root = project_root + 'data/MSRA-TD500/test/' 11 | 12 | 13 | def get_pred(path): 14 | lines = file_util.read_file(path).split('\n') 15 | bboxes = [] 16 | for line in lines: 17 | if line == '': 18 | continue 19 | bbox = line.split(',') 20 | if len(bbox) % 2 == 1: 21 | print(path) 22 | bbox = [int(x) for x in bbox] 23 | bboxes.append(bbox) 24 | return bboxes 25 | 26 | 27 | def get_gt(path): 28 | lines = file_util.read_file(path).split('\n') 29 | bboxes = [] 30 | tags = [] 31 | for line in lines: 32 | if line == '': 33 | continue 34 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 35 | # gt = util.str.split(line, ' ') 36 | gt = line.split(' ') 37 | 38 | w_ = np.float(gt[4]) 39 | h_ = np.float(gt[5]) 40 | x1 = np.float(gt[2]) + w_ / 2.0 41 | y1 = np.float(gt[3]) + h_ / 2.0 42 | theta = np.float(gt[6]) / math.pi * 180 43 | 44 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 45 | bbox = bbox.reshape(-1) 46 | 47 | bboxes.append(bbox) 48 | tags.append(np.int(gt[1])) 49 | return np.array(bboxes), tags 50 | 51 | 52 | def get_union(pD, pG): 53 | areaA = pD.area() 54 | areaB = pG.area() 55 | return areaA + areaB - get_intersection(pD, pG) 56 | 57 | 58 | def get_intersection(pD, pG): 59 | pInt = pD & pG 60 | if len(pInt) == 0: 61 | return 0 62 | return pInt.area() 63 | 64 | 65 | if __name__ == '__main__': 66 | th = 0.5 67 | pred_list = file_util.read_dir(pred_root) 68 | 69 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 70 | for pred_path in pred_list: 71 | count = count + 1 72 | preds = get_pred(pred_path) 73 | gt_path = gt_root + pred_path.split('/')[-1].split('.')[0] + '.gt' 74 | gts, tags = get_gt(gt_path) 75 | 76 | ta = ta + len(preds) 77 | for gt, tag in zip(gts, tags): 78 | gt = np.array(gt) 79 | gt = gt.reshape(gt.shape[0] / 2, 2) 80 | gt_p = plg.Polygon(gt) 81 | difficult = tag 82 | flag = 0 83 | for pred in preds: 84 | pred = np.array(pred) 85 | pred = pred.reshape(pred.shape[0] / 2, 2) 86 | pred_p = plg.Polygon(pred) 87 | 88 | union = get_union(pred_p, gt_p) 89 | inter = get_intersection(pred_p, gt_p) 90 | iou = float(inter) / union 91 | if iou >= th: 92 | flag = 1 93 | tp = tp + 1 94 | break 95 | 96 | if flag == 0 and difficult == 0: 97 | fp = fp + 1 98 | 99 | recall = float(tp) / (tp + fp) 100 | precision = float(tp) / ta 101 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 102 | 103 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 104 | -------------------------------------------------------------------------------- /eval/msra/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/tt/polygon_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from skimage.draw import polygon 4 | 5 | """ 6 | :param det_x: [1, N] Xs of detection's vertices 7 | :param det_y: [1, N] Ys of detection's vertices 8 | :param gt_x: [1, N] Xs of groundtruth's vertices 9 | :param gt_y: [1, N] Ys of groundtruth's vertices 10 | ############## 11 | All the calculation of 'AREA' in this script is handled by: 12 | 1) First generating a binary mask with the polygon area filled up with 1's 13 | 2) Summing up all the 1's 14 | """ 15 | 16 | 17 | def area(x, y): 18 | """ 19 | This helper calculates the area given x and y vertices. 20 | """ 21 | ymax = np.max(y) 22 | xmax = np.max(x) 23 | bin_mask = np.zeros((ymax, xmax)) 24 | rr, cc = polygon(y, x) 25 | bin_mask[rr, cc] = 1 26 | area = np.sum(bin_mask) 27 | return area 28 | #return np.round(area, 2) 29 | 30 | 31 | def approx_area_of_intersection(det_x, det_y, gt_x, gt_y): 32 | """ 33 | This helper determine if both polygons are intersecting with each others with an approximation method. 34 | Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax] 35 | """ 36 | det_ymax = np.max(det_y) 37 | det_xmax = np.max(det_x) 38 | det_ymin = np.min(det_y) 39 | det_xmin = np.min(det_x) 40 | 41 | gt_ymax = np.max(gt_y) 42 | gt_xmax = np.max(gt_x) 43 | gt_ymin = np.min(gt_y) 44 | gt_xmin = np.min(gt_x) 45 | 46 | all_min_ymax = np.minimum(det_ymax, gt_ymax) 47 | all_max_ymin = np.maximum(det_ymin, gt_ymin) 48 | 49 | intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin)) 50 | 51 | all_min_xmax = np.minimum(det_xmax, gt_xmax) 52 | all_max_xmin = np.maximum(det_xmin, gt_xmin) 53 | intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin)) 54 | 55 | return intersect_heights * intersect_widths 56 | 57 | def area_of_intersection(det_x, det_y, gt_x, gt_y): 58 | """ 59 | This helper calculates the area of intersection. 60 | """ 61 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 62 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 63 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 64 | bin_mask = np.zeros((ymax, xmax)) 65 | det_bin_mask = np.zeros_like(bin_mask) 66 | gt_bin_mask = np.zeros_like(bin_mask) 67 | 68 | rr, cc = polygon(det_y, det_x) 69 | det_bin_mask[rr, cc] = 1 70 | 71 | rr, cc = polygon(gt_y, gt_x) 72 | gt_bin_mask[rr, cc] = 1 73 | 74 | final_bin_mask = det_bin_mask + gt_bin_mask 75 | 76 | inter_map = np.where(final_bin_mask == 2, 1, 0) 77 | inter = np.sum(inter_map) 78 | return inter 79 | # return np.round(inter, 2) 80 | else: 81 | return 0 82 | 83 | 84 | def iou(det_x, det_y, gt_x, gt_y): 85 | """ 86 | This helper determine the intersection over union of two polygons. 87 | """ 88 | 89 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 90 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 91 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 92 | bin_mask = np.zeros((ymax, xmax)) 93 | det_bin_mask = np.zeros_like(bin_mask) 94 | gt_bin_mask = np.zeros_like(bin_mask) 95 | 96 | rr, cc = polygon(det_y, det_x) 97 | det_bin_mask[rr, cc] = 1 98 | 99 | rr, cc = polygon(gt_y, gt_x) 100 | gt_bin_mask[rr, cc] = 1 101 | 102 | final_bin_mask = det_bin_mask + gt_bin_mask 103 | 104 | #inter_map = np.zeros_like(final_bin_mask) 105 | inter_map = np.where(final_bin_mask == 2, 1, 0) 106 | inter = np.sum(inter_map) 107 | 108 | #union_map = np.zeros_like(final_bin_mask) 109 | union_map = np.where(final_bin_mask > 0, 1, 0) 110 | union = np.sum(union_map) 111 | return inter / float(union + 1.0) 112 | #return np.round(inter / float(union + 1.0), 2) 113 | else: 114 | return 0 115 | 116 | def iod(det_x, det_y, gt_x, gt_y): 117 | """ 118 | This helper determine the fraction of intersection area over detection area 119 | """ 120 | 121 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 122 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 123 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 124 | bin_mask = np.zeros((ymax, xmax)) 125 | det_bin_mask = np.zeros_like(bin_mask) 126 | gt_bin_mask = np.zeros_like(bin_mask) 127 | 128 | rr, cc = polygon(det_y, det_x) 129 | det_bin_mask[rr, cc] = 1 130 | 131 | rr, cc = polygon(gt_y, gt_x) 132 | gt_bin_mask[rr, cc] = 1 133 | 134 | final_bin_mask = det_bin_mask + gt_bin_mask 135 | 136 | inter_map = np.where(final_bin_mask == 2, 1, 0) 137 | inter = np.round(np.sum(inter_map), 2) 138 | 139 | det = np.round(np.sum(det_bin_mask), 2) 140 | return inter / float(det + 1.0) 141 | #return np.round(inter / float(det + 1.0), 2) 142 | else: 143 | return 0 144 | -------------------------------------------------------------------------------- /eval/tt_rec/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whai362/PSENet/5a42734dc56df42b7192494933ea8fcb3f486494/eval/tt_rec/gt.zip -------------------------------------------------------------------------------- /eval/tt_rec/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whai362/PSENet/5a42734dc56df42b7192494933ea8fcb3f486494/logo.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .psenet import PSENet 2 | from .builder import build_model 3 | 4 | __all__ = ['PSENet'] 5 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet50, resnet101 2 | from .builder import build_backbone 3 | 4 | __all__ = ['resnet18', 'resnet50', 'resnet101'] 5 | -------------------------------------------------------------------------------- /models/backbone/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_backbone(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | backbone = models.backbone.__dict__[cfg.type](**param) 12 | 13 | return backbone 14 | -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | 7 | try: 8 | from urllib import urlretrieve 9 | except ImportError: 10 | from urllib.request import urlretrieve 11 | 12 | __all__ = ['resnet18', 'resnet50', 'resnet101'] 13 | 14 | model_urls = { 15 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 16 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 17 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class Convkxk(nn.Module): 100 | def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): 101 | super(Convkxk, self).__init__() 102 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 103 | bias=False) 104 | self.bn = nn.BatchNorm2d(out_planes) 105 | self.relu = nn.ReLU(inplace=True) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def forward(self, x): 116 | return self.relu(self.bn(self.conv(x))) 117 | 118 | 119 | class ResNet(nn.Module): 120 | 121 | def __init__(self, block, layers): 122 | super(ResNet, self).__init__() 123 | self.inplanes = 128 124 | self.conv1 = conv3x3(3, 64, stride=2) 125 | self.bn1 = nn.BatchNorm2d(64) 126 | self.relu1 = nn.ReLU(inplace=True) 127 | self.conv2 = conv3x3(64, 64) 128 | self.bn2 = nn.BatchNorm2d(64) 129 | self.relu2 = nn.ReLU(inplace=True) 130 | self.conv3 = conv3x3(64, 128) 131 | self.bn3 = nn.BatchNorm2d(128) 132 | self.relu3 = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | 135 | self.layer1 = self._make_layer(block, 64, layers[0]) 136 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | 148 | def _make_layer(self, block, planes, blocks, stride=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | nn.Conv2d(self.inplanes, planes * block.expansion, 153 | kernel_size=1, stride=stride, bias=False), 154 | nn.BatchNorm2d(planes * block.expansion), 155 | ) 156 | 157 | layers = [] 158 | layers.append(block(self.inplanes, planes, stride, downsample)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, blocks): 161 | layers.append(block(self.inplanes, planes)) 162 | 163 | return nn.Sequential(*layers) 164 | 165 | def forward(self, x): 166 | x = self.relu1(self.bn1(self.conv1(x))) 167 | x = self.relu2(self.bn2(self.conv2(x))) 168 | x = self.relu3(self.bn3(self.conv3(x))) 169 | x = self.maxpool(x) 170 | 171 | f = [] 172 | x = self.layer1(x) 173 | f.append(x) 174 | x = self.layer2(x) 175 | f.append(x) 176 | x = self.layer3(x) 177 | f.append(x) 178 | x = self.layer4(x) 179 | f.append(x) 180 | 181 | return tuple(f) 182 | 183 | # x = self.avgpool(x) 184 | # x = x.view(x.size(0), -1) 185 | # x = self.fc(x) 186 | 187 | # return x 188 | 189 | 190 | def resnet18(pretrained=False, **kwargs): 191 | """Constructs a ResNet-18 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on Places 195 | """ 196 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(load_url(model_urls['resnet18']), strict=False) 199 | return model 200 | 201 | 202 | def resnet50(pretrained=False, **kwargs): 203 | """Constructs a ResNet-50 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on Places 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 211 | return model 212 | 213 | 214 | def resnet101(pretrained=False, **kwargs): 215 | """Constructs a ResNet-101 model. 216 | 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on Places 219 | """ 220 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 221 | if pretrained: 222 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 223 | return model 224 | 225 | 226 | def load_url(url, model_dir='./pretrained', map_location=None): 227 | if not os.path.exists(model_dir): 228 | os.makedirs(model_dir) 229 | filename = url.split('/')[-1] 230 | cached_file = os.path.join(model_dir, filename) 231 | if not os.path.exists(cached_file): 232 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 233 | urlretrieve(url, cached_file) 234 | return torch.load(cached_file, map_location=map_location) 235 | -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_model(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | model = models.__dict__[cfg.type](**param) 12 | 13 | return model 14 | -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .psenet_head import PSENet_Head 2 | from .builder import build_head 3 | 4 | __all__ = ['PSENet_Head'] 5 | -------------------------------------------------------------------------------- /models/head/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_head(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | head = models.head.__dict__[cfg.type](**param) 12 | 13 | return head 14 | -------------------------------------------------------------------------------- /models/head/psenet_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import cv2 7 | import time 8 | from ..loss import build_loss, ohem_batch, iou 9 | from ..post_processing import pse 10 | 11 | 12 | class PSENet_Head(nn.Module): 13 | def __init__(self, 14 | in_channels, 15 | hidden_dim, 16 | num_classes, 17 | loss_text, 18 | loss_kernel): 19 | super(PSENet_Head, self).__init__() 20 | self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) 21 | self.bn1 = nn.BatchNorm2d(hidden_dim) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | 24 | self.conv2 = nn.Conv2d(hidden_dim, num_classes, kernel_size=1, stride=1, padding=0) 25 | 26 | self.text_loss = build_loss(loss_text) 27 | self.kernel_loss = build_loss(loss_kernel) 28 | 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 32 | m.weight.data.normal_(0, math.sqrt(2. / n)) 33 | elif isinstance(m, nn.BatchNorm2d): 34 | m.weight.data.fill_(1) 35 | m.bias.data.zero_() 36 | 37 | def forward(self, f): 38 | out = self.conv1(f) 39 | out = self.relu1(self.bn1(out)) 40 | out = self.conv2(out) 41 | 42 | return out 43 | 44 | def get_results(self, out, img_meta, cfg): 45 | outputs = dict() 46 | 47 | if not self.training and cfg.report_speed: 48 | torch.cuda.synchronize() 49 | start = time.time() 50 | 51 | score = torch.sigmoid(out[:, 0, :, :]) 52 | # out = (torch.sign(out - 1) + 1) / 2 # 0 1 53 | # 54 | # text_mask = out[:, 0, :, :] 55 | # kernels = out[:, 1:cfg.test_cfg.kernel_num, :, :] * text_mask 56 | 57 | kernels = out[:, :cfg.test_cfg.kernel_num, :, :] > 0 58 | text_mask = kernels[:, :1, :, :] 59 | kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask 60 | 61 | score = score.data.cpu().numpy()[0].astype(np.float32) 62 | kernels = kernels.data.cpu().numpy()[0].astype(np.uint8) 63 | # kernel_1 = kernels[1] 64 | # kernel_2 = kernels[2] 65 | # kernel_3 = kernels[3] 66 | # kernel_4 = kernels[4] 67 | # kernel_5 = kernels[5] 68 | # kernel_6 = kernels[6] 69 | # 70 | # kernel_1 = kernel_1.reshape(736, 1120, 1) 71 | # kernel_2 = kernel_2.reshape(736, 1120, 1) 72 | # kernel_3 = kernel_3.reshape(736, 1120, 1) 73 | # kernel_4 = kernel_4.reshape(736, 1120, 1) 74 | # kernel_5 = kernel_5.reshape(736, 1120, 1) 75 | # kernel_6 = kernel_6.reshape(736, 1120, 1) 76 | # 77 | # kernel_1 = np.concatenate((kernel_1, kernel_1, kernel_1), axis=2) * 255 78 | # kernel_2 = np.concatenate((kernel_2, kernel_2, kernel_2), axis=2) * 255 79 | # kernel_3 = np.concatenate((kernel_3, kernel_3, kernel_3), axis=2) * 255 80 | # kernel_4 = np.concatenate((kernel_4, kernel_4, kernel_4), axis=2) * 255 81 | # kernel_5 = np.concatenate((kernel_5, kernel_5, kernel_5), axis=2) * 255 82 | # kernel_6 = np.concatenate((kernel_6, kernel_6, kernel_6), axis=2) * 255 83 | # 84 | # kernel_1 = cv2.copyMakeBorder(kernel_1, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 85 | # kernel_2 = cv2.copyMakeBorder(kernel_2, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 86 | # kernel_3 = cv2.copyMakeBorder(kernel_3, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 87 | # kernel_4 = cv2.copyMakeBorder(kernel_4, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 88 | # kernel_5 = cv2.copyMakeBorder(kernel_5, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 89 | # kernel_6 = cv2.copyMakeBorder(kernel_6, 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 90 | # 91 | # res = np.concatenate((kernel_1, kernel_2, kernel_3, kernel_4, kernel_5, kernel_6), axis=1) 92 | # print('saved kernels.') 93 | # cv2.imwrite('vis_kernels.png', res) 94 | # exit() 95 | 96 | label = pse(kernels, cfg.test_cfg.min_area) 97 | 98 | # image size 99 | org_img_size = img_meta['org_img_size'][0] 100 | img_size = img_meta['img_size'][0] 101 | 102 | label_num = np.max(label) + 1 103 | label = cv2.resize(label, (img_size[1], img_size[0]), interpolation=cv2.INTER_NEAREST) 104 | score = cv2.resize(score, (img_size[1], img_size[0]), interpolation=cv2.INTER_NEAREST) 105 | 106 | if not self.training and cfg.report_speed: 107 | torch.cuda.synchronize() 108 | outputs.update(dict( 109 | det_pse_time=time.time() - start 110 | )) 111 | 112 | scale = (float(org_img_size[1]) / float(img_size[1]), 113 | float(org_img_size[0]) / float(img_size[0])) 114 | 115 | bboxes = [] 116 | scores = [] 117 | for i in range(1, label_num): 118 | ind = label == i 119 | points = np.array(np.where(ind)).transpose((1, 0)) 120 | 121 | if points.shape[0] < cfg.test_cfg.min_area: 122 | label[ind] = 0 123 | continue 124 | 125 | score_i = np.mean(score[ind]) 126 | if score_i < cfg.test_cfg.min_score: 127 | label[ind] = 0 128 | continue 129 | 130 | if cfg.test_cfg.bbox_type == 'rect': 131 | rect = cv2.minAreaRect(points[:, ::-1]) 132 | bbox = cv2.boxPoints(rect) * scale 133 | elif cfg.test_cfg.bbox_type == 'poly': 134 | binary = np.zeros(label.shape, dtype='uint8') 135 | binary[ind] = 1 136 | _, contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 137 | bbox = contours[0] * scale 138 | 139 | bbox = bbox.astype('int32') 140 | bboxes.append(bbox.reshape(-1)) 141 | scores.append(score_i) 142 | 143 | outputs.update(dict( 144 | bboxes=bboxes, 145 | scores=scores 146 | )) 147 | 148 | return outputs 149 | 150 | def loss(self, out, gt_texts, gt_kernels, training_masks): 151 | # output 152 | texts = out[:, 0, :, :] 153 | kernels = out[:, 1:, :, :] 154 | # text loss 155 | selected_masks = ohem_batch(texts, gt_texts, training_masks) 156 | 157 | loss_text = self.text_loss(texts, gt_texts, selected_masks, reduce=False) 158 | iou_text = iou((texts > 0).long(), gt_texts, training_masks, reduce=False) 159 | losses = dict( 160 | loss_text=loss_text, 161 | iou_text=iou_text 162 | ) 163 | 164 | # kernel loss 165 | loss_kernels = [] 166 | selected_masks = gt_texts * training_masks 167 | for i in range(kernels.size(1)): 168 | kernel_i = kernels[:, i, :, :] 169 | gt_kernel_i = gt_kernels[:, i, :, :] 170 | loss_kernel_i = self.kernel_loss(kernel_i, gt_kernel_i, selected_masks, reduce=False) 171 | loss_kernels.append(loss_kernel_i) 172 | loss_kernels = torch.mean(torch.stack(loss_kernels, dim=1), dim=1) 173 | iou_kernel = iou( 174 | (kernels[:, -1, :, :] > 0).long(), gt_kernels[:, -1, :, :], training_masks * gt_texts, reduce=False) 175 | losses.update(dict( 176 | loss_kernels=loss_kernels, 177 | iou_kernel=iou_kernel 178 | )) 179 | 180 | return losses 181 | -------------------------------------------------------------------------------- /models/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .dice_loss import DiceLoss 2 | from .builder import build_loss 3 | from .ohem import ohem_batch 4 | from .iou import iou 5 | from .acc import acc 6 | 7 | __all__ = ['DiceLoss'] 8 | -------------------------------------------------------------------------------- /models/loss/acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-6 4 | 5 | def acc_single(a, b, mask): 6 | ind = mask == 1 7 | if torch.sum(ind) == 0: 8 | return 0 9 | correct = (a[ind] == b[ind]).float() 10 | acc = torch.sum(correct) / correct.size(0) 11 | return acc 12 | 13 | def acc(a, b, mask, reduce=True): 14 | batch_size = a.size(0) 15 | 16 | a = a.view(batch_size, -1) 17 | b = b.view(batch_size, -1) 18 | mask = mask.view(batch_size, -1) 19 | 20 | acc = a.new_zeros((batch_size,), dtype=torch.float32) 21 | for i in range(batch_size): 22 | acc[i] = acc_single(a[i], b[i], mask[i]) 23 | 24 | if reduce: 25 | acc = torch.mean(acc) 26 | return acc -------------------------------------------------------------------------------- /models/loss/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_loss(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | loss = models.loss.__dict__[cfg.type](**param) 12 | 13 | return loss 14 | -------------------------------------------------------------------------------- /models/loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, loss_weight=1.0): 7 | super(DiceLoss, self).__init__() 8 | self.loss_weight = loss_weight 9 | 10 | def forward(self, input, target, mask, reduce=True): 11 | batch_size = input.size(0) 12 | 13 | input = torch.sigmoid(input) 14 | 15 | input = input.contiguous().view(batch_size, -1) 16 | target = target.contiguous().view(batch_size, -1).float() 17 | mask = mask.contiguous().view(batch_size, -1).float() 18 | 19 | 20 | input = input * mask 21 | target = target * mask 22 | 23 | a = torch.sum(input * target, dim=1) 24 | b = torch.sum(input * input, dim=1) + 0.001 25 | c = torch.sum(target * target, dim=1) + 0.001 26 | d = (2 * a) / (b + c) 27 | loss = 1 - d 28 | 29 | loss = self.loss_weight * loss 30 | 31 | 32 | 33 | if reduce: 34 | loss = torch.mean(loss) 35 | 36 | 37 | return loss 38 | -------------------------------------------------------------------------------- /models/loss/emb_loss_v1.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.autograd import Function, Variable 7 | 8 | 9 | class EmbLoss_v1(nn.Module): 10 | def __init__(self, feature_dim=4, loss_weight=1.0): 11 | super(EmbLoss_v1, self).__init__() 12 | self.feature_dim = feature_dim 13 | self.loss_weight = loss_weight 14 | self.delta_v = 0.5 15 | self.delta_d = 1.5 16 | self.weights = (1.0, 1.0) 17 | 18 | def forward_single(self, emb, instance, kernel, training_mask, bboxes): 19 | training_mask = (training_mask > 0.5).long() 20 | kernel = (kernel > 0.5).long() 21 | instance = instance * training_mask 22 | instance_kernel = (instance * kernel).view(-1) 23 | instance = instance.view(-1) 24 | emb = emb.view(self.feature_dim, -1) 25 | 26 | unique_labels, unique_ids = torch.unique(instance_kernel, sorted=True, return_inverse=True) 27 | num_instance = unique_labels.size(0) 28 | if num_instance <= 1: 29 | return 0 30 | 31 | emb_mean = emb.new_zeros((self.feature_dim, num_instance), dtype=torch.float32) 32 | for i, lb in enumerate(unique_labels): 33 | if lb == 0: 34 | continue 35 | ind_k = instance_kernel == lb 36 | emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) 37 | 38 | l_agg = emb.new_zeros(num_instance, dtype=torch.float32) # bug 39 | for i, lb in enumerate(unique_labels): 40 | if lb == 0: 41 | continue 42 | ind = instance == lb 43 | emb_ = emb[:, ind] 44 | dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0) 45 | dist = F.relu(dist - self.delta_v) ** 2 46 | l_agg[i] = torch.mean(torch.log(dist + 1.0)) 47 | l_agg = torch.mean(l_agg[1:]) 48 | 49 | if num_instance > 2: 50 | emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) 51 | emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view(-1, self.feature_dim) 52 | # print(seg_band) 53 | 54 | mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view(-1, 1).repeat(1, self.feature_dim) 55 | mask = mask.view(num_instance, num_instance, -1) 56 | mask[0, :, :] = 0 57 | mask[:, 0, :] = 0 58 | mask = mask.view(num_instance * num_instance, -1) 59 | # print(mask) 60 | 61 | dist = emb_interleave - emb_band 62 | dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1) 63 | dist = F.relu(2 * self.delta_d - dist) ** 2 64 | l_dis = torch.mean(torch.log(dist + 1.0)) 65 | else: 66 | l_dis = 0 67 | 68 | l_agg = self.weights[0] * l_agg 69 | l_dis = self.weights[1] * l_dis 70 | l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 71 | loss = l_agg + l_dis + l_reg 72 | return loss 73 | 74 | def forward(self, emb, instance, kernel, training_mask, bboxes, reduce=True): 75 | loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) 76 | 77 | for i in range(loss_batch.size(0)): 78 | loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], training_mask[i], bboxes[i]) 79 | 80 | loss_batch = self.loss_weight * loss_batch 81 | 82 | if reduce: 83 | loss_batch = torch.mean(loss_batch) 84 | 85 | return loss_batch 86 | -------------------------------------------------------------------------------- /models/loss/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-6 4 | 5 | def iou_single(a, b, mask, n_class): 6 | valid = mask == 1 7 | a = a[valid] 8 | b = b[valid] 9 | miou = [] 10 | for i in range(n_class): 11 | inter = ((a == i) & (b == i)).float() 12 | union = ((a == i) | (b == i)).float() 13 | 14 | miou.append(torch.sum(inter) / (torch.sum(union) + EPS)) 15 | miou = sum(miou) / len(miou) 16 | return miou 17 | 18 | def iou(a, b, mask, n_class=2, reduce=True): 19 | batch_size = a.size(0) 20 | 21 | a = a.view(batch_size, -1) 22 | b = b.view(batch_size, -1) 23 | mask = mask.view(batch_size, -1) 24 | 25 | iou = a.new_zeros((batch_size,), dtype=torch.float32) 26 | for i in range(batch_size): 27 | iou[i] = iou_single(a[i], b[i], mask[i], n_class) 28 | 29 | if reduce: 30 | iou = torch.mean(iou) 31 | return iou -------------------------------------------------------------------------------- /models/loss/ohem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def ohem_single(score, gt_text, training_mask): 4 | pos_num = int(torch.sum(gt_text > 0.5)) - int(torch.sum((gt_text > 0.5) & (training_mask <= 0.5))) 5 | 6 | if pos_num == 0: 7 | # selected_mask = gt_text.copy() * 0 # may be not good 8 | selected_mask = training_mask 9 | selected_mask = selected_mask.view(1, selected_mask.shape[0], selected_mask.shape[1]).float() 10 | return selected_mask 11 | 12 | neg_num = int(torch.sum(gt_text <= 0.5)) 13 | neg_num = int(min(pos_num * 3, neg_num)) 14 | 15 | if neg_num == 0: 16 | selected_mask = training_mask 17 | selected_mask = selected_mask.view(1, selected_mask.shape[0], selected_mask.shape[1]).float() 18 | return selected_mask 19 | 20 | neg_score = score[gt_text <= 0.5] 21 | neg_score_sorted, _ = torch.sort(-neg_score) 22 | threshold = -neg_score_sorted[neg_num - 1] 23 | 24 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5) 25 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).float() 26 | return selected_mask 27 | 28 | def ohem_batch(scores, gt_texts, training_masks): 29 | selected_masks = [] 30 | for i in range(scores.shape[0]): 31 | selected_masks.append(ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) 32 | 33 | selected_masks = torch.cat(selected_masks, 0).float() 34 | return selected_masks 35 | -------------------------------------------------------------------------------- /models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPN 2 | from .builder import build_neck 3 | 4 | __all__ = ['FPN'] 5 | -------------------------------------------------------------------------------- /models/neck/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_neck(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | neck = models.neck.__dict__[cfg.type](**param) 12 | 13 | return neck 14 | -------------------------------------------------------------------------------- /models/neck/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from ..utils import Conv_BN_ReLU 6 | 7 | 8 | class FPN(nn.Module): 9 | def __init__(self, in_channels, out_channels): 10 | super(FPN, self).__init__() 11 | 12 | # Top layer 13 | self.toplayer_ = Conv_BN_ReLU(2048, 256, kernel_size=1, stride=1, padding=0) 14 | 15 | # Smooth layers 16 | self.smooth1_ = Conv_BN_ReLU(256, 256, kernel_size=3, stride=1, padding=1) 17 | 18 | self.smooth2_ = Conv_BN_ReLU(256, 256, kernel_size=3, stride=1, padding=1) 19 | 20 | self.smooth3_ = Conv_BN_ReLU(256, 256, kernel_size=3, stride=1, padding=1) 21 | 22 | # Lateral layers 23 | self.latlayer1_ = Conv_BN_ReLU(1024, 256, kernel_size=1, stride=1, padding=0) 24 | 25 | self.latlayer2_ = Conv_BN_ReLU(512, 256, kernel_size=1, stride=1, padding=0) 26 | 27 | self.latlayer3_ = Conv_BN_ReLU(256, 256, kernel_size=1, stride=1, padding=0) 28 | 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 32 | m.weight.data.normal_(0, math.sqrt(2. / n)) 33 | elif isinstance(m, nn.BatchNorm2d): 34 | m.weight.data.fill_(1) 35 | m.bias.data.zero_() 36 | 37 | def _upsample(self, x, y, scale=1): 38 | _, _, H, W = y.size() 39 | return F.upsample(x, size=(H // scale, W // scale), mode='bilinear') 40 | 41 | def _upsample_add(self, x, y): 42 | _, _, H, W = y.size() 43 | return F.upsample(x, size=(H, W), mode='bilinear') + y 44 | 45 | def forward(self, f2, f3, f4, f5): 46 | p5 = self.toplayer_(f5) 47 | 48 | f4 = self.latlayer1_(f4) 49 | p4 = self._upsample_add(p5, f4) 50 | p4 = self.smooth1_(p4) 51 | 52 | f3 = self.latlayer2_(f3) 53 | p3 = self._upsample_add(p4, f3) 54 | p3 = self.smooth2_(p3) 55 | 56 | f2 = self.latlayer3_(f2) 57 | p2 = self._upsample_add(p3, f2) 58 | p2 = self.smooth3_(p2) 59 | 60 | p3 = self._upsample(p3, p2) 61 | p4 = self._upsample(p4, p2) 62 | p5 = self._upsample(p5, p2) 63 | 64 | return p2, p3, p4, p5 65 | -------------------------------------------------------------------------------- /models/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .pse import pse 2 | -------------------------------------------------------------------------------- /models/post_processing/pse/__init__.py: -------------------------------------------------------------------------------- 1 | from .pse import pse -------------------------------------------------------------------------------- /models/post_processing/pse/pse.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | cimport numpy as np 4 | cimport cython 5 | cimport libcpp 6 | cimport libcpp.pair 7 | cimport libcpp.queue 8 | from libcpp.pair cimport * 9 | from libcpp.queue cimport * 10 | 11 | @cython.boundscheck(False) 12 | @cython.wraparound(False) 13 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 14 | np.ndarray[np.int32_t, ndim=2] label, 15 | int kernel_num, 16 | int label_num, 17 | float min_area=0): 18 | cdef np.ndarray[np.int32_t, ndim=2] pred 19 | pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 20 | 21 | for label_idx in range(1, label_num): 22 | if np.sum(label == label_idx) < min_area: 23 | label[label == label_idx] = 0 24 | 25 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 26 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 27 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 28 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 29 | cdef np.int16_t* dx = [-1, 1, 0, 0] 30 | cdef np.int16_t* dy = [0, 0, -1, 1] 31 | cdef np.int16_t tmpx, tmpy 32 | 33 | points = np.array(np.where(label > 0)).transpose((1, 0)) 34 | for point_idx in range(points.shape[0]): 35 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 36 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 37 | pred[tmpx, tmpy] = label[tmpx, tmpy] 38 | 39 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 40 | cdef int cur_label 41 | for kernel_idx in range(kernel_num - 1, -1, -1): 42 | while not que.empty(): 43 | cur = que.front() 44 | que.pop() 45 | cur_label = pred[cur.first, cur.second] 46 | 47 | is_edge = True 48 | for j in range(4): 49 | tmpx = cur.first + dx[j] 50 | tmpy = cur.second + dy[j] 51 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 52 | continue 53 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 54 | continue 55 | 56 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 57 | pred[tmpx, tmpy] = cur_label 58 | is_edge = False 59 | if is_edge: 60 | nxt_que.push(cur) 61 | 62 | que, nxt_que = nxt_que, que 63 | 64 | return pred 65 | 66 | def pse(kernels, min_area): 67 | kernel_num = kernels.shape[0] 68 | label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) 69 | return _pse(kernels[:-1], label, kernel_num, label_num, min_area) -------------------------------------------------------------------------------- /models/post_processing/pse/readme.txt: -------------------------------------------------------------------------------- 1 | python setup.py build_ext --inplace -------------------------------------------------------------------------------- /models/post_processing/pse/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup(ext_modules=cythonize(Extension( 6 | 'pse', 7 | sources=['pse.pyx'], 8 | language='c++', 9 | include_dirs=[numpy.get_include()], 10 | library_dirs=[], 11 | libraries=[], 12 | extra_compile_args=['-O3'], 13 | extra_link_args=[] 14 | ))) 15 | -------------------------------------------------------------------------------- /models/psenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import time 7 | 8 | from .backbone import build_backbone 9 | from .neck import build_neck 10 | from .head import build_head 11 | from .utils import Conv_BN_ReLU 12 | 13 | 14 | class PSENet(nn.Module): 15 | def __init__(self, 16 | backbone, 17 | neck, 18 | detection_head): 19 | super(PSENet, self).__init__() 20 | self.backbone = build_backbone(backbone) 21 | self.fpn = build_neck(neck) 22 | 23 | self.det_head = build_head(detection_head) 24 | 25 | def _upsample(self, x, size, scale=1): 26 | _, _, H, W = size 27 | return F.upsample(x, size=(H // scale, W // scale), mode='bilinear') 28 | 29 | def forward(self, 30 | imgs, 31 | gt_texts=None, 32 | gt_kernels=None, 33 | training_masks=None, 34 | img_metas=None, 35 | cfg=None): 36 | outputs = dict() 37 | 38 | if not self.training and cfg.report_speed: 39 | torch.cuda.synchronize() 40 | start = time.time() 41 | 42 | # backbone 43 | f = self.backbone(imgs) 44 | if not self.training and cfg.report_speed: 45 | torch.cuda.synchronize() 46 | outputs.update(dict( 47 | backbone_time=time.time() - start 48 | )) 49 | start = time.time() 50 | 51 | # FPN 52 | f1, f2, f3, f4, = self.fpn(f[0], f[1], f[2], f[3]) 53 | 54 | f = torch.cat((f1, f2, f3, f4), 1) 55 | 56 | if not self.training and cfg.report_speed: 57 | torch.cuda.synchronize() 58 | outputs.update(dict( 59 | neck_time=time.time() - start 60 | )) 61 | start = time.time() 62 | 63 | # detection 64 | 65 | det_out = self.det_head(f) 66 | 67 | if not self.training and cfg.report_speed: 68 | torch.cuda.synchronize() 69 | outputs.update(dict( 70 | det_head_time=time.time() - start 71 | )) 72 | 73 | if self.training: 74 | det_out = self._upsample(det_out, imgs.size()) 75 | det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, training_masks) 76 | outputs.update(det_loss) 77 | else: 78 | det_out = self._upsample(det_out, imgs.size(), 1) 79 | det_res = self.det_head.get_results(det_out, img_metas, cfg) 80 | outputs.update(det_res) 81 | 82 | return outputs 83 | -------------------------------------------------------------------------------- /models/pypse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from queue import Queue 4 | 5 | 6 | def pse(kernals, min_area): 7 | kernal_num = len(kernals) 8 | # print('kernal_num', kernal_num) 9 | pred = np.zeros(kernals[0].shape, dtype='int32') 10 | 11 | label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4) 12 | 13 | for label_idx in range(1, label_num): 14 | if np.sum(label == label_idx) < min_area: 15 | label[label == label_idx] = 0 16 | 17 | queue = Queue(maxsize=0) 18 | next_queue = Queue(maxsize=0) 19 | points = np.array(np.where(label > 0)).transpose((1, 0)) 20 | 21 | for point_idx in range(points.shape[0]): 22 | x, y = points[point_idx, 0], points[point_idx, 1] 23 | l = label[x, y] 24 | queue.put((x, y, l)) 25 | pred[x, y] = l 26 | 27 | dx = [-1, 1, 0, 0] 28 | dy = [0, 0, -1, 1] 29 | for kernal_idx in range(kernal_num - 2, -1, -1): 30 | kernal = kernals[kernal_idx].copy() 31 | while not queue.empty(): 32 | (x, y, l) = queue.get() 33 | 34 | is_edge = True 35 | for j in range(4): 36 | tmpx = x + dx[j] 37 | tmpy = y + dy[j] 38 | if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: 39 | continue 40 | if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 41 | continue 42 | 43 | queue.put((tmpx, tmpy, l)) 44 | pred[tmpx, tmpy] = l 45 | is_edge = False 46 | if is_edge: 47 | next_queue.put((x, y, l)) 48 | 49 | # kernal[pred > 0] = 0 50 | queue, next_queue = next_queue, queue 51 | 52 | # points = np.array(np.where(pred > 0)).transpose((1, 0)) 53 | # for point_idx in range(points.shape[0]): 54 | # x, y = points[point_idx, 0], points[point_idx, 1] 55 | # l = pred[x, y] 56 | # queue.put((x, y, l)) 57 | 58 | return pred 59 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_bn_relu import Conv_BN_ReLU 2 | from .fuse_conv_bn import fuse_module 3 | -------------------------------------------------------------------------------- /models/utils/conv_bn_relu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class Conv_BN_ReLU(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0): 8 | super(Conv_BN_ReLU, self).__init__() 9 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 10 | bias=False) 11 | self.bn = nn.BatchNorm2d(out_planes) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | for m in self.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 17 | m.weight.data.normal_(0, math.sqrt(2. / n)) 18 | elif isinstance(m, nn.BatchNorm2d): 19 | m.weight.data.fill_(1) 20 | m.bias.data.zero_() 21 | 22 | def forward(self, x): 23 | return self.relu(self.bn(self.conv(x))) 24 | -------------------------------------------------------------------------------- /models/utils/fuse_conv_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def fuse_conv_bn(conv, bn): 6 | """During inference, the functionary of batch norm layers is turned off but 7 | only the mean and var alone channels are used, which exposes the chance to 8 | fuse it with the preceding conv layers to save computations and simplify 9 | network structures.""" 10 | conv_w = conv.weight 11 | conv_b = conv.bias if conv.bias is not None else torch.zeros_like( 12 | bn.running_mean) 13 | 14 | factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) 15 | conv.weight = nn.Parameter(conv_w * 16 | factor.reshape([conv.out_channels, 1, 1, 1])) 17 | conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) 18 | return conv 19 | 20 | 21 | def fuse_module(m): 22 | last_conv = None 23 | last_conv_name = None 24 | 25 | for name, child in m.named_children(): 26 | if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): 27 | if last_conv is None: # only fuse BN that is after Conv 28 | continue 29 | fused_conv = fuse_conv_bn(last_conv, child) 30 | m._modules[last_conv_name] = fused_conv 31 | # To reduce changes, set BN as Identity instead of deleting it. 32 | m._modules[name] = nn.Identity() 33 | last_conv = None 34 | elif isinstance(child, nn.Conv2d): 35 | last_conv = child 36 | last_conv_name = name 37 | else: 38 | fuse_module(child) 39 | return m 40 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## News 2 | - PSENet is included in [MMOCR](https://github.com/open-mmlab/mmocr). 3 | - We have upgraded PSENet from python2 to python3. You can find the old version [here](https://github.com/whai362/PSENet/tree/python2). 4 | - We have implemented PSENet using Paddle. Visit it [here](https://github.com/RoseSakurai/PSENet_paddle). 5 | - You can find code of PAN [here](https://github.com/whai362/pan_pp.pytorch). 6 | - Another group also implemented PSENet using Paddle. You can visit it [here](https://github.com/PaddleEdu/OCR-models-PaddlePaddle/tree/main/PSENet). You can also have a try online with all the environment ready [here](https://aistudio.baidu.com/aistudio/projectdetail/1945560). 7 | 8 | ## Introduction 9 | Official Pytorch implementations of PSENet [1]. 10 | 11 | [1] W. Wang, E. Xie, X. Li, W. Hou, T. Lu, G. Yu, and S. Shao. Shape robust text detection with progressive scale expansion network. In Proc. IEEE Conf. Comp. Vis. Patt. Recogn., pages 9336–9345, 2019.
12 | 13 | 14 | ## Recommended environment 15 | ``` 16 | Python 3.6+ 17 | Pytorch 1.1.0 18 | torchvision 0.3 19 | mmcv 0.2.12 20 | editdistance 21 | Polygon3 22 | pyclipper 23 | opencv-python 3.4.2.17 24 | Cython 25 | ``` 26 | 27 | ## Install 28 | ```shell script 29 | pip install -r requirement.txt 30 | ./compile.sh 31 | ``` 32 | 33 | ## Training 34 | ```shell script 35 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py ${CONFIG_FILE} 36 | ``` 37 | For example: 38 | ```shell script 39 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py 40 | ``` 41 | 42 | ## Test 43 | ``` 44 | python test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} 45 | ``` 46 | For example: 47 | ```shell script 48 | python test.py config/psenet/psenet_r50_ic15_736.py checkpoints/psenet_r50_ic15_736/checkpoint.pth.tar 49 | ``` 50 | 51 | ## Speed 52 | ```shell script 53 | python test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --report_speed 54 | ``` 55 | For example: 56 | ```shell script 57 | python test.py config/psenet/psenet_r50_ic15_736.py checkpoints/psenet_r50_ic15_736/checkpoint.pth.tar --report_speed 58 | ``` 59 | 60 | ## Evaluation 61 | ## Introduction 62 | The evaluation scripts of ICDAR 2015 (IC15), Total-Text (TT) and CTW1500 (CTW) datasets. 63 | ## [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4) 64 | Text detection 65 | ```shell script 66 | ./eval_ic15.sh 67 | ``` 68 | 69 | 70 | ## [Total-Text](https://github.com/cs-chan/Total-Text-Dataset) 71 | Text detection 72 | ```shell script 73 | ./eval_tt.sh 74 | ``` 75 | 76 | ## [CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) 77 | Text detection 78 | ```shell script 79 | ./eval_ctw.sh 80 | ``` 81 | 82 | ## Benchmark 83 | ## Results 84 | 85 | [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4) 86 | 87 | | Method | Backbone | Fine-tuning | Scale | Config | Precision (%) | Recall (%) | F-measure (%) | Model | 88 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 89 | | PSENet | ResNet50 | N | Shorter Side: 736 | [psenet_r50_ic15_736.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ic15_736.py) | 83.6 | 74.0 | 78.5 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ic15_736.pth.tar) | 90 | | PSENet | ResNet50 | N | Shorter Side: 1024 | [psenet_r50_ic15_1024.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ic15_1024.py) | 84.4 | 76.3 | 80.2 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ic15_1024.pth.tar) | 91 | | PSENet (paper) | ResNet50 | N | Longer Side: 2240 | - | 81.5 | 79.7 | 80.6 | - | 92 | | PSENet | ResNet50 | Y | Shorter Side: 736 | [psenet_r50_ic15_736_finetune.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ic15_736_finetune.py) | 85.3 | 76.8 | 80.9 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ic15_736_finetune.pth.tar) | 93 | | PSENet | ResNet50 | Y | Shorter Side: 1024 | [psenet_r50_ic15_1024_finetune.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ic15_1024_finetune.py) | 86.2 | 79.4 | 82.7 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ic15_1024_finetune.pth.tar) | 94 | | PSENet (paper) | ResNet50 | Y | Longer Side: 2240 | - | 86.9 | 84.5 | 85.7 | - | 95 | 96 | [CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) 97 | 98 | | Method | Backbone | Fine-tuning | Config | Precision (%) | Recall (%) | F-measure (%) | Model | 99 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 100 | | PSENet | ResNet50 | N | [psenet_r50_ctw.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ctw.py) | 82.6 | 76.4 | 79.4 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ctw.pth.tar) | 101 | | PSENet (paper) | ResNet50 | N | - | 80.6 | 75.6 | 78 | - | 102 | | PSENet | ResNet50 | Y | [psenet_r50_ctw_finetune.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_ctw_finetune.py) | 84.5 | 79.2 | 81.8 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_ctw_finetune.pth.tar) | 103 | | PSENet (paper) | ResNet50 | Y | - | 84.8 | 79.7 | 82.2 | - | 104 | 105 | [Total-Text](https://github.com/cs-chan/Total-Text-Dataset) 106 | 107 | | Method | Backbone | Fine-tuning | Config | Precision (%) | Recall (%) | F-measure (%) | Model | 108 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 109 | | PSENet | ResNet50 | N | [psenet_r50_tt.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_tt.py) | 87.3 | 77.9 | 82.3 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_tt.pth.tar) | 110 | | PSENet (paper) | ResNet50 | N | - | 81.8 | 75.1 | 78.3 | - | 111 | | PSENet | ResNet50 | Y | [psenet_r50_tt_finetune.py](https://github.com/whai362/PSENet/blob/python3/config/psenet/psenet_r50_tt_finetune.py) | 89.3 | 79.6 | 84.2 | [Releases](https://github.com/whai362/PSENet/releases/download/checkpoint/psenet_r50_tt_finetune.pth.tar) | 112 | | PSENet (paper) | ResNet50 | Y | - | 84.0 | 78.0 | 80.9 | - | 113 | 114 | ## Citation 115 | ``` 116 | @inproceedings{wang2019shape, 117 | title={Shape robust text detection with progressive scale expansion network}, 118 | author={Wang, Wenhai and Xie, Enze and Li, Xiang and Hou, Wenbo and Lu, Tong and Yu, Gang and Shao, Shuai}, 119 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 120 | pages={9336--9345}, 121 | year={2019} 122 | } 123 | ``` 124 | 125 | ## License 126 | This project is developed and maintained by [IMAGINE Lab@National Key Laboratory for Novel Software Technology, Nanjing University](https://cs.nju.edu.cn/lutong/ImagineLab.html). 127 | 128 | IMAGINE Lab 129 | 130 | This project is released under the [Apache 2.0 license](https://github.com/whai362/pan_pp.pytorch/blob/master/LICENSE). 131 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.3 3 | mmcv==0.2.12 4 | editdistance 5 | Polygon3 6 | pyclipper 7 | opencv-python==3.4.2.17 8 | Cython -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | import os 5 | import os.path as osp 6 | import sys 7 | import time 8 | import json 9 | from mmcv import Config 10 | 11 | from dataset import build_data_loader 12 | from models import build_model 13 | from models.utils import fuse_module 14 | from utils import ResultFormat, AverageMeter 15 | 16 | 17 | def report_speed(outputs, speed_meters): 18 | total_time = 0 19 | for key in outputs: 20 | if 'time' in key: 21 | total_time += outputs[key] 22 | speed_meters[key].update(outputs[key]) 23 | print('%s: %.4f' % (key, speed_meters[key].avg)) 24 | 25 | speed_meters['total_time'].update(total_time) 26 | print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg)) 27 | 28 | 29 | def test(test_loader, model, cfg): 30 | model.eval() 31 | 32 | rf = ResultFormat(cfg.data.test.type, cfg.test_cfg.result_path) 33 | 34 | if cfg.report_speed: 35 | speed_meters = dict( 36 | backbone_time=AverageMeter(500), 37 | neck_time=AverageMeter(500), 38 | det_head_time=AverageMeter(500), 39 | det_pse_time=AverageMeter(500), 40 | rec_time=AverageMeter(500), 41 | total_time=AverageMeter(500) 42 | ) 43 | 44 | for idx, data in enumerate(test_loader): 45 | print('Testing %d/%d' % (idx, len(test_loader))) 46 | sys.stdout.flush() 47 | 48 | # prepare input 49 | data['imgs'] = data['imgs'].cuda() 50 | data.update(dict( 51 | cfg=cfg 52 | )) 53 | 54 | # forward 55 | with torch.no_grad(): 56 | outputs = model(**data) 57 | 58 | if cfg.report_speed: 59 | report_speed(outputs, speed_meters) 60 | 61 | # save result 62 | image_name, _ = osp.splitext(osp.basename(test_loader.dataset.img_paths[idx])) 63 | # print('image_name', image_name) 64 | rf.write_result(image_name, outputs) 65 | 66 | 67 | def main(args): 68 | cfg = Config.fromfile(args.config) 69 | for d in [cfg, cfg.data.test]: 70 | d.update(dict( 71 | report_speed=args.report_speed 72 | )) 73 | print(json.dumps(cfg._cfg_dict, indent=4)) 74 | sys.stdout.flush() 75 | 76 | # data loader 77 | data_loader = build_data_loader(cfg.data.test) 78 | test_loader = torch.utils.data.DataLoader( 79 | data_loader, 80 | batch_size=1, 81 | shuffle=False, 82 | num_workers=2, 83 | ) 84 | # model 85 | model = build_model(cfg.model) 86 | model = model.cuda() 87 | 88 | if args.checkpoint is not None: 89 | if os.path.isfile(args.checkpoint): 90 | print("Loading model and optimizer from checkpoint '{}'".format(args.checkpoint)) 91 | sys.stdout.flush() 92 | 93 | checkpoint = torch.load(args.checkpoint) 94 | 95 | d = dict() 96 | for key, value in checkpoint['state_dict'].items(): 97 | tmp = key[7:] 98 | d[tmp] = value 99 | model.load_state_dict(d) 100 | else: 101 | print("No checkpoint found at '{}'".format(args.resume)) 102 | raise 103 | 104 | # fuse conv and bn 105 | model = fuse_module(model) 106 | 107 | # test 108 | test(test_loader, model, cfg) 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser(description='Hyperparams') 113 | parser.add_argument('config', help='config file path') 114 | parser.add_argument('checkpoint', nargs='?', type=str, default=None) 115 | parser.add_argument('--report_speed', action='store_true') 116 | args = parser.parse_args() 117 | 118 | main(args) 119 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import sys 8 | import time 9 | import json 10 | from mmcv import Config 11 | 12 | from dataset import build_data_loader 13 | from models import build_model 14 | from utils import AverageMeter 15 | 16 | torch.manual_seed(123456) 17 | torch.cuda.manual_seed(123456) 18 | np.random.seed(123456) 19 | random.seed(123456) 20 | 21 | 22 | def train(train_loader, model, optimizer, epoch, start_iter, cfg): 23 | model.train() 24 | 25 | # meters 26 | batch_time = AverageMeter() 27 | data_time = AverageMeter() 28 | 29 | losses = AverageMeter() 30 | losses_text = AverageMeter() 31 | losses_kernels = AverageMeter() 32 | 33 | ious_text = AverageMeter() 34 | ious_kernel = AverageMeter() 35 | accs_rec = AverageMeter() 36 | 37 | # start time 38 | start = time.time() 39 | for iter, data in enumerate(train_loader): 40 | # skip previous iterations 41 | if iter < start_iter: 42 | print('Skipping iter: %d' % iter) 43 | sys.stdout.flush() 44 | continue 45 | 46 | # time cost of data loader 47 | data_time.update(time.time() - start) 48 | 49 | # adjust learning rate 50 | adjust_learning_rate(optimizer, train_loader, epoch, iter, cfg) 51 | 52 | # prepare input 53 | data.update(dict(cfg=cfg)) 54 | 55 | # forward 56 | outputs = model(**data) 57 | # 58 | # print(outputs['loss_text'].shape) 59 | # print(outputs['loss_kernels'].shape) 60 | 61 | # detection loss 62 | loss_text = torch.mean(outputs['loss_text']) 63 | losses_text.update(loss_text.item()) 64 | 65 | loss_kernels = torch.mean(outputs['loss_kernels']) 66 | losses_kernels.update(loss_kernels.item()) 67 | 68 | loss = loss_text + loss_kernels 69 | 70 | iou_text = torch.mean(outputs['iou_text']) 71 | ious_text.update(iou_text.item()) 72 | iou_kernel = torch.mean(outputs['iou_kernel']) 73 | ious_kernel.update(iou_kernel.item()) 74 | 75 | losses.update(loss.item()) 76 | # backward 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | batch_time.update(time.time() - start) 82 | 83 | # update start time 84 | start = time.time() 85 | 86 | # print log 87 | if iter % 20 == 0: 88 | output_log = '({batch}/{size}) LR: {lr:.6f} | Batch: {bt:.3f}s | Total: {total:.0f}min | ' \ 89 | 'ETA: {eta:.0f}min | Loss: {loss:.3f} | ' \ 90 | 'Loss(text/kernel): {loss_text:.3f}/{loss_kernel:.3f} ' \ 91 | '| IoU(text/kernel): {iou_text:.3f}/{iou_kernel:.3f} | Acc rec: {acc_rec:.3f}'.format( 92 | batch=iter + 1, 93 | size=len(train_loader), 94 | lr=optimizer.param_groups[0]['lr'], 95 | bt=batch_time.avg, 96 | total=batch_time.avg * iter / 60.0, 97 | eta=batch_time.avg * (len(train_loader) - iter) / 60.0, 98 | loss_text=losses_text.avg, 99 | loss_kernel=losses_kernels.avg, 100 | loss=losses.avg, 101 | iou_text=ious_text.avg, 102 | iou_kernel=ious_kernel.avg, 103 | acc_rec=accs_rec.avg, 104 | ) 105 | print(output_log) 106 | sys.stdout.flush() 107 | 108 | 109 | def adjust_learning_rate(optimizer, dataloader, epoch, iter, cfg): 110 | schedule = cfg.train_cfg.schedule 111 | if isinstance(schedule, str): 112 | assert schedule == 'polylr', 'Error: schedule should be polylr!' 113 | cur_iter = epoch * len(dataloader) + iter 114 | max_iter_num = cfg.train_cfg.epoch * len(dataloader) 115 | lr = cfg.train_cfg.lr * (1 - float(cur_iter) / max_iter_num) ** 0.9 116 | elif isinstance(schedule, tuple): 117 | lr = cfg.train_cfg.lr 118 | for i in range(len(schedule)): 119 | if epoch < schedule[i]: 120 | break 121 | lr = lr * 0.1 122 | 123 | for param_group in optimizer.param_groups: 124 | param_group['lr'] = lr 125 | 126 | 127 | def save_checkpoint(state, checkpoint_path, cfg): 128 | file_path = osp.join(checkpoint_path, 'checkpoint.pth.tar') 129 | torch.save(state, file_path) 130 | 131 | if cfg.data.train.type in ['synth'] or \ 132 | (state['iter'] == 0 and state['epoch'] > cfg.train_cfg.epoch - 100 and state['epoch'] % 10 == 0): 133 | file_name = 'checkpoint_%dep.pth.tar' % state['epoch'] 134 | file_path = osp.join(checkpoint_path, file_name) 135 | torch.save(state, file_path) 136 | 137 | 138 | def main(args): 139 | cfg = Config.fromfile(args.config) 140 | print(json.dumps(cfg._cfg_dict, indent=4)) 141 | 142 | if args.checkpoint is not None: 143 | checkpoint_path = args.checkpoint 144 | else: 145 | cfg_name, _ = osp.splitext(osp.basename(args.config)) 146 | checkpoint_path = osp.join('checkpoints', cfg_name) 147 | if not osp.isdir(checkpoint_path): 148 | os.makedirs(checkpoint_path) 149 | print('Checkpoint path: %s.' % checkpoint_path) 150 | sys.stdout.flush() 151 | 152 | # data loader 153 | data_loader = build_data_loader(cfg.data.train) 154 | train_loader = torch.utils.data.DataLoader( 155 | data_loader, 156 | batch_size=cfg.data.batch_size, 157 | shuffle=True, 158 | num_workers=8, 159 | drop_last=True, 160 | pin_memory=True 161 | ) 162 | 163 | # model 164 | model = build_model(cfg.model) 165 | model = torch.nn.DataParallel(model).cuda() 166 | 167 | # Check if model has custom optimizer / loss 168 | if hasattr(model.module, 'optimizer'): 169 | optimizer = model.module.optimizer 170 | else: 171 | if cfg.train_cfg.optimizer == 'SGD': 172 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.train_cfg.lr, momentum=0.99, weight_decay=5e-4) 173 | elif cfg.train_cfg.optimizer == 'Adam': 174 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train_cfg.lr) 175 | 176 | start_epoch = 0 177 | start_iter = 0 178 | if hasattr(cfg.train_cfg, 'pretrain'): 179 | assert osp.isfile(cfg.train_cfg.pretrain), 'Error: no pretrained weights found!' 180 | print('Finetuning from pretrained model %s.' % cfg.train_cfg.pretrain) 181 | checkpoint = torch.load(cfg.train_cfg.pretrain) 182 | model.load_state_dict(checkpoint['state_dict']) 183 | if args.resume: 184 | assert osp.isfile(args.resume), 'Error: no checkpoint directory found!' 185 | print('Resuming from checkpoint %s.' % args.resume) 186 | checkpoint = torch.load(args.resume) 187 | start_epoch = checkpoint['epoch'] 188 | start_iter = checkpoint['iter'] 189 | model.load_state_dict(checkpoint['state_dict']) 190 | optimizer.load_state_dict(checkpoint['optimizer']) 191 | 192 | for epoch in range(start_epoch, cfg.train_cfg.epoch): 193 | print('\nEpoch: [%d | %d]' % (epoch + 1, cfg.train_cfg.epoch)) 194 | 195 | train(train_loader, model, optimizer, epoch, start_iter, cfg) 196 | 197 | state = dict( 198 | epoch=epoch + 1, 199 | iter=0, 200 | state_dict=model.state_dict(), 201 | optimizer=optimizer.state_dict() 202 | ) 203 | save_checkpoint(state, checkpoint_path, cfg) 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser(description='Hyperparams') 208 | parser.add_argument('config', help='config file path') 209 | parser.add_argument('--checkpoint', nargs='?', type=str, default=None) 210 | parser.add_argument('--resume', nargs='?', type=str, default=None) 211 | args = parser.parse_args() 212 | 213 | main(args) 214 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .average_meter import AverageMeter 3 | from .result_format import ResultFormat 4 | -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, max_len=-1): 4 | self.val = [] 5 | self.count = [] 6 | self.max_len = max_len 7 | self.avg = 0 8 | 9 | def update(self, val, n=1): 10 | self.val.append(val * n) 11 | self.count.append(n) 12 | if self.max_len > 0 and len(self.val) > self.max_len: 13 | self.val = self.val[-self.max_len:] 14 | self.count = self.count[-self.max_len:] 15 | self.avg = sum(self.val) / sum(self.count) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | class Logger(object): 6 | def __init__(self, fpath, title=None, resume=False): 7 | self.file = None 8 | self.resume = resume 9 | self.title = '' if title == None else title 10 | if fpath is not None: 11 | if resume: 12 | self.file = open(fpath, 'r') 13 | name = self.file.readline() 14 | self.names = name.rstrip().split('\t') 15 | self.numbers = {} 16 | for _, name in enumerate(self.names): 17 | self.numbers[name] = [] 18 | 19 | for numbers in self.file: 20 | numbers = numbers.rstrip().split('\t') 21 | for i in range(0, len(numbers)): 22 | self.numbers[self.names[i]].append(numbers[i]) 23 | self.file.close() 24 | self.file = open(fpath, 'a') 25 | else: 26 | self.file = open(fpath, 'w') 27 | 28 | def set_names(self, names): 29 | if self.resume: 30 | pass 31 | # initialize numbers as empty list 32 | self.numbers = {} 33 | self.names = names 34 | for _, name in enumerate(self.names): 35 | self.file.write(name) 36 | self.file.write('\t') 37 | self.numbers[name] = [] 38 | self.file.write('\n') 39 | self.file.flush() 40 | 41 | 42 | def append(self, numbers): 43 | assert len(self.names) == len(numbers) 44 | for index, num in enumerate(numbers): 45 | if type(num) == str: 46 | self.file.write(num) 47 | else: 48 | self.file.write("{0:.6f}".format(num)) 49 | self.file.write('\t') 50 | self.numbers[self.names[index]].append(num) 51 | self.file.write('\n') 52 | self.file.flush() 53 | 54 | def close(self): 55 | if self.file is not None: 56 | self.file.close() -------------------------------------------------------------------------------- /utils/result_format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import zipfile 4 | 5 | 6 | class ResultFormat(object): 7 | def __init__(self, data_type, result_path): 8 | self.data_type = data_type 9 | self.result_path = result_path 10 | 11 | if osp.isfile(result_path): 12 | os.remove(result_path) 13 | 14 | if result_path.endswith('.zip'): 15 | result_path = result_path.replace('.zip', '') 16 | 17 | if not osp.exists(result_path): 18 | os.makedirs(result_path) 19 | 20 | def write_result(self, img_name, outputs): 21 | if 'IC15' in self.data_type: 22 | self._write_result_ic15(img_name, outputs) 23 | elif 'TT' in self.data_type: 24 | self._write_result_tt(img_name, outputs) 25 | elif 'CTW' in self.data_type: 26 | self._write_result_ctw(img_name, outputs) 27 | elif 'MSRA' in self.data_type: 28 | self._write_result_msra(img_name, outputs) 29 | 30 | def _write_result_ic15(self, img_name, outputs): 31 | assert self.result_path.endswith('.zip'), 'Error: ic15 result should be a zip file!' 32 | 33 | tmp_folder = self.result_path.replace('.zip', '') 34 | 35 | bboxes = outputs['bboxes'] 36 | 37 | lines = [] 38 | for i, bbox in enumerate(bboxes): 39 | values = [int(v) for v in bbox] 40 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n" % tuple(values) 41 | lines.append(line) 42 | 43 | file_name = 'res_%s.txt' % img_name 44 | file_path = osp.join(tmp_folder, file_name) 45 | with open(file_path, 'w') as f: 46 | for line in lines: 47 | f.write(line) 48 | 49 | z = zipfile.ZipFile(self.result_path, 'a', zipfile.ZIP_DEFLATED) 50 | z.write(file_path, file_name) 51 | z.close() 52 | 53 | def _write_result_tt(self, image_name, outputs): 54 | bboxes = outputs['bboxes'] 55 | 56 | lines = [] 57 | for i, bbox in enumerate(bboxes): 58 | bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1) 59 | values = [int(v) for v in bbox] 60 | line = "%d" % values[0] 61 | for v_id in range(1, len(values)): 62 | line += ",%d" % values[v_id] 63 | line += '\n' 64 | lines.append(line) 65 | 66 | file_name = '%s.txt' % image_name 67 | file_path = osp.join(self.result_path, file_name) 68 | with open(file_path, 'w') as f: 69 | for line in lines: 70 | f.write(line) 71 | 72 | def _write_result_ctw(self, image_name, outputs): 73 | bboxes = outputs['bboxes'] 74 | 75 | lines = [] 76 | for i, bbox in enumerate(bboxes): 77 | bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1) 78 | values = [int(v) for v in bbox] 79 | line = "%d" % values[0] 80 | for v_id in range(1, len(values)): 81 | line += ",%d" % values[v_id] 82 | line += '\n' 83 | lines.append(line) 84 | 85 | tmp_folder = self.result_path.replace('.zip', '') 86 | 87 | file_name = '%s.txt' % image_name 88 | file_path = osp.join(tmp_folder, file_name) 89 | with open(file_path, 'w') as f: 90 | for line in lines: 91 | f.write(line) 92 | 93 | z = zipfile.ZipFile(self.result_path, 'a', zipfile.ZIP_DEFLATED) 94 | z.write(file_path, file_name) 95 | z.close() 96 | 97 | 98 | 99 | 100 | def _write_result_msra(self, image_name, outputs): 101 | bboxes = outputs['bboxes'] 102 | 103 | lines = [] 104 | for b_idx, bbox in enumerate(bboxes): 105 | values = [int(v) for v in bbox] 106 | line = "%d" % values[0] 107 | for v_id in range(1, len(values)): 108 | line += ", %d" % values[v_id] 109 | line += '\n' 110 | lines.append(line) 111 | 112 | file_name = '%s.txt' % image_name 113 | file_path = osp.join(self.result_path, file_name) 114 | with open(file_path, 'w') as f: 115 | for line in lines: 116 | f.write(line) 117 | --------------------------------------------------------------------------------