├── .gitignore ├── LICENSE ├── README.md ├── ctpn ├── __init__.py ├── config.py ├── layers │ ├── __init__.py │ ├── anchor.py │ ├── base_net.py │ ├── gt.py │ ├── losses.py │ ├── models.py │ ├── target.py │ └── text_proposals.py ├── preprocess │ ├── __init__.py │ └── reader.py └── utils │ ├── __init__.py │ ├── detector.py │ ├── file_utils.py │ ├── generator.py │ ├── gt_utils.py │ ├── image_utils.py │ ├── np_utils.py │ ├── text_proposal_connector.py │ ├── text_proposal_graph_builder.py │ ├── tf_utils.py │ └── visualize.py ├── evaluate.py ├── image_examples ├── a0.png ├── a1.png ├── a2.png ├── a3.png ├── bkgd_1_0_generated_0.1.jpg ├── flip1.png ├── flip2.png ├── icdar2015 │ ├── img_200.0.jpg │ ├── img_200.1.jpg │ ├── img_5.0.jpg │ ├── img_5.1.jpg │ ├── img_8.0.jpg │ └── img_8.1.jpg └── icdar2017 │ ├── ts_img_01000.0.jpg │ ├── ts_img_01000.1.jpg │ ├── ts_img_01001.0.jpg │ └── ts_img_01001.1.jpg ├── predict.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-ctpn 2 | 3 | [TOC] 4 | 5 | 1. [说明](#说明) 6 | 2. [预测](#预测) 7 | 3. [训练](#训练) 8 | 4. [例子](#例子)
9 | 4.1 [ICDAR2015](#ICDAR2015)
10 | 4.1.1 [带侧边细化](#带侧边细化)
11 | 4.1.2 [不带带侧边细化](#不带侧边细化)
12 | 4.1.3 [做数据增广-水平翻转](#做数据增广-水平翻转)
13 | 4.2 [ICDAR2017](#ICDAR2017)
14 | 4.3 [其它数据集](#其它数据集) 15 | 5. [toDoList](#toDoList) 16 | 6. [总结](#总结) 17 | 18 | ## 说明 19 | 20 | ​ 本工程是keras实现的[CPTN: Detecting Text in Natural Image with Connectionist Text Proposal Network](https://arxiv.org/abs/1609.03605) . 本工程实现主要参考了[keras-faster-rcnn](https://github.com/yizt/keras-faster-rcnn) ; 并在ICDAR2015和ICDAR2017数据集上训练和测试。 21 | 22 | ​ 工程地址: [keras-ctpn](https://github.com/yizt/keras-ctpn) 23 | 24 | ​ cptn论文翻译:[CTPN.md](https://github.com/yizt/cv-papers/blob/master/CTPN.md) 25 | 26 | **效果**: 27 | 28 | ​ 使用ICDAR2015的1000张图像训练在500张测试集上结果为:Recall: 37.07 % Precision: 42.94 % Hmean: 39.79 %; 29 | 原文中的F值为61%;使用了额外的3000张图像训练。 30 | 31 | **关键点说明**: 32 | 33 | a.骨干网络使用的是resnet50 34 | 35 | b.训练输入图像大小为720\*720; 将图像的长边缩放到720,保持长宽比,短边padding;原文是短边600;预测时使用1024*1024 36 | 37 | c.batch_size为4, 每张图像训练128个anchor,正负样本比为1:1; 38 | 39 | d.分类、边框回归以及侧边细化的损失函数权重为1:1:1;原论文中是1:1:2 40 | 41 | e.侧边细化与边框回归选择一样的正样本anchor;原文中应该是分开选择的 42 | 43 | f.侧边细化还是有效果的(注:网上很多人说没有啥效果) 44 | 45 | g.由于有双向GRU,水平翻转会影响效果(见样例[做数据增广-水平翻转](#做数据增广-水平翻转)) 46 | 47 | h.随机裁剪做数据增广,网络不收敛 48 | 49 | 50 | 51 | 52 | ## 预测 53 | 54 | a. 工程下载 55 | 56 | ```bash 57 | git clone https://github.com/yizt/keras-ctpn 58 | ``` 59 | 60 | 61 | 62 | b. 预训练模型下载 63 | 64 | ​ ICDAR2015训练集上训练好的模型下载地址: [google drive](https://drive.google.com/open?id=12t-PFYvYwx4In2aRv7OgRFkHa9rCjjn7),[百度云盘](https://pan.baidu.com/s/1GnDATacvBeFXpAwnBW6RaQ) 取码:wm47 65 | 66 | c.修改配置类config.py中如下属性 67 | 68 | ```python 69 | WEIGHT_PATH = '/tmp/ctpn.h5' 70 | ``` 71 | 72 | d. 检测文本 73 | 74 | ```shell 75 | python predict.py --image_path image_3.jpg 76 | ``` 77 | 78 | ## 评估 79 | 80 | a. 执行如下命令,并将输出的txt压缩为zip包 81 | ```shell 82 | python evaluate.py --weight_path /tmp/ctpn.100.h5 --image_dir /opt/dataset/OCR/ICDAR_2015/test_images/ --output_dir /tmp/output_2015/ 83 | ``` 84 | 85 | b. 提交在线评估 86 | 将压缩的zip包提交评估,评估地址:http://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1 87 | 88 | ## 训练 89 | 90 | a. 训练数据下载 91 | ```shell 92 | #icdar2013 93 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task12_Images.zip 94 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task1_GT.zip 95 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Test_Task12_Images.zip 96 | ``` 97 | 98 | ```shell 99 | #icdar2015 100 | wget http://rrc.cvc.uab.es/downloads/ch4_training_images.zip 101 | wget http://rrc.cvc.uab.es/downloads/ch4_training_localization_transcription_gt.zip 102 | wget http://rrc.cvc.uab.es/downloads/ch4_test_images.zip 103 | ``` 104 | 105 | ```shell 106 | #icdar2017 107 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_images_1~8.zip 108 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_localization_transcription_gt_v2.zip 109 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_test_images.zip 110 | ``` 111 | 112 | 113 | 114 | b. resnet50与训练模型下载 115 | 116 | ```shell 117 | wget https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 118 | ``` 119 | 120 | 121 | 122 | c. 修改配置类config.py中,如下属性 123 | 124 | ```python 125 | # 预训练模型 126 | PRE_TRAINED_WEIGHT = '/opt/pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' 127 | 128 | # 数据集路径 129 | IMAGE_DIR = '/opt/dataset/OCR/ICDAR_2015/train_images' 130 | IMAGE_GT_DIR = '/opt/dataset/OCR/ICDAR_2015/train_gt' 131 | ``` 132 | 133 | d.训练 134 | 135 | ```shell 136 | python train.py --epochs 50 137 | ``` 138 | 139 | 140 | 141 | 142 | 143 | ## 例子 144 | 145 | ### ICDAR2015 146 | 147 | #### 带侧边细化 148 | 149 | ![](image_examples/icdar2015/img_8.1.jpg) 150 | 151 | ![](image_examples/icdar2015/img_200.1.jpg) 152 | 153 | #### 不带侧边细化 154 | ![](image_examples/icdar2015/img_8.0.jpg) 155 | 156 | ![](image_examples/icdar2015/img_200.0.jpg) 157 | 158 | #### 做数据增广-水平翻转 159 | ![](image_examples/flip1.png) 160 | ![](image_examples/flip2.png) 161 | 162 | ### ICDAR2017 163 | 164 | 165 | ![](image_examples/icdar2017/ts_img_01000.1.jpg) 166 | 167 | ![](image_examples/icdar2017/ts_img_01001.1.jpg) 168 | 169 | ### 其它数据集 170 | ![](image_examples/bkgd_1_0_generated_0.1.jpg) 171 | ![](image_examples/a2.png) 172 | ![](image_examples/a1.png) 173 | ![](image_examples/a3.png) 174 | ![](image_examples/a0.png) 175 | 176 | ## toDoList 177 | 178 | 1. 侧边细化(已完成) 179 | 2. ICDAR2017数据集训练(已完成) 180 | 3. 检测文本行坐标映射到原图(已完成) 181 | 4. 精度评估(已完成) 182 | 5. 侧边回归,限制在边框内(已完成) 183 | 6. 增加水平翻转(已完成) 184 | 7. 增加随机裁剪(已完成) 185 | 186 | 187 | 188 | ### 总结 189 | 190 | 1. ctpn对水平文字检测效果不错 191 | 2. 整个网络对于数据集很敏感;在2017上训练的模型到2015上测试效果很不好;同样2015训练的在2013上测试效果也很差 192 | 3. 推测由于双向GRU,网络有存储记忆的缘故?在使用随机裁剪作数据增广时网络不收敛,使用水平翻转时预测结果也水平对称出现 193 | -------------------------------------------------------------------------------- /ctpn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: __init__.py 4 | Description : 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ -------------------------------------------------------------------------------- /ctpn/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: config 4 | Description : 配置类 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | 9 | 10 | class Config(object): 11 | IMAGES_PER_GPU = 4 12 | IMAGE_SHAPE = (720, 720, 3) 13 | MAX_GT_INSTANCES = 1000 14 | 15 | NUM_CLASSES = 1 + 1 # 16 | CLASS_MAPPING = {'bg': 0, 17 | 'text': 1} 18 | # 训练样本 19 | ANCHORS_HEIGHT = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 20 | ANCHORS_WIDTH = 16 21 | TRAIN_ANCHORS_PER_IMAGE = 128 22 | ANCHOR_POSITIVE_RATIO = 0.5 23 | # 步长 24 | NET_STRIDE = 16 25 | # text proposal输出 26 | TEXT_PROPOSALS_MIN_SCORE = 0.7 27 | TEXT_PROPOSALS_NMS_THRESH = 0.3 28 | TEXT_PROPOSALS_MAX_NUM = 500 29 | TEXT_PROPOSALS_WIDTH = 16 30 | # text line boxes超参数 31 | LINE_MIN_SCORE = 0.7 32 | MAX_HORIZONTAL_GAP = 50 33 | TEXT_LINE_NMS_THRESH = 0.3 34 | MIN_NUM_PROPOSALS = 1 35 | MIN_RATIO = 1.2 36 | MIN_V_OVERLAPS = 0.7 37 | MIN_SIZE_SIM = 0.7 38 | 39 | # 训练超参数 40 | LEARNING_RATE = 0.01 41 | LEARNING_MOMENTUM = 0.9 42 | # 权重衰减 43 | WEIGHT_DECAY = 0.0005, 44 | GRADIENT_CLIP_NORM = 5.0 45 | 46 | LOSS_WEIGHTS = { 47 | "ctpn_regress_loss": 1., 48 | "ctpn_class_loss": 1, 49 | "side_regress_loss": 1 50 | } 51 | # 是否使用侧边改善 52 | USE_SIDE_REFINE = True 53 | # 预训练模型 54 | PRE_TRAINED_WEIGHT = '/opt/pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' 55 | 56 | WEIGHT_PATH = '/tmp/ctpn.h5' 57 | 58 | # 数据集路径 59 | IMAGE_DIR = '/opt/dataset/OCR/ICDAR_2015/train_images' 60 | IMAGE_GT_DIR = '/opt/dataset/OCR/ICDAR_2015/train_gt' 61 | 62 | 63 | cur_config = Config() 64 | -------------------------------------------------------------------------------- /ctpn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: __init__.py 4 | Description : layer层 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ -------------------------------------------------------------------------------- /ctpn/layers/anchor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: anchor 4 | Description : ctpn anchor层,在输入图像边框外的anchors丢弃 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | import keras 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | 13 | def generate_anchors(heights, width): 14 | """ 15 | 生成基准anchors 16 | :param heights: 高度列表 17 | :param width: 宽度,数值 18 | :return: 19 | """ 20 | w = np.array([width] * len(heights)) 21 | h = np.array(heights) 22 | return np.stack([-0.5 * h, -0.5 * w, 0.5 * h, 0.5 * w], axis=1) 23 | 24 | 25 | def shift(shape, stride, base_anchors): 26 | """ 27 | 根据feature map的长宽,生成所有的anchors 28 | :param shape: (H,W) 29 | :param stride: 步长 30 | :param base_anchors:所有的基准anchors,(anchor_num,4) 31 | :return: 32 | """ 33 | H, W = shape[0], shape[1] 34 | ctr_x = (tf.cast(tf.range(W), tf.float32) + tf.constant(0.5, dtype=tf.float32)) * stride 35 | ctr_y = (tf.cast(tf.range(H), tf.float32) + tf.constant(0.5, dtype=tf.float32)) * stride 36 | 37 | ctr_x, ctr_y = tf.meshgrid(ctr_x, ctr_y) 38 | 39 | # 打平为1维,得到所有锚点的坐标 40 | ctr_x = tf.reshape(ctr_x, [-1]) 41 | ctr_y = tf.reshape(ctr_y, [-1]) 42 | # (H*W,1,4) 43 | shifts = tf.expand_dims(tf.stack([ctr_y, ctr_x, ctr_y, ctr_x], axis=1), axis=1) 44 | # (1,anchor_num,4) 45 | base_anchors = tf.expand_dims(tf.constant(base_anchors, dtype=tf.float32), axis=0) 46 | 47 | # (H*W,anchor_num,4) 48 | anchors = shifts + base_anchors 49 | # 转为(H*W*anchor_num,4) 返回 50 | return tf.reshape(anchors, [-1, 4]) 51 | 52 | 53 | def filter_out_of_bound_boxes(boxes, feature_shape, stride): 54 | """ 55 | 过滤图像边框外的anchor 56 | :param boxes: [n,y1,x1,y2,x2] 57 | :param feature_shape: 特征图的长宽 [h,w] 58 | :param stride: 网络步长 59 | :return: 60 | """ 61 | # 图像原始长宽为特征图长宽*步长 62 | h, w = feature_shape[0], feature_shape[1] 63 | h = tf.cast(h * stride, tf.float32) 64 | w = tf.cast(w * stride, tf.float32) 65 | 66 | valid_boxes_tag = tf.logical_and(tf.logical_and(tf.logical_and(boxes[:, 0] >= 0, 67 | boxes[:, 1] >= 0), 68 | boxes[:, 2] <= h), 69 | boxes[:, 3] <= w) 70 | boxes = tf.boolean_mask(boxes, valid_boxes_tag) 71 | valid_boxes_indices = tf.where(valid_boxes_tag)[:, 0] 72 | return boxes, valid_boxes_indices 73 | 74 | 75 | class CtpnAnchor(keras.layers.Layer): 76 | def __init__(self, heights, width, stride, **kwargs): 77 | """ 78 | :param heights: 高度列表 79 | :param width: 宽度,数值,如:16 80 | :param stride: 步长, 81 | :param image_shape: tuple(H,W,C) 82 | """ 83 | self.heights = heights 84 | self.width = width 85 | self.stride = stride 86 | # base anchors数量 87 | self.num_anchors = None # 初始化值 88 | super(CtpnAnchor, self).__init__(**kwargs) 89 | 90 | def call(self, inputs, **kwargs): 91 | """ 92 | 93 | :param inputs:输入 卷积层特征(锚点所在层),shape:[batch_size,H,W,C] 94 | :param kwargs: 95 | :return: 96 | """ 97 | features = inputs 98 | features_shape = tf.shape(features) 99 | print("feature_shape:{}".format(features_shape)) 100 | 101 | base_anchors = generate_anchors(self.heights, self.width) 102 | # print("len(base_anchors):".format(len(base_anchors))) 103 | anchors = shift(features_shape[1:3], self.stride, base_anchors) 104 | anchors, valid_anchors_indices = filter_out_of_bound_boxes(anchors, features_shape[1:3], self.stride) 105 | self.num_anchors = tf.shape(anchors)[0] 106 | # 扩展第一维,batch_size;每个样本都有相同的anchors 107 | anchors = tf.tile(tf.expand_dims(anchors, axis=0), [features_shape[0], 1, 1]) 108 | valid_anchors_indices = tf.tile(tf.expand_dims(valid_anchors_indices, axis=0), [features_shape[0], 1]) 109 | 110 | return [anchors, valid_anchors_indices] 111 | 112 | def compute_output_shape(self, input_shape): 113 | """ 114 | 115 | :param input_shape: [batch_size,H,W,C] 116 | :return: 117 | """ 118 | # 计算所有的anchors数量 119 | total = self.num_anchors 120 | return [(input_shape[0], total, 4), 121 | (input_shape[0], total)] 122 | 123 | 124 | def main(): 125 | anchors = generate_anchors([11, 16, 23, 33, 48, 68, 97, 139, 198, 283], 16) 126 | print(anchors) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /ctpn/layers/base_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: base_net 4 | Description : 基网络 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | from keras import backend, layers 9 | from keras.models import Model 10 | 11 | 12 | def identity_block(input_tensor, kernel_size, filters, stage, block): 13 | """The identity block is the block that has no conv layer at shortcut. 14 | 15 | # Arguments 16 | input_tensor: input tensor 17 | kernel_size: default 3, the kernel size of 18 | middle conv layer at main path 19 | filters: list of integers, the filters of 3 conv layer at main path 20 | stage: integer, current stage label, used for generating layer names 21 | block: 'a','b'..., current block label, used for generating layer names 22 | 23 | # Returns 24 | Output tensor for the block. 25 | """ 26 | filters1, filters2, filters3 = filters 27 | if backend.image_data_format() == 'channels_last': 28 | bn_axis = 3 29 | else: 30 | bn_axis = 1 31 | conv_name_base = 'res' + str(stage) + block + '_branch' 32 | bn_name_base = 'bn' + str(stage) + block + '_branch' 33 | 34 | x = layers.Conv2D(filters1, (1, 1), 35 | kernel_initializer='he_normal', 36 | name=conv_name_base + '2a')(input_tensor) 37 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 38 | x = layers.Activation('relu')(x) 39 | 40 | x = layers.Conv2D(filters2, kernel_size, 41 | padding='same', 42 | kernel_initializer='he_normal', 43 | name=conv_name_base + '2b')(x) 44 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 45 | x = layers.Activation('relu')(x) 46 | 47 | x = layers.Conv2D(filters3, (1, 1), 48 | kernel_initializer='he_normal', 49 | name=conv_name_base + '2c')(x) 50 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 51 | 52 | x = layers.add([x, input_tensor]) 53 | x = layers.Activation('relu')(x) 54 | return x 55 | 56 | 57 | def conv_block(input_tensor, 58 | kernel_size, 59 | filters, 60 | stage, 61 | block, 62 | strides=(2, 2)): 63 | """A block that has a conv layer at shortcut. 64 | 65 | # Arguments 66 | input_tensor: input tensor 67 | kernel_size: default 3, the kernel size of 68 | middle conv layer at main path 69 | filters: list of integers, the filters of 3 conv layer at main path 70 | stage: integer, current stage label, used for generating layer names 71 | block: 'a','b'..., current block label, used for generating layer names 72 | strides: Strides for the first conv layer in the block. 73 | 74 | # Returns 75 | Output tensor for the block. 76 | 77 | Note that from stage 3, 78 | the first conv layer at main path is with strides=(2, 2) 79 | And the shortcut should have strides=(2, 2) as well 80 | """ 81 | filters1, filters2, filters3 = filters 82 | if backend.image_data_format() == 'channels_last': 83 | bn_axis = 3 84 | else: 85 | bn_axis = 1 86 | conv_name_base = 'res' + str(stage) + block + '_branch' 87 | bn_name_base = 'bn' + str(stage) + block + '_branch' 88 | 89 | x = layers.Conv2D(filters1, (1, 1), strides=strides, 90 | kernel_initializer='he_normal', 91 | name=conv_name_base + '2a')(input_tensor) 92 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 93 | x = layers.Activation('relu')(x) 94 | 95 | x = layers.Conv2D(filters2, kernel_size, padding='same', 96 | kernel_initializer='he_normal', 97 | name=conv_name_base + '2b')(x) 98 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 99 | x = layers.Activation('relu')(x) 100 | 101 | x = layers.Conv2D(filters3, (1, 1), 102 | kernel_initializer='he_normal', 103 | name=conv_name_base + '2c')(x) 104 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 105 | 106 | shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, 107 | kernel_initializer='he_normal', 108 | name=conv_name_base + '1')(input_tensor) 109 | shortcut = layers.BatchNormalization( 110 | axis=bn_axis, name=bn_name_base + '1')(shortcut) 111 | 112 | x = layers.add([x, shortcut]) 113 | x = layers.Activation('relu')(x) 114 | return x 115 | 116 | 117 | def resnet50(image_input): 118 | bn_axis = 3 119 | 120 | x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(image_input) 121 | x = layers.Conv2D(64, (7, 7), 122 | strides=(2, 2), 123 | padding='valid', 124 | name='conv1')(x) 125 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x) 126 | x = layers.Activation('relu')(x) 127 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 128 | # block 2 129 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 130 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 131 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 132 | # # 确定精调层 133 | no_train_model = Model(inputs=image_input, outputs=x) 134 | for l in no_train_model.layers: 135 | if isinstance(l, layers.BatchNormalization): 136 | l.trainable = True 137 | else: 138 | l.trainable = False 139 | # block 3 140 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 141 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 142 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 143 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 144 | # block 4 145 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 146 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 147 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 148 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 149 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 150 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 151 | # block 5 152 | # x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 153 | # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 154 | # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 155 | 156 | # model = Model(input, x, name='resnet50') 157 | 158 | return x 159 | -------------------------------------------------------------------------------- /ctpn/layers/gt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: gt 4 | Description : 根据GT生成宽度为16的系列GT 5 | Author : mick.yi 6 | date: 2019/3/18 7 | """ 8 | import keras 9 | import tensorflow as tf 10 | from ..utils import tf_utils, gt_utils 11 | from deprecated import deprecated 12 | 13 | 14 | def generate_gt_graph(gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride, max_gt_num): 15 | """ 16 | 17 | :param gt_quadrilaterals: [mat_gt_num,(x1,y1,x2,y2,x3,y3,x4,y4,tag)] 左上、右上、右下、左下(顺时针) 18 | :param input_gt_class_ids: 19 | :param image_shape: 20 | :param width_stride: 21 | :param max_gt_num: 22 | :return: 23 | """ 24 | 25 | gt_quadrilaterals = tf_utils.remove_pad(gt_quadrilaterals) 26 | input_gt_class_ids = tf_utils.remove_pad(input_gt_class_ids) 27 | gt_boxes, gt_class_ids = tf.py_func(func=gt_utils.gen_gt_from_quadrilaterals, 28 | inp=[gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride], 29 | Tout=[tf.float32] * 2) 30 | return tf_utils.pad_list_to_fixed_size([gt_boxes, gt_class_ids], max_gt_num) 31 | 32 | 33 | @deprecated(reason='目前没有用') 34 | class GenGT(keras.layers.Layer): 35 | def __init__(self, image_shape, width_stride, max_gt_num, **kwargs): 36 | self.image_shape = image_shape 37 | self.width_stride = width_stride 38 | self.max_gt_num = max_gt_num 39 | super(GenGT, self).__init__(**kwargs) 40 | 41 | def call(self, inputs, **kwargs): 42 | """ 43 | 44 | :param inputs: gt_quadrilaterals [batch_size,mat_gt_num,(y1,x1,y2,x1,tag)] 45 | :param kwargs: 46 | :return: 47 | """ 48 | gt_quadrilaterals = inputs[0] 49 | input_gt_class_ids = inputs[1] 50 | outputs = tf_utils.batch_slice([gt_quadrilaterals, input_gt_class_ids], 51 | lambda x: generate_gt_graph(x, self.image_shape, 52 | self.width_stride, self.max_gt_num), 53 | batch_size=self.batch_size) 54 | return outputs 55 | 56 | def compute_output_shape(self, input_shape): 57 | return [(input_shape[0][0], self.max_gt_num, 5), 58 | (input_shape[0][0], self.max_gt_num, 2)] 59 | -------------------------------------------------------------------------------- /ctpn/layers/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: losses 4 | Description : 损失函数层 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | import tensorflow as tf 9 | from keras import backend as K 10 | 11 | 12 | def ctpn_cls_loss(predict_cls_ids, true_cls_ids, indices): 13 | """ 14 | ctpn分类损失 15 | :param predict_cls_ids: 预测的anchors类别,(batch_num,anchors_num,2) fg or bg 16 | :param true_cls_ids:实际的anchors类别,(batch_num,rpn_train_anchors,(class_id,tag)) 17 | tag 1:正负样本,0 padding 18 | :param indices: 正负样本索引,(batch_num,rpn_train_anchors,(idx,tag)), 19 | idx:指定anchor索引位置,tag 1:正样本,-1:负样本,0 padding 20 | :return: 21 | """ 22 | # 去除padding 23 | train_indices = tf.where(tf.not_equal(indices[:, :, -1], 0)) # 0为padding 24 | train_anchor_indices = tf.gather_nd(indices[..., 0], train_indices) # 一维(batch*train_num,),每个训练anchor的索引 25 | true_cls_ids = tf.gather_nd(true_cls_ids[..., 0], train_indices) # 一维(batch*train_num,) 26 | # 转为onehot编码 27 | true_cls_ids = tf.where(true_cls_ids >= 1, 28 | tf.ones_like(true_cls_ids, dtype=tf.uint8), 29 | tf.zeros_like(true_cls_ids, dtype=tf.uint8)) # 前景类都为1 30 | true_cls_ids = tf.one_hot(true_cls_ids, depth=2) 31 | # batch索引 32 | batch_indices = train_indices[:, 0] # 训练的第一维是batch索引 33 | # 每个训练anchor的2维索引 34 | train_indices_2d = tf.stack([batch_indices, tf.cast(train_anchor_indices, dtype=tf.int64)], axis=1) 35 | # 获取预测的anchors类别 36 | predict_cls_ids = tf.gather_nd(predict_cls_ids, train_indices_2d) # (batch*train_num,2) 37 | 38 | # 交叉熵损失函数 39 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 40 | labels=true_cls_ids, logits=predict_cls_ids) 41 | return tf.reduce_mean(losses) 42 | 43 | 44 | def smooth_l1_loss(y_true, y_predict, sigma2=9.0): 45 | """ 46 | smooth L1损失函数; 0.5 * sigma2 * x^2 if |x| <1/sigma2 else |x|-0.5/sigma2; x是 diff 47 | :param y_true: 真实值,可以是任何维度 48 | :param y_predict: 预测值 49 | :param sigma2 50 | :return: 51 | """ 52 | abs_diff = tf.abs(y_true - y_predict, name='abs_diff') 53 | loss = tf.where(tf.less(abs_diff, 1. / sigma2), 0.5 * sigma2 * tf.pow(abs_diff, 2), abs_diff - 0.5 / sigma2) 54 | return loss 55 | 56 | 57 | def ctpn_regress_loss(predict_deltas, deltas, indices): 58 | """ 59 | 高度方向中心点偏移和高度尺寸缩放回归损失 60 | :param predict_deltas: 预测的回归目标,(batch_num, anchors_num, 2) 61 | :param deltas: 真实的回归目标,(batch_num, ctpn_train_anchors, 3+1), 最后一位为tag, tag=0 为padding 62 | :param indices: 正负样本索引,(batch_num, ctpn_train_anchors, (idx,tag)), 63 | idx:指定anchor索引位置,最后一位为tag, tag=0 为padding; 1为正样本,-1为负样本 64 | :return: 65 | """ 66 | # 去除padding和负样本 67 | positive_indices = tf.where(tf.equal(indices[:, :, -1], 1)) 68 | deltas = tf.gather_nd(deltas[..., :-2], positive_indices) # (n,(dy,dh,dx,tag)) 69 | true_positive_indices = tf.gather_nd(indices[..., 0], positive_indices) # 一维,正anchor索引 70 | 71 | # batch索引 72 | batch_indices = positive_indices[:, 0] 73 | # 正样本anchor的2维索引 74 | train_indices_2d = tf.stack([batch_indices, tf.cast(true_positive_indices, dtype=tf.int64)], axis=1) 75 | # 正样本anchor预测的回归类型 76 | predict_deltas = tf.gather_nd(predict_deltas, train_indices_2d, name='ctpn_regress_loss_predict_deltas') 77 | 78 | # Smooth-L1 # 非常重要,不然报NAN 79 | loss = K.switch(tf.size(deltas) > 0, 80 | smooth_l1_loss(deltas, predict_deltas), 81 | tf.constant(0.0)) 82 | loss = K.mean(loss) 83 | return loss 84 | 85 | 86 | def side_regress_loss(predict_deltas, deltas, indices): 87 | """ 88 | 侧边改善回归目标 89 | :param predict_deltas: 预测的x周偏移回归目标,(batch_num, anchors_num, 1) 90 | :param deltas: 真实的回归目标,(batch_num, ctpn_train_anchors, 3+1), 最后一位为tag, tag=0 为padding 91 | :param indices: 正负样本索引,(batch_num, ctpn_train_anchors, (idx,tag)), 92 | idx:指定anchor索引位置,最后一位为tag, tag=0 为padding; 1为正样本,-1为负样本 93 | :return: 94 | """ 95 | # 去除padding和负样本 96 | positive_indices = tf.where(tf.equal(indices[:, :, -1], 1)) 97 | deltas = tf.gather_nd(deltas[..., 2:3], positive_indices) # (n,(dy,dh,dx,tag)) 取 dx 98 | true_positive_indices = tf.gather_nd(indices[..., 0], positive_indices) # 一维,正anchor索引 99 | 100 | # batch索引 101 | batch_indices = positive_indices[:, 0] 102 | # 正样本anchor的2维索引 103 | train_indices_2d = tf.stack([batch_indices, tf.cast(true_positive_indices, dtype=tf.int64)], axis=1) 104 | # 正样本anchor预测的回归类型 105 | predict_deltas = tf.gather_nd(predict_deltas, train_indices_2d, name='ctpn_regress_loss_predict_side_deltas') 106 | 107 | # Smooth-L1 # 非常重要,不然报NAN 108 | loss = K.switch(tf.size(deltas) > 0, 109 | smooth_l1_loss(deltas, predict_deltas), 110 | tf.constant(0.0)) 111 | loss = K.mean(loss) 112 | return loss 113 | -------------------------------------------------------------------------------- /ctpn/layers/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: models 4 | Description : 模型 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | import keras 9 | from keras import layers 10 | from keras import Input, Model 11 | import tensorflow as tf 12 | from .base_net import resnet50 13 | from .anchor import CtpnAnchor 14 | from .target import CtpnTarget 15 | from .losses import ctpn_cls_loss, ctpn_regress_loss, side_regress_loss 16 | from .text_proposals import TextProposal 17 | 18 | 19 | def ctpn_net(config, stage='train'): 20 | # 网络构建 21 | # input_image = Input(batch_shape=(config.IMAGES_PER_GPU,) + config.IMAGE_SHAPE, name='input_image') 22 | # input_image_meta = Input(batch_shape=(config.IMAGES_PER_GPU, 12), name='input_image_meta') 23 | # gt_class_ids = Input(batch_shape=(config.IMAGES_PER_GPU, config.MAX_GT_INSTANCES, 2), name='gt_class_ids') 24 | # gt_boxes = Input(batch_shape=(config.IMAGES_PER_GPU, config.MAX_GT_INSTANCES, 5), name='gt_boxes') 25 | input_image = Input(shape=config.IMAGE_SHAPE, name='input_image') 26 | input_image_meta = Input(shape=(12,), name='input_image_meta') 27 | gt_class_ids = Input(shape=(config.MAX_GT_INSTANCES, 2), name='gt_class_ids') 28 | gt_boxes = Input(shape=(config.MAX_GT_INSTANCES, 5), name='gt_boxes') 29 | 30 | # 预测 31 | base_features = resnet50(input_image) 32 | num_anchors = len(config.ANCHORS_HEIGHT) 33 | predict_class_logits, predict_deltas, predict_side_deltas = ctpn(base_features, num_anchors, 64, 256) 34 | 35 | # anchors生成 36 | anchors, valid_anchors_indices = CtpnAnchor(config.ANCHORS_HEIGHT, config.ANCHORS_WIDTH, config.NET_STRIDE, 37 | name='gen_ctpn_anchors')(base_features) 38 | 39 | if stage == 'train': 40 | targets = CtpnTarget(config.IMAGES_PER_GPU, 41 | train_anchors_num=config.TRAIN_ANCHORS_PER_IMAGE, 42 | positive_ratios=config.ANCHOR_POSITIVE_RATIO, 43 | max_gt_num=config.MAX_GT_INSTANCES, 44 | name='ctpn_target')([gt_boxes, gt_class_ids, anchors, valid_anchors_indices]) 45 | deltas, class_ids, anchors_indices = targets[:3] 46 | # 损失函数 47 | regress_loss = layers.Lambda(lambda x: ctpn_regress_loss(*x), 48 | name='ctpn_regress_loss')([predict_deltas, deltas, anchors_indices]) 49 | side_loss = layers.Lambda(lambda x: side_regress_loss(*x), 50 | name='side_regress_loss')([predict_side_deltas, deltas, anchors_indices]) 51 | cls_loss = layers.Lambda(lambda x: ctpn_cls_loss(*x), 52 | name='ctpn_class_loss')([predict_class_logits, class_ids, anchors_indices]) 53 | model = Model(inputs=[input_image, gt_boxes, gt_class_ids], 54 | outputs=[regress_loss, cls_loss, side_loss]) 55 | 56 | else: 57 | text_boxes, text_scores, text_class_logits = TextProposal(config.IMAGES_PER_GPU, 58 | score_threshold=config.TEXT_PROPOSALS_MIN_SCORE, 59 | output_box_num=config.TEXT_PROPOSALS_MAX_NUM, 60 | iou_threshold=config.TEXT_PROPOSALS_NMS_THRESH, 61 | use_side_refine=config.USE_SIDE_REFINE, 62 | name='text_proposals')( 63 | [predict_deltas, predict_side_deltas, predict_class_logits, anchors, valid_anchors_indices]) 64 | image_meta = layers.Lambda(lambda x: x)(input_image_meta) # 原样返回 65 | model = Model(inputs=[input_image, input_image_meta], outputs=[text_boxes, text_scores, image_meta]) 66 | return model 67 | 68 | 69 | def ctpn(base_features, num_anchors, rnn_units=128, fc_units=512): 70 | """ 71 | ctpn网络 72 | :param base_features: (B,H,W,C) 73 | :param num_anchors: anchors个数 74 | :param rnn_units: 75 | :param fc_units: 76 | :return: 77 | """ 78 | x = layers.Conv2D(512, kernel_size=(3, 3), padding='same', name='pre_fc')(base_features) # [B,H,W,512] 79 | # 沿着宽度方式做rnn 80 | rnn_forward = layers.TimeDistributed(layers.GRU(rnn_units, return_sequences=True, kernel_initializer='he_normal'), 81 | name='gru_forward')(x) 82 | rnn_backward = layers.TimeDistributed( 83 | layers.GRU(rnn_units, return_sequences=True, kernel_initializer='he_normal', go_backwards=True), 84 | name='gru_backward')(x) 85 | 86 | rnn_output = layers.Concatenate(name='gru_concat')([rnn_forward, rnn_backward]) # (B,H,W,256) 87 | 88 | # conv实现fc 89 | fc_output = layers.Conv2D(fc_units, kernel_size=(1, 1), activation='relu', name='fc_output')( 90 | rnn_output) # (B,H,W,512) 91 | 92 | # 分类 93 | class_logits = layers.Conv2D(2 * num_anchors, kernel_size=(1, 1), name='cls')(fc_output) 94 | class_logits = layers.Reshape(target_shape=(-1, 2), name='cls_reshape')(class_logits) 95 | # 中心点垂直坐标和高度回归 96 | predict_deltas = layers.Conv2D(2 * num_anchors, kernel_size=(1, 1), name='deltas')(fc_output) 97 | predict_deltas = layers.Reshape(target_shape=(-1, 2), name='deltas_reshape')(predict_deltas) 98 | # 侧边精调(只需要预测x偏移即可) 99 | predict_side_deltas = layers.Conv2D(num_anchors, kernel_size=(1, 1), name='side_deltas')(fc_output) 100 | predict_side_deltas = layers.Reshape(target_shape=(-1, 1), name='side_deltas_reshape')( 101 | predict_side_deltas) 102 | return class_logits, predict_deltas, predict_side_deltas 103 | 104 | 105 | def get_layer(model, name): 106 | for layer in model.layers: 107 | if layer.name == name: 108 | return layer 109 | return None 110 | 111 | 112 | def compile(keras_model, config, loss_names=[]): 113 | """ 114 | 编译模型,增加损失函数,L2正则化以 115 | :param keras_model: 116 | :param config: 117 | :param loss_names: 损失函数列表 118 | :return: 119 | """ 120 | # 优化目标 121 | optimizer = keras.optimizers.SGD( 122 | lr=config.LEARNING_RATE, momentum=config.LEARNING_MOMENTUM, 123 | clipnorm=config.GRADIENT_CLIP_NORM) 124 | # 增加损失函数,首先清除之前的,防止重复 125 | keras_model._losses = [] 126 | keras_model._per_input_losses = {} 127 | 128 | for name in loss_names: 129 | layer = get_layer(keras_model, name) 130 | if layer is None or layer.output in keras_model.losses: 131 | continue 132 | loss = (tf.reduce_mean(layer.output, keepdims=True) 133 | * config.LOSS_WEIGHTS.get(name, 1.)) 134 | keras_model.add_loss(loss) 135 | 136 | # 增加L2正则化 137 | # 跳过批标准化层的 gamma 和 beta 权重 138 | reg_losses = [ 139 | keras.regularizers.l2(config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32) 140 | for w in keras_model.trainable_weights 141 | if 'gamma' not in w.name and 'beta' not in w.name] 142 | keras_model.add_loss(tf.add_n(reg_losses)) 143 | 144 | # 编译 145 | keras_model.compile( 146 | optimizer=optimizer, 147 | loss=[None] * len(keras_model.outputs)) # 使用虚拟损失 148 | 149 | # 为每个损失函数增加度量 150 | for name in loss_names: 151 | if name in keras_model.metrics_names: 152 | continue 153 | layer = get_layer(keras_model, name) 154 | if layer is None: 155 | continue 156 | keras_model.metrics_names.append(name) 157 | loss = ( 158 | tf.reduce_mean(layer.output, keepdims=True) 159 | * config.LOSS_WEIGHTS.get(name, 1.)) 160 | keras_model.metrics_tensors.append(loss) 161 | 162 | 163 | def add_metrics(keras_model, metric_name_list, metric_tensor_list): 164 | """ 165 | 增加度量 166 | :param keras_model: 模型 167 | :param metric_name_list: 度量名称列表 168 | :param metric_tensor_list: 度量张量列表 169 | :return: 无 170 | """ 171 | for name, tensor in zip(metric_name_list, metric_tensor_list): 172 | keras_model.metrics_names.append(name) 173 | keras_model.metrics_tensors.append(tf.reduce_mean(tensor, keepdims=False)) 174 | -------------------------------------------------------------------------------- /ctpn/layers/target.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: Target 4 | Description : 分类和回归目标层 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | 9 | from keras import layers 10 | import tensorflow as tf 11 | from ..utils import tf_utils 12 | 13 | 14 | def compute_iou(gt_boxes, anchors): 15 | """ 16 | 计算iou 17 | :param gt_boxes: [N,(y1,x1,y2,x2)] 18 | :param anchors: [M,(y1,x1,y2,x2)] 19 | :return: IoU [N,M] 20 | """ 21 | gt_boxes = tf.expand_dims(gt_boxes, axis=1) # [N,1,4] 22 | anchors = tf.expand_dims(anchors, axis=0) # [1,M,4] 23 | # 交集 24 | intersect_w = tf.maximum(0.0, 25 | tf.minimum(gt_boxes[:, :, 3], anchors[:, :, 3]) - 26 | tf.maximum(gt_boxes[:, :, 1], anchors[:, :, 1])) 27 | intersect_h = tf.maximum(0.0, 28 | tf.minimum(gt_boxes[:, :, 2], anchors[:, :, 2]) - 29 | tf.maximum(gt_boxes[:, :, 0], anchors[:, :, 0])) 30 | intersect = intersect_h * intersect_w 31 | 32 | # 计算面积 33 | area_gt = (gt_boxes[:, :, 3] - gt_boxes[:, :, 1]) * \ 34 | (gt_boxes[:, :, 2] - gt_boxes[:, :, 0]) 35 | area_anchor = (anchors[:, :, 3] - anchors[:, :, 1]) * \ 36 | (anchors[:, :, 2] - anchors[:, :, 0]) 37 | 38 | # 计算并集 39 | union = area_gt + area_anchor - intersect 40 | # 交并比 41 | iou = tf.divide(intersect, union, name='regress_target_iou') 42 | return iou 43 | 44 | 45 | def ctpn_regress_target(anchors, gt_boxes): 46 | """ 47 | 计算回归目标 48 | :param anchors: [N,(y1,x1,y2,x2)] 49 | :param gt_boxes: [N,(y1,x1,y2,x2)] 50 | :return: [N, (dy, dh, dx)] dx 代表侧边改善的 51 | """ 52 | # anchor高度 53 | h = anchors[:, 2] - anchors[:, 0] 54 | # gt高度 55 | gt_h = gt_boxes[:, 2] - gt_boxes[:, 0] 56 | 57 | # anchor中心点y坐标 58 | center_y = (anchors[:, 2] + anchors[:, 0]) * 0.5 59 | # gt中心点y坐标 60 | gt_center_y = (gt_boxes[:, 2] + gt_boxes[:, 0]) * 0.5 61 | 62 | # 计算回归目标 63 | dy = (gt_center_y - center_y) / h 64 | dh = tf.log(gt_h / h) 65 | dx = side_regress_target(anchors, gt_boxes) # 侧边改善 66 | target = tf.stack([dy, dh, dx], axis=1) 67 | target /= tf.constant([0.1, 0.2, 0.1]) 68 | 69 | return target 70 | 71 | 72 | def side_regress_target(anchors, gt_boxes): 73 | """ 74 | 侧边改善回归目标 75 | :param anchors: [N,(y1,x1,y2,x2)] 76 | :param gt_boxes: anchor 对应的GT boxes[N,(y1,x1,y2,x2)] 77 | :return: 78 | """ 79 | w = anchors[:, 3] - anchors[:, 1] # 实际是固定长度16 80 | center_x = (anchors[:, 3] + anchors[:, 1]) * 0.5 81 | gt_center_x = (gt_boxes[:, 3] + gt_boxes[:, 1]) * 0.5 82 | # 侧边框移动到gt的侧边,相当于中心点偏移的两倍;不是侧边的anchor 偏移为0; 83 | dx = (gt_center_x - center_x) * 2 / w 84 | return dx 85 | 86 | 87 | def ctpn_target_graph(gt_boxes, gt_cls, anchors, valid_anchors_indices, train_anchors_num=128, positive_ratios=0.5, 88 | max_gt_num=50): 89 | """ 90 | 处理单个图像的ctpn回归目标 91 | a)正样本: 与gt IoU大于0.7的anchor,或者与GT IoU最大的那个anchor 92 | b)需要保证所有的GT都有anchor对应 93 | :param gt_boxes: gt边框坐标 [gt_num, (y1,x1,y2,x2,tag)], tag=0为padding 94 | :param gt_cls: gt类别 [gt_num, 1+1], 最后一位为tag, tag=0为padding 95 | :param anchors: [anchor_num, (y1,x1,y2,x2)] 96 | :param valid_anchors_indices:有效的anchors索引 [anchor_num] 97 | :param train_anchors_num 98 | :param positive_ratios 99 | :param max_gt_num 100 | :return: 101 | deltas:[train_anchors_num, (dy,dh,dx,tag)],anchor边框回归目标,tag=1为正负样本,tag=0为padding 102 | class_id:[train_anchors_num,(class_id,tag)] 103 | indices: [train_anchors_num,(anchors_index,tag)] tag=1为正样本,tag=0为padding,-1为负样本 104 | """ 105 | # 获取真正的GT,去除标签位 106 | gt_boxes = tf_utils.remove_pad(gt_boxes) 107 | gt_cls = tf_utils.remove_pad(gt_cls)[:, 0] # [N,1]转[N] 108 | 109 | gt_num = tf.shape(gt_cls)[0] # gt 个数 110 | 111 | # 计算IoU 112 | iou = compute_iou(gt_boxes, anchors) 113 | # 每个GT对应的IoU最大的anchor是正样本(一般有多个) 114 | gt_iou_max = tf.reduce_max(iou, axis=1, keep_dims=True) # 每个gt最大的iou [gt_num,1] 115 | gt_iou_max_bool = tf.equal(iou, gt_iou_max) # bool类型[gt_num,num_anchors];每个gt最大的iou(可能多个) 116 | 117 | # 每个anchors最大iou ,且iou>0.7的为正样本 118 | anchors_iou_max = tf.reduce_max(iou, axis=0, keep_dims=True) # 每个anchor最大的iou; [1,num_anchors] 119 | anchors_iou_max = tf.where(tf.greater_equal(anchors_iou_max, 0.7), 120 | anchors_iou_max, 121 | tf.ones_like(anchors_iou_max)) 122 | anchors_iou_max_bool = tf.equal(iou, anchors_iou_max) 123 | 124 | # 合并两部分正样本索引 125 | positive_bool_matrix = tf.logical_or(gt_iou_max_bool, anchors_iou_max_bool) 126 | # 获取最小的iou,用于度量 127 | gt_match_min_iou = tf.reduce_min(tf.boolean_mask(iou, positive_bool_matrix), keep_dims=True)[0] # 一维 128 | gt_match_mean_iou = tf.reduce_mean(tf.boolean_mask(iou, positive_bool_matrix), keep_dims=True)[0] 129 | # 正样本索引 130 | positive_indices = tf.where(positive_bool_matrix) # 第一维gt索引号,第二维anchor索引号 131 | # before_sample_positive_indices = positive_indices # 采样之前的正样本索引 132 | # 采样正样本 133 | positive_num = tf.minimum(tf.shape(positive_indices)[0], int(train_anchors_num * positive_ratios)) 134 | positive_indices = tf.random_shuffle(positive_indices)[:positive_num] 135 | 136 | # 获取正样本和对应的GT 137 | positive_gt_indices = positive_indices[:, 0] 138 | positive_anchor_indices = positive_indices[:, 1] 139 | positive_anchors = tf.gather(anchors, positive_anchor_indices) 140 | positive_gt_boxes = tf.gather(gt_boxes, positive_gt_indices) 141 | positive_gt_cls = tf.gather(gt_cls, positive_gt_indices) 142 | 143 | # 计算回归目标 144 | deltas = ctpn_regress_target(positive_anchors, positive_gt_boxes) 145 | 146 | # # 获取负样本 iou<0.5 147 | negative_bool = tf.less(tf.reduce_max(iou, axis=0), 0.5) 148 | positive_bool = tf.reduce_any(positive_bool_matrix, axis=0) # 正样本anchors [num_anchors] 149 | negative_bool = tf.logical_and(negative_bool, tf.logical_not(positive_bool)) 150 | 151 | # 采样负样本 152 | negative_num = tf.minimum(int(train_anchors_num * (1. - positive_ratios)), train_anchors_num - positive_num) 153 | negative_indices = tf.random_shuffle(tf.where(negative_bool)[:, 0])[:negative_num] 154 | 155 | negative_gt_cls = tf.zeros([negative_num]) # 负样本类别id为0 156 | negative_deltas = tf.zeros([negative_num, 3]) 157 | 158 | # 合并正负样本 159 | deltas = tf.concat([deltas, negative_deltas], axis=0, name='ctpn_target_deltas') 160 | class_ids = tf.concat([positive_gt_cls, negative_gt_cls], axis=0, name='ctpn_target_class_ids') 161 | indices = tf.concat([positive_anchor_indices, negative_indices], axis=0, 162 | name='ctpn_train_anchor_indices') 163 | indices = tf.gather(valid_anchors_indices, indices) # 对应到有效的索引号 164 | 165 | # 计算padding 166 | deltas, class_ids = tf_utils.pad_list_to_fixed_size([deltas, tf.expand_dims(class_ids, 1)], 167 | train_anchors_num) 168 | # 将负样本tag标志改为-1;方便后续处理; 169 | indices = tf_utils.pad_to_fixed_size_with_negative(tf.expand_dims(indices, 1), train_anchors_num, 170 | negative_num=negative_num, data_type=tf.int64) 171 | 172 | return [deltas, class_ids, indices, tf.cast( # 用作度量的必须是浮点类型 173 | gt_num, dtype=tf.float32), tf.cast( 174 | positive_num, dtype=tf.float32), tf.cast(negative_num, dtype=tf.float32), 175 | gt_match_min_iou, gt_match_mean_iou] 176 | 177 | 178 | class CtpnTarget(layers.Layer): 179 | def __init__(self, batch_size, train_anchors_num=128, positive_ratios=0.5, max_gt_num=50, **kwargs): 180 | self.batch_size = batch_size 181 | self.train_anchors_num = train_anchors_num 182 | self.positive_ratios = positive_ratios 183 | self.max_gt_num = max_gt_num 184 | super(CtpnTarget, self).__init__(**kwargs) 185 | 186 | def call(self, inputs, **kwargs): 187 | """ 188 | 189 | :param inputs: 190 | inputs[0]: GT 边框坐标 [batch_size, MAX_GT_BOXs,(y1,x1,y2,x2,tag)] ,tag=0 为padding 191 | inputs[1]: GT 类别 [batch_size, MAX_GT_BOXs,num_class+1] ;最后一位为tag, tag=0 为padding 192 | inputs[2]: Anchors [batch_size, anchor_num,(y1,x1,y2,x2)] 193 | inputs[3]: val_anchors_indices [batch_size, anchor_num] 194 | :param kwargs: 195 | :return: 196 | """ 197 | gt_boxes, gt_cls_ids, anchors, valid_anchors_indices = inputs 198 | # options = {"train_anchors_num": self.train_anchors_num, 199 | # "positive_ratios": self.positive_ratios, 200 | # "max_gt_num": self.max_gt_num} 201 | # 202 | # outputs = tf.map_fn(fn=lambda x: ctpn_target_graph(*x, **options), 203 | # elems=[gt_boxes, gt_cls_ids, anchors, valid_anchors_indices], 204 | # dtype=[tf.float32] * 2 + [tf.int64] + [tf.float32] + [tf.int64] + [tf.float32] * 3) 205 | outputs = tf_utils.batch_slice([gt_boxes, gt_cls_ids, anchors, valid_anchors_indices], 206 | lambda x, y, z, s: ctpn_target_graph(x, y, z, s, 207 | self.train_anchors_num, 208 | self.positive_ratios, 209 | self.max_gt_num), 210 | batch_size=self.batch_size) 211 | return outputs 212 | 213 | def compute_output_shape(self, input_shape): 214 | return [(input_shape[0][0], self.train_anchors_num, 4), # deltas (dy,dh,dx) 215 | (input_shape[0][0], self.train_anchors_num, 2), # cls 216 | (input_shape[0][0], self.train_anchors_num, 2), # indices 217 | (input_shape[0][0],), # gt_num 218 | (input_shape[0][0],), # positive_num 219 | (input_shape[0][0],), # negative_num 220 | (input_shape[0][0], 1), 221 | (input_shape[0][0], 1)] # gt_match_min_iou 222 | -------------------------------------------------------------------------------- /ctpn/layers/text_proposals.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: text_proposals 4 | Description : 文本提议框生成 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | from keras import layers 9 | import tensorflow as tf 10 | from ..utils import tf_utils 11 | 12 | 13 | def apply_regress(deltas, side_deltas, anchors, use_side_refine=False): 14 | """ 15 | 应用回归目标到边框, 垂直中心点偏移和高度缩放 16 | :param deltas: 回归目标[N,(dy,dh,)] 17 | :param side_deltas: 回归目标[N,(dx)] 18 | :param anchors: anchor boxes[N,(y1,x1,y2,x2)] 19 | :param use_side_refine: 是否应用侧边回归 20 | :return: 21 | """ 22 | # 高度和宽度 23 | h = anchors[:, 2] - anchors[:, 0] 24 | w = anchors[:, 3] - anchors[:, 1] 25 | 26 | # 中心点坐标 27 | cy = (anchors[:, 2] + anchors[:, 0]) * 0.5 28 | cx = (anchors[:, 3] + anchors[:, 1]) * 0.5 29 | 30 | deltas = tf.concat([deltas, side_deltas], axis=1) 31 | # 回归系数 32 | deltas *= tf.constant([0.1, 0.2, 0.1]) 33 | dy, dh, dx = deltas[:, 0], deltas[:, 1], deltas[:, 2] 34 | 35 | # 中心坐标回归 36 | cy += dy * h 37 | # 侧边精调 38 | cx += dx * w 39 | # 高度和宽度回归 40 | h *= tf.exp(dh) 41 | 42 | # 转为y1,x1,y2,x2 43 | y1 = cy - h * 0.5 44 | y2 = cy + h * 0.5 45 | x1 = tf.maximum(cx - w * 0.5, 0.) # 限制在窗口内,修复后继节点找不到对应的前驱节点 46 | x2 = cx + w * 0.5 47 | 48 | if use_side_refine: 49 | return tf.stack([y1, x1, y2, x2], axis=1) 50 | else: 51 | return tf.stack([y1, anchors[:, 1], y2, anchors[:, 3]], axis=1) 52 | 53 | 54 | def get_valid_predicts(deltas, side_deltas, class_logits, valid_anchors_indices): 55 | return tf.gather(deltas, valid_anchors_indices), tf.gather( 56 | side_deltas, valid_anchors_indices), tf.gather( 57 | class_logits, valid_anchors_indices) 58 | 59 | 60 | def nms(boxes, scores, class_logits, max_output_size, iou_threshold=0.5, score_threshold=0.05, 61 | name=None): 62 | """ 63 | 非极大抑制 64 | :param boxes: 形状为[num_boxes, 4]的二维浮点型Tensor. 65 | :param scores: 形状为[num_boxes]的一维浮点型Tensor,表示与每个框(每行框)对应的单个分数. 66 | :param class_logits: 形状为[num_boxes,num_classes] 原始的预测类别 67 | :param max_output_size: 一个标量整数Tensor,表示通过非最大抑制选择的框的最大数量. 68 | :param iou_threshold: 浮点数,IOU 阈值 69 | :param score_threshold: 浮点数, 过滤低于阈值的边框 70 | :param name: 71 | :return: 检测边框、边框得分、边框类别 72 | """ 73 | indices = tf.image.non_max_suppression(boxes, scores, max_output_size, iou_threshold, score_threshold, 74 | name) # 一维索引 75 | output_boxes = tf.gather(boxes, indices) # (M,4) 76 | class_scores = tf.expand_dims(tf.gather(scores, indices), axis=1) # 扩展到二维(M,1) 77 | class_logits = tf.gather(class_logits, indices) 78 | # padding到固定大小 79 | return [tf_utils.pad_to_fixed_size(output_boxes, max_output_size), 80 | tf_utils.pad_to_fixed_size(class_scores, max_output_size), 81 | tf_utils.pad_to_fixed_size(class_logits, max_output_size)] 82 | 83 | 84 | class TextProposal(layers.Layer): 85 | """ 86 | 生成候选框 87 | """ 88 | 89 | def __init__(self, batch_size, score_threshold=0.7, output_box_num=500, iou_threshold=0.3, 90 | use_side_refine=False, **kwargs): 91 | """ 92 | 93 | :param score_threshold: 分数阈值 94 | :param output_box_num: 生成proposal 边框数量 95 | :param iou_threshold: nms iou阈值 96 | :param use_side_refine : 预测时是否使用侧边改善 97 | """ 98 | self.batch_size = batch_size 99 | self.score_threshold = score_threshold 100 | self.output_box_num = output_box_num 101 | self.iou_threshold = iou_threshold 102 | self.use_side_refine = use_side_refine 103 | super(TextProposal, self).__init__(**kwargs) 104 | 105 | def call(self, inputs, **kwargs): 106 | """ 107 | 应用边框回归,并使用nms生成最后的边框 108 | :param inputs: 109 | inputs[0]: deltas, [batch_size,N,(dy,dh)] N是所有的anchors数量 110 | inputs[1]: side_deltas, [batch_size,N,(dx)] N是所有的anchors数量 111 | inputs[2]: class logits [batch_size,N,num_classes] 112 | inputs[3]: anchors [batch_size,N,(y1,x1,y2,x2)] 113 | inputs[4]: valid_anchors_indices [batch_size, anchor_num] 114 | :param kwargs: 115 | :return: 116 | """ 117 | deltas, side_deltas, class_logits, anchors, val_anchors_indices = inputs 118 | # 只看有效anchor的预测结果 119 | deltas, side_deltas, class_logits = tf_utils.batch_slice( 120 | [deltas, side_deltas, class_logits, val_anchors_indices], 121 | lambda x, y, z, u: get_valid_predicts(x, y, z, u), 122 | self.batch_size) 123 | # 转为分类评分 124 | class_scores = tf.nn.softmax(logits=class_logits, axis=-1) # [N,num_classes] 125 | fg_scores = tf.reduce_max(class_scores[..., 1:], axis=-1) # 第一类为背景 (N,) 126 | 127 | # # 应用边框回归 128 | # proposals = tf.map_fn(fn=lambda x: apply_regress(*x), 129 | # elems=[deltas, anchors], 130 | # dtype=tf.float32) 131 | # # # 非极大抑制 132 | # 133 | # options = {"max_output_size": self.output_box_num, 134 | # "iou_threshold": self.iou_threshold, 135 | # "score_threshold": self.score_threshold} 136 | # outputs = tf.map_fn(fn=lambda x: nms(*x, **options), 137 | # elems=[proposals, fg_scores, class_logits], 138 | # dtype=[tf.float32] * 3) 139 | proposals = tf_utils.batch_slice([deltas, side_deltas, anchors], 140 | lambda x, y, z: apply_regress(x, y, z, self.use_side_refine), 141 | self.batch_size) 142 | 143 | outputs = tf_utils.batch_slice([proposals, fg_scores, class_logits], 144 | lambda x, y, z: nms(x, y, z, 145 | max_output_size=self.output_box_num, 146 | iou_threshold=self.iou_threshold, 147 | score_threshold=self.score_threshold), 148 | self.batch_size) 149 | return outputs 150 | 151 | def compute_output_shape(self, input_shape): 152 | """ 153 | 注意多输出,call返回值必须是列表 154 | :param input_shape: 155 | :return: 156 | """ 157 | return [(input_shape[0][0], self.output_box_num, 4 + 1), 158 | (input_shape[0][0], self.output_box_num, 1 + 1), 159 | (input_shape[0][0], self.output_box_num, input_shape[1][-1])] 160 | -------------------------------------------------------------------------------- /ctpn/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: __init__.py 4 | Description : 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ -------------------------------------------------------------------------------- /ctpn/preprocess/reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: reader 4 | Description : 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | import numpy as np 9 | import os 10 | import glob 11 | 12 | 13 | def load_annotation(annotation_path, image_dir): 14 | """ 15 | 加载标注信息 16 | :param annotation_path: 17 | :param image_dir: 18 | :return: 19 | """ 20 | image_annotation = {} 21 | # 文件名称,路径 22 | base_name = os.path.basename(annotation_path) 23 | image_name = base_name[3:-3] + '*' # 通配符 gt_img_3.txt,img_3.jpg or png 24 | image_annotation["annotation_path"] = annotation_path 25 | image_annotation["image_path"] = glob.glob(os.path.join(image_dir, image_name))[0] 26 | image_annotation["file_name"] = os.path.basename(image_annotation["image_path"]) # 图像文件名 27 | # 读取边框标注 28 | bbox = [] 29 | quadrilateral = [] # 四边形 30 | 31 | with open(annotation_path, "r", encoding='utf-8') as f: 32 | lines = f.read().encode('utf-8').decode('utf-8-sig').splitlines() 33 | # lines = f.readlines() 34 | # print(lines) 35 | for line in lines: 36 | line = line.strip().split(",") 37 | # 左上、右上、右下、左下 四个坐标 如:377,117,463,117,465,130,378,130 38 | lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y = map(float, line[:8]) 39 | x_min, y_min, x_max, y_max = min(lt_x, lb_x), min(lt_y, rt_y), max(rt_x, rb_x), max(lb_y, rb_y) 40 | bbox.append([y_min, x_min, y_max, x_max]) 41 | quadrilateral.append([lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y]) 42 | 43 | image_annotation["boxes"] = np.asarray(bbox, np.float32).reshape((-1, 4)) 44 | image_annotation["quadrilaterals"] = np.asarray(quadrilateral, np.float32).reshape((-1, 8)) 45 | image_annotation["labels"] = np.ones(shape=(len(bbox)), dtype=np.uint8) 46 | return image_annotation 47 | -------------------------------------------------------------------------------- /ctpn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: __init__.py 4 | Description : 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ -------------------------------------------------------------------------------- /ctpn/utils/detector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: detector 4 | Description : 文本行检测器 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | import numpy as np 9 | from .text_proposal_connector import TextProposalConnector 10 | from ..utils import np_utils 11 | 12 | 13 | def normalize(data): 14 | if data.shape[0] == 0: 15 | return data 16 | max_ = data.max() 17 | min_ = data.min() 18 | return (data - min_) / (max_ - min_) if max_ - min_ != 0 else data - min_ 19 | 20 | 21 | class TextDetector: 22 | """ 23 | Detect text from an image 24 | """ 25 | 26 | def __init__(self, config): 27 | self.config = config 28 | self.text_proposal_connector = TextProposalConnector() 29 | 30 | def detect(self, text_proposals, scores, image_shape, window): 31 | """ 32 | 检测文本行 33 | :param text_proposals: 文本提议框 34 | :param scores: 文本框得分 35 | :param image_shape: 图像形状 36 | :param window [y1,x1,y2,x2] 去除padding后的窗口 37 | :return: text_lines; [ num,(y1,x1,y2,x2,score)] 38 | """ 39 | 40 | scores = normalize(scores) # 加上后,效果变差; 评估结果好像还是好一点 41 | text_lines = self.text_proposal_connector.get_text_lines(text_proposals, scores, image_shape) 42 | keep_indices = self.filter_boxes(text_lines) 43 | text_lines = text_lines[keep_indices] 44 | text_lines = filter_out_of_window(text_lines, window) 45 | 46 | # 文本行nms 47 | if text_lines.shape[0] != 0: 48 | keep_indices = np_utils.quadrangle_nms(text_lines[:, :8], text_lines[:, 8], 49 | self.config.TEXT_LINE_NMS_THRESH) 50 | text_lines = text_lines[keep_indices] 51 | 52 | return text_lines 53 | 54 | def filter_boxes(self, text_lines): 55 | widths = text_lines[:, 2] - text_lines[:, 0] 56 | scores = text_lines[:, -1] 57 | return np.where((scores > self.config.LINE_MIN_SCORE) & 58 | (widths > (self.config.TEXT_PROPOSALS_WIDTH * self.config.MIN_NUM_PROPOSALS)))[0] 59 | 60 | 61 | def filter_out_of_window(text_lines, window): 62 | """ 63 | 过滤窗口外的text_lines 64 | :param text_lines: [n,9] 65 | :param window: [y1,x1,y2,x2] 66 | :return: 67 | """ 68 | y1, x1, y2, x2 = window 69 | 70 | quadrilaterals = np.reshape(text_lines[:, :8], (-1, 4, 2)) # [n,4 points,(x,y)] 71 | min_x = np.min(quadrilaterals[:, :, 0], axis=1) # [n] 72 | max_x = np.max(quadrilaterals[:, :, 0], axis=1) 73 | min_y = np.min(quadrilaterals[:, :, 1], axis=1) 74 | max_y = np.max(quadrilaterals[:, :, 1], axis=1) 75 | # 窗口内的text_lines 76 | indices = np.where(np.logical_and(np.logical_and(np.logical_and(min_x >= x1, 77 | max_x <= x2), 78 | min_y >= y1), 79 | max_y <= y2)) 80 | return text_lines[indices] 81 | -------------------------------------------------------------------------------- /ctpn/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: file_utils 4 | Description : 文件处理工具类 5 | Author : mick.yi 6 | date: 2019/2/19 7 | """ 8 | import os 9 | 10 | 11 | def get_sub_files(dir_path, recursive=False): 12 | """ 13 | 获取目录下所有文件名 14 | :param dir_path: 15 | :param recursive: 是否递归 16 | :return: 17 | """ 18 | file_paths = [] 19 | for dir_name in os.listdir(dir_path): 20 | cur_dir_path = os.path.join(dir_path, dir_name) 21 | if os.path.isdir(cur_dir_path) and recursive: 22 | file_paths = file_paths + get_sub_files(cur_dir_path) 23 | else: 24 | file_paths.append(cur_dir_path) 25 | return file_paths 26 | -------------------------------------------------------------------------------- /ctpn/utils/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: generator 4 | Description : 生成器 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | import random 9 | import numpy as np 10 | from ..utils import image_utils, np_utils, gt_utils 11 | 12 | 13 | def generator(image_annotations, batch_size, image_shape, width_stride, 14 | max_gt_num, horizontal_flip=False, random_crop=False): 15 | image_length = len(image_annotations) 16 | while True: 17 | ids = np.random.choice(image_length, batch_size, replace=False) 18 | batch_images = [] 19 | batch_images_meta = [] 20 | batch_gt_boxes = [] 21 | batch_gt_class_ids = [] 22 | for id in ids: 23 | image_annotation = image_annotations[id] 24 | image, image_meta, _, gt_quadrilaterals = image_utils.load_image_gt(id, 25 | image_annotation['image_path'], 26 | image_shape[0], 27 | gt_quadrilaterals=image_annotation[ 28 | 'quadrilaterals'], 29 | horizontal_flip=horizontal_flip, 30 | random_crop=random_crop) 31 | class_ids = image_annotation['labels'] 32 | gt_boxes, class_ids = gt_utils.gen_gt_from_quadrilaterals(gt_quadrilaterals, 33 | class_ids, 34 | image_shape, 35 | width_stride, 36 | box_min_size=3) 37 | batch_images.append(image) 38 | batch_images_meta.append(image_meta) 39 | gt_boxes = np_utils.pad_to_fixed_size(gt_boxes[:max_gt_num], max_gt_num) # GT boxes数量防止超出阈值 40 | batch_gt_boxes.append(gt_boxes) 41 | batch_gt_class_ids.append( 42 | np_utils.pad_to_fixed_size(np.expand_dims(np.array(class_ids), axis=1), max_gt_num)) 43 | 44 | # 返回结果 45 | yield {"input_image": np.asarray(batch_images), 46 | "input_image_meta": np.asarray(batch_images_meta), 47 | "gt_class_ids": np.asarray(batch_gt_class_ids), 48 | "gt_boxes": np.asarray(batch_gt_boxes)}, None 49 | -------------------------------------------------------------------------------- /ctpn/utils/gt_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: gt_utils 4 | Description : gt 四边形分割为固定宽度的系列gt boxes 5 | Author : mick.yi 6 | date: 2019/3/18 7 | """ 8 | import numpy as np 9 | 10 | 11 | def linear_fit_y(xs, ys, x_list): 12 | """ 13 | 线性函数拟合两点(x1,y1),(x2,y2);并求得x_list在的取值 14 | :param xs: [x1,x2] 15 | :param ys: [y1,y2] 16 | :param x_list: x轴坐标点,numpy数组 [n] 17 | :return: 18 | """ 19 | if xs[0] == xs[1]: # 垂直线 20 | return np.ones_like(x_list) * np.mean(ys) 21 | elif ys[0] == ys[1]: # 水平线 22 | return np.ones_like(x_list) * ys[0] 23 | else: 24 | fn = np.poly1d(np.polyfit(xs, ys, 1)) # 一元线性函数 25 | return fn(x_list) 26 | 27 | 28 | def get_min_max_y(quadrilateral, xs): 29 | """ 30 | 获取指定x值坐标点集合四边形上的y轴最小值和最大值 31 | :param quadrilateral: 四边形坐标;x1,y1,x2,y2,x3,y3,x4,y4 32 | :param xs: x轴坐标点,numpy数组 [n] 33 | :return: x轴坐标点在四边形上的最小值和最大值 34 | """ 35 | x1, y1, x2, y2, x3, y3, x4, y4 = quadrilateral.tolist() 36 | y_val_1 = linear_fit_y(np.array([x1, x2]), np.array([y1, y2]), xs) 37 | y_val_2 = linear_fit_y(np.array([x2, x3]), np.array([y2, y3]), xs) 38 | y_val_3 = linear_fit_y(np.array([x3, x4]), np.array([y3, y4]), xs) 39 | y_val_4 = linear_fit_y(np.array([x4, x1]), np.array([y4, y1]), xs) 40 | y_val_min = [] 41 | y_val_max = [] 42 | for i in range(len(xs)): 43 | y_val = [] 44 | if min(x1, x2) <= xs[i] <= max(x1, x2): 45 | y_val.append(y_val_1[i]) 46 | if min(x2, x3) <= xs[i] <= max(x2, x3): 47 | y_val.append(y_val_2[i]) 48 | if min(x3, x4) <= xs[i] <= max(x3, x4): 49 | y_val.append(y_val_3[i]) 50 | if min(x4, x1) <= xs[i] <= max(x4, x1): 51 | y_val.append(y_val_4[i]) 52 | # print("y_val:{}".format(y_val)) 53 | y_val_min.append(min(y_val)) 54 | y_val_max.append(max(y_val)) 55 | 56 | return np.array(y_val_min), np.array(y_val_max) 57 | 58 | 59 | def get_xs_in_range(x_array, x_min, x_max): 60 | """ 61 | 获取分割坐标点 62 | :param x_array: 宽度方向分割坐标点数组;0~image_width,间隔16 ;如:[0,16,32,...608] 63 | :param x_min: 四边形x最小值 64 | :param x_max: 四边形x最大值 65 | :return: 66 | """ 67 | indices = np.logical_and(x_array >= x_min, x_array <= x_max) 68 | xs = x_array[indices] 69 | # 处理两端的值 70 | if xs.shape[0] == 0 or xs[0] > x_min: 71 | xs = np.insert(xs, 0, x_min) 72 | if xs.shape[0] == 0 or xs[-1] < x_max: 73 | xs = np.append(xs, x_max) 74 | return xs 75 | 76 | 77 | def gen_gt_from_quadrilaterals(gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride, box_min_size=3): 78 | """ 79 | 从gt 四边形生成,宽度固定的gt boxes 80 | :param gt_quadrilaterals: GT四边形坐标,[n,(x1,y1,x2,y2,x3,y3,x4,y4)] 81 | :param input_gt_class_ids: GT四边形类别,一般就是1 [n] 82 | :param image_shape: 83 | :param width_stride: 分割的步长,一般16 84 | :param box_min_size: 分割后GT boxes的最小尺寸 85 | :return: 86 | gt_boxes:[m,(y1,x1,y2,x2)] 87 | gt_class_ids: [m] 88 | """ 89 | h, w = list(image_shape)[:2] 90 | x_array = np.arange(0, w + 1, width_stride, np.float32) # 固定宽度间隔的x坐标点 91 | # 每个四边形x 最小值和最大值 92 | x_min_np = np.min(gt_quadrilaterals[:, ::2], axis=1) 93 | x_max_np = np.max(gt_quadrilaterals[:, ::2], axis=1) 94 | gt_boxes = [] 95 | gt_class_ids = [] 96 | for i in np.arange(len(gt_quadrilaterals)): 97 | xs = get_xs_in_range(x_array, x_min_np[i], x_max_np[i]) # 获取四边形内的x中坐标点 98 | ys_min, ys_max = get_min_max_y(gt_quadrilaterals[i], xs) 99 | # print("xs:{}".format(xs)) 100 | # 为每个四边形生成固定宽度的gt 101 | for j in range(len(xs) - 1): 102 | x1, x2 = xs[j], xs[j + 1] 103 | y1, y2 = np.min(ys_min[j:j + 2]), np.max(ys_max[j:j + 2]) 104 | gt_boxes.append([y1, x1, y2, x2]) 105 | gt_class_ids.append(input_gt_class_ids[i]) 106 | gt_boxes = np.reshape(np.array(gt_boxes), (-1, 4)) 107 | gt_class_ids = np.reshape(np.array(gt_class_ids), (-1,)) 108 | # 过滤高度太小的边框 109 | height = gt_boxes[:, 2] - gt_boxes[:, 0] 110 | width = gt_boxes[:, 3] - gt_boxes[:, 1] 111 | indices = np.where(np.logical_and(height >= 8, width >= 2)) 112 | return gt_boxes[indices], gt_class_ids[indices] 113 | -------------------------------------------------------------------------------- /ctpn/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: image 4 | Description : 图像处理工具类 5 | Author : mick.yi 6 | date: 2019/2/18 7 | """ 8 | import skimage 9 | from skimage import io, transform 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import random 13 | 14 | 15 | def load_image(image_path): 16 | """ 17 | 加载图像 18 | :param image_path: 图像路径 19 | :return: [h,w,3] numpy数组 20 | """ 21 | image = plt.imread(image_path) 22 | # 灰度图转为RGB 23 | if len(image.shape) == 2: 24 | image = np.expand_dims(image, axis=2) 25 | image = np.tile(image, (1, 1, 3)) 26 | elif image.shape[-1] == 1: 27 | image = skimage.color.gray2rgb(image) # io.imread 报ValueError: Input image expected to be RGB, RGBA or gray 28 | # 标准化为0~255之间 29 | if image.dtype == np.float32: 30 | image *= 255 31 | image = image.astype(np.uint8) 32 | # 删除alpha通道 33 | return image[..., :3] 34 | 35 | 36 | def load_image_gt(image_id, image_path, output_size, gt_boxes=None, 37 | gt_quadrilaterals=None, horizontal_flip=False, random_crop=False): 38 | """ 39 | 加载图像,生成训练输入大小的图像,并调整GT 边框,返回相关元数据信息 40 | :param image_id: 图像编号id 41 | :param image_path: 图像路径 42 | :param output_size: 图像输出尺寸,及网络输入到高度或宽度(默认长宽相等) 43 | :param gt_boxes: GT 边框 [N,(y1,x1,y2,x2)] 44 | :param gt_quadrilaterals: 45 | :param horizontal_flip: 是否水平翻转 46 | :param random_crop: 是否随机裁剪 47 | :return: 48 | image: (H,W,3) 49 | image_meta: 元数据信息,详见compose_image_meta 50 | gt_boxes:图像缩放及padding后对于的GT 边框坐标 [N,(y1,x1,y2,x2)] 51 | """ 52 | # 加载图像 53 | image = load_image(image_path) 54 | # 随机裁剪 55 | if random_crop and random.random() >= 0.5: 56 | min_x, max_x = np.min(gt_quadrilaterals[:, ::2]), np.max(gt_quadrilaterals[:, ::2]) 57 | min_y, max_y = np.min(gt_quadrilaterals[:, 1::2]), np.max(gt_quadrilaterals[:, 1::2]) 58 | image, crop_window = crop_image(image, [min_y, min_x, max_y, max_x]) 59 | # print(image.shape,[min_y, min_x, max_y, max_x],crop_window) 60 | # gt坐标偏移 61 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0: 62 | gt_quadrilaterals[:, 1::2] -= crop_window[0] 63 | gt_quadrilaterals[:, ::2] -= crop_window[1] 64 | # 水平翻转 65 | if horizontal_flip: 66 | image = image[:, ::-1, :] 67 | # gt翻转 68 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0: 69 | gt_quadrilaterals[:, ::2] = image.shape[1] - gt_quadrilaterals[:, ::2] 70 | lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y = np.split(gt_quadrilaterals, 8, axis=1) 71 | gt_quadrilaterals = np.concatenate([rt_x, rt_y, lt_x, lt_y, lb_x, lb_y, rb_x, rb_y], axis=1) 72 | 73 | original_shape = image.shape 74 | # resize图像,并获取相关元数据信息 75 | image, window, scale, padding = resize_image(image, output_size) 76 | 77 | # 组合元数据信息 78 | image_meta = compose_image_meta(image_id, original_shape, image.shape, 79 | window, scale) 80 | # 根据缩放及padding调整GT边框 81 | if gt_boxes is not None and gt_boxes.shape[0] > 0: 82 | gt_boxes = adjust_box(gt_boxes, padding, scale) 83 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0: 84 | gt_quadrilaterals = adjust_quadrilaterals(gt_quadrilaterals, padding, scale) 85 | 86 | return image, image_meta, gt_boxes, gt_quadrilaterals 87 | 88 | 89 | def crop_image(image, gt_window): 90 | h, w = list(image.shape)[:2] 91 | y1, x1, y2, x2 = gt_window 92 | gaps = np.array([y1, x1, h - y2, w - x2]) 93 | wy1 = np.random.randint(min(y1+1, h // 5)) 94 | wx1 = np.random.randint(min(x1+1, w // 5)) 95 | wy2 = h - np.random.randint(min(h - y2 + 1, h // 5)) 96 | wx2 = w - np.random.randint(min(w - x2 + 1, w // 5)) 97 | return image[wy1:wy2, wx1:wx2], [wy1, wx1, wy2, wx2] 98 | 99 | 100 | def resize_image(image, max_dim): 101 | """ 102 | 缩放图像为正方形,指定长边大小,短边padding; 103 | :param image: numpy 数组(H,W,3) 104 | :param max_dim: 长边大小 105 | :return: 缩放后的图像,元素图像的宽口位置,缩放尺寸,padding 106 | """ 107 | image_dtype = image.dtype 108 | h, w = image.shape[:2] 109 | scale = max_dim / max(h, w) # 缩放尺寸 110 | image = transform.resize(image, (round(h * scale), round(w * scale)), 111 | order=1, mode='constant', cval=0, clip=True, preserve_range=True) 112 | h, w = image.shape[:2] 113 | # 计算padding 114 | top_pad = (max_dim - h) // 2 115 | bottom_pad = max_dim - h - top_pad 116 | left_pad = (max_dim - w) // 2 117 | right_pad = max_dim - w - left_pad 118 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 119 | image = np.pad(image, padding, mode='constant', constant_values=0) 120 | # 原始图像在缩放图像上的窗口位置 121 | window = (top_pad, left_pad, h + top_pad, w + left_pad) # 122 | return image.astype(image_dtype), window, scale, padding 123 | 124 | 125 | def compose_image_meta(image_id, original_image_shape, image_shape, 126 | window, scale): 127 | """ 128 | 组合图像元数据信息,返回numpy数据 129 | :param image_id: 130 | :param original_image_shape: 原始图像形状,tuple(H,W,3) 131 | :param image_shape: 缩放后图像形状tuple(H,W,3) 132 | :param window: 原始图像在缩放图像上的窗口位置(y1,x1,y2,x2) 133 | :param scale: 缩放因子 134 | :return: 135 | """ 136 | meta = np.array( 137 | [image_id] + # size=1 138 | list(original_image_shape) + # size=3 139 | list(image_shape) + # size=3 140 | list(window) + # size=4 (y1, x1, y2, x2) in image cooredinates 141 | [scale] # size=1 142 | ) 143 | return meta 144 | 145 | 146 | def parse_image_meta(meta): 147 | """ 148 | 解析图像元数据信息,注意输入是元数据信息数组 149 | :param meta: [12] 150 | :return: 151 | """ 152 | image_id = meta[0] 153 | original_image_shape = meta[1:4] 154 | image_shape = meta[4:7] 155 | window = meta[7:11] # (y1, x1, y2, x2) window of image in in pixels 156 | scale = meta[11] 157 | return { 158 | "image_id": image_id.astype(np.int32), 159 | "original_image_shape": original_image_shape.astype(np.int32), 160 | "image_shape": image_shape.astype(np.int32), 161 | "window": window.astype(np.int32), 162 | "scale": scale.astype(np.float32) 163 | } 164 | 165 | 166 | def batch_parse_image_meta(meta): 167 | """ 168 | 解析图像元数据信息,注意输入是元数据信息数组 169 | :param meta: [batch,12] 170 | :return: 171 | """ 172 | image_id = meta[:, 0] 173 | original_image_shape = meta[:, 1:4] 174 | image_shape = meta[:, 4:7] 175 | window = meta[:, 7:11] # (y1, x1, y2, x2) window of image in in pixels 176 | scale = meta[:, 11] 177 | return { 178 | "image_id": image_id.astype(np.int32), 179 | "original_image_shape": original_image_shape.astype(np.int32), 180 | "image_shape": image_shape.astype(np.int32), 181 | "window": window.astype(np.int32), 182 | "scale": scale.astype(np.float32) 183 | } 184 | 185 | 186 | def adjust_box(boxes, padding, scale): 187 | """ 188 | 根据填充和缩放因子,调整boxes的值 189 | :param boxes: numpy 数组; GT boxes [N,(y1,x1,y2,x2)] 190 | :param padding: [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 191 | :param scale: 缩放因子 192 | :return: 193 | """ 194 | boxes = boxes * scale 195 | boxes[:, 0::2] += padding[0][0] # 高度padding 196 | boxes[:, 1::2] += padding[1][0] # 宽度padding 197 | return boxes 198 | 199 | 200 | def adjust_quadrilaterals(quadrilaterals, padding, scale): 201 | """ 202 | 根据填充和缩放因子,调整四边形的值 203 | :param quadrilaterals: numpy 数组; GT quadrilaterals[N,(x1,y1,x2,y2,x3,y3,x4,y4)] 204 | :param padding: [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 205 | :param scale: 缩放因子 206 | :return: 207 | """ 208 | quadrilaterals = quadrilaterals * scale 209 | quadrilaterals[:, 1::2] += padding[0][0] # 高度padding 210 | quadrilaterals[:, 0::2] += padding[1][0] # 宽度padding 211 | return quadrilaterals 212 | 213 | 214 | def recover_detect_boxes(boxes, window, scale): 215 | """ 216 | 将检测边框映射到原始图像上,去除padding和缩放 217 | :param boxes: numpy数组,[n,(y1,x1,y2,x2)] 218 | :param window: [(y1,x1,y2,x2)] 219 | :param scale: 标量 220 | :return: 221 | """ 222 | # 去除padding 223 | boxes[:, 0::2] -= window[0] 224 | boxes[:, 1::2] -= window[1] 225 | # 还原缩放 226 | boxes /= scale 227 | return boxes 228 | 229 | 230 | def recover_detect_quad(boxes, window, scale): 231 | """ 232 | 将检测四边形映射到原始图像上,去除padding和缩放 233 | :param boxes: numpy数组,[n,(x1,y1,x2,y2,x3,y3,x4,y4)] 234 | :param window: [(y1,x1,y2,x2)] 235 | :param scale: 标量 236 | :return: 237 | """ 238 | # 去除padding 239 | boxes[:, 1::2] -= window[0] # 高度 240 | boxes[:, 0::2] -= window[1] # 宽度 241 | # 还原缩放 242 | boxes /= scale 243 | return boxes 244 | -------------------------------------------------------------------------------- /ctpn/utils/np_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: np_utils 4 | Description : numpy 工具类 5 | Author : mick.yi 6 | date: 2019/2/19 7 | """ 8 | 9 | import numpy as np 10 | from shapely.geometry import Polygon 11 | 12 | 13 | def pad_to_fixed_size(input_np, fixed_size): 14 | """ 15 | 增加padding到固定尺寸,在第二维增加一个标志位,0-padding,1-非padding 16 | :param input_np: 二维数组 17 | :param fixed_size: 18 | :return: 19 | """ 20 | shape = input_np.shape 21 | # 增加tag 22 | np_array = np.pad(input_np, ((0, 0), (0, 1)), mode='constant', constant_values=1) 23 | # 增加padding 24 | pad_num = max(0, fixed_size - shape[0]) 25 | return np.pad(np_array, ((0, pad_num), (0, 0)), mode='constant', constant_values=0) 26 | 27 | 28 | def remove_pad(input_np): 29 | """ 30 | 去除padding 31 | :param input_np: 32 | :return: 33 | """ 34 | pad_tag = input_np[:, -1] # 最后一维是padding 标志,1-非padding 35 | real_size = int(np.sum(pad_tag)) 36 | return input_np[:real_size, :-1] 37 | 38 | 39 | def compute_iou(boxes_a, boxes_b): 40 | """ 41 | numpy 计算IoU 42 | :param boxes_a: (N,4) 43 | :param boxes_b: (M,4) 44 | :return: IoU (N,M) 45 | """ 46 | # 扩维 47 | boxes_a = np.expand_dims(boxes_a, axis=1) # (N,1,4) 48 | boxes_b = np.expand_dims(boxes_b, axis=0) # (1,M,4) 49 | 50 | # 分别计算高度和宽度的交集 51 | overlap_h = np.maximum(0.0, 52 | np.minimum(boxes_a[..., 2], boxes_b[..., 2]) - 53 | np.maximum(boxes_a[..., 0], boxes_b[..., 0])) # (N,M) 54 | 55 | overlap_w = np.maximum(0.0, 56 | np.minimum(boxes_a[..., 3], boxes_b[..., 3]) - 57 | np.maximum(boxes_a[..., 1], boxes_b[..., 1])) # (N,M) 58 | 59 | # 交集 60 | overlap = overlap_w * overlap_h 61 | 62 | # 计算面积 63 | area_a = (boxes_a[..., 2] - boxes_a[..., 0]) * (boxes_a[..., 3] - boxes_a[..., 1]) 64 | area_b = (boxes_b[..., 2] - boxes_b[..., 0]) * (boxes_b[..., 3] - boxes_b[..., 1]) 65 | 66 | # 交并比 67 | iou = overlap / (area_a + area_b - overlap) 68 | return iou 69 | 70 | 71 | def compute_iou_1vn(box, boxes, box_area, boxes_area): 72 | """Calculates IoU of the given box with the array of the given boxes. 73 | box: 1D vector [y1, x1, y2, x2] 74 | boxes: [boxes_count, (y1, x1, y2, x2)] 75 | box_area: float. the area of 'box' 76 | boxes_area: array of length boxes_count. 77 | 78 | Note: the areas are passed in rather than calculated here for 79 | efficiency. Calculate once in the caller to avoid duplicate work. 80 | """ 81 | # Calculate intersection areas 82 | y1 = np.maximum(box[0], boxes[:, 0]) 83 | y2 = np.minimum(box[2], boxes[:, 2]) 84 | x1 = np.maximum(box[1], boxes[:, 1]) 85 | x2 = np.minimum(box[3], boxes[:, 3]) 86 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) 87 | union = box_area + boxes_area[:] - intersection[:] 88 | iou = intersection / union 89 | return iou 90 | 91 | 92 | def threshold(coords, min_, max_): 93 | return np.maximum(np.minimum(coords, max_), min_) 94 | 95 | 96 | def clip_boxes(boxes, im_shape): 97 | """ 98 | 裁剪边框到图像内 99 | :param boxes: 边框 [n,(y1,x1,y2,x2)] 100 | :param im_shape: tuple(H,W,C) 101 | :return: 102 | """ 103 | boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1]) 104 | boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0]) 105 | return boxes 106 | 107 | 108 | def non_max_suppression(boxes, scores, iou_threshold): 109 | """ 110 | 非极大抑制 111 | :param boxes: [n,(y1,x1,y2,x2)] 112 | :param scores: [n] 113 | :param iou_threshold: 114 | :return: 115 | """ 116 | assert boxes.shape[0] > 0 117 | if boxes.dtype.kind != "f": 118 | boxes = boxes.astype(np.float32) 119 | 120 | # Compute box areas 121 | y1 = boxes[:, 0] 122 | x1 = boxes[:, 1] 123 | y2 = boxes[:, 2] 124 | x2 = boxes[:, 3] 125 | area = (y2 - y1) * (x2 - x1) 126 | 127 | # Get indicies of boxes sorted by scores (highest first) 128 | ixs = scores.argsort()[::-1] 129 | 130 | pick = [] 131 | while len(ixs) > 0: 132 | # Pick top box and add its index to the list 133 | i = ixs[0] 134 | pick.append(i) 135 | # Compute IoU of the picked box with the rest 136 | iou = compute_iou_1vn(boxes[i], boxes[ixs[1:]], area[i], area[ixs[1:]]) 137 | # Identify boxes with IoU over the threshold. This 138 | # returns indices into ixs[1:], so add 1 to get 139 | # indices into ixs. 140 | remove_ixs = np.where(iou > iou_threshold)[0] + 1 141 | # Remove indices of the picked and overlapped boxes. 142 | ixs = np.delete(ixs, remove_ixs) 143 | ixs = np.delete(ixs, 0) 144 | return np.array(pick, dtype=np.int32) 145 | 146 | 147 | def quadrangle_iou(quadrangle_a, quadrangle_b): 148 | """ 149 | 四边形iou 150 | :param quadrangle_a: 一维numpy数组[(x1,y1,x2,y2,x3,y3,x4,y4)] 151 | :param quadrangle_b: 一维numpy数组[(x1,y1,x2,y2,x3,y3,x4,y4)] 152 | :return: 153 | """ 154 | a = Polygon(quadrangle_a.reshape((4, 2))) 155 | b = Polygon(quadrangle_b.reshape((4, 2))) 156 | if not a.is_valid or not b.is_valid: 157 | return 0 158 | inter = Polygon(a).intersection(Polygon(b)).area 159 | union = a.area + b.area - inter 160 | if union == 0: 161 | return 0 162 | else: 163 | return inter / union 164 | 165 | 166 | def quadrangle_nms(quadrangles, scores, iou_threshold): 167 | """ 168 | 四边形nms 169 | :param quadrangles: 四边形坐标,二维numpy数组[n,(x1,y1,x2,y2,x3,y3,x4,y4)] 170 | :param scores: 四边形得分,[n] 171 | :param iou_threshold: iou阈值 172 | :return: 173 | """ 174 | order = np.argsort(scores)[::-1] 175 | keep = [] 176 | while order.size > 0: 177 | # 选择得分最高的 178 | i = order[0] 179 | keep.append(i) 180 | # 逐个计算iou 181 | overlap = np.array([quadrangle_iou(quadrangles[i], quadrangles[t]) for t in order[1:]]) 182 | # 小于阈值的,用于下一个极值点选择 183 | indices = np.where(overlap < iou_threshold)[0] 184 | order = order[indices + 1] 185 | 186 | return keep 187 | 188 | 189 | def main(): 190 | x = np.zeros(shape=(0, 4)) 191 | y = pad_to_fixed_size(x, 5) 192 | print(y.shape) 193 | x = np.asarray([], np.float32).reshape((-1, 4)) 194 | y = pad_to_fixed_size(x, 5) 195 | print(y.shape) 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /ctpn/utils/text_proposal_connector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: text_proposal_connector 4 | Description : 文本框连接,构建文本行 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | import numpy as np 9 | from .text_proposal_graph_builder import TextProposalGraphBuilder 10 | from .np_utils import clip_boxes 11 | 12 | 13 | class TextProposalConnector: 14 | """ 15 | 连接文本框构建文本行 16 | """ 17 | 18 | def __init__(self): 19 | self.graph_builder = TextProposalGraphBuilder() 20 | 21 | def group_text_proposals(self, text_proposals, scores, im_size): 22 | """ 23 | 将文本框连接起来,按照文本行分组 24 | :param text_proposals: 文本框,[n,(y1,x1,y2,x2)] 25 | :param scores: 文本框得分,[n] 26 | :param im_size: 图像尺寸,tuple(H,W,C) 27 | :return: list of list; 文本行列表,每个文本行是文本框索引号列表 28 | """ 29 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 30 | return graph.sub_graphs_connected() 31 | 32 | def fit_y(self, X, Y, x1, x2): 33 | """ 34 | 一元线性函数拟合X,Y,并返回x1,x2的的函数值 35 | """ 36 | len(X) != 0 37 | # 只有一个点返回 y=Y[0] 38 | if np.sum(X == X[0]) == len(X): 39 | return Y[0], Y[0] 40 | p = np.poly1d(np.polyfit(X, Y, 1)) 41 | return p(x1), p(x2) 42 | 43 | def get_text_lines(self, text_proposals, scores, im_size): 44 | """ 45 | 获取文本行 46 | :param text_proposals: 文本框,[n,(y1,x1,y2,x2)] 47 | :param scores: 文本框得分,[n] 48 | :param im_size: 图像尺寸,tuple(H,W,C) 49 | :return: 文本行,边框和得分,numpy数组 [m,(y1,x1,y2,x2,score)] 50 | """ 51 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) 52 | text_lines = np.zeros((len(tp_groups), 9), np.float32) 53 | # print("len(tp_groups):{}".format(len(tp_groups))) 54 | # 逐个文本行处理 55 | for index, tp_indices in enumerate(tp_groups): 56 | text_line_boxes = text_proposals[list(tp_indices)] 57 | # 宽度方向最小值和最大值 58 | x_min = np.min(text_line_boxes[:, 1]) 59 | x_max = np.max(text_line_boxes[:, 3]) 60 | # 文本框宽度的一半 61 | offset = (text_line_boxes[0, 3] - text_line_boxes[0, 1]) * 0.5 62 | # 使用一元线性函数求文本行左右两边高度边界 63 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 1], text_line_boxes[:, 0], x_min - offset, x_max + offset) 64 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 1], text_line_boxes[:, 2], x_min - offset, x_max + offset) 65 | 66 | # 文本行的得分为所有文本框得分的均值 67 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) 68 | # 文本行坐标 69 | text_lines[index, 0] = x_min 70 | text_lines[index, 1] = lt_y 71 | text_lines[index, 2] = x_max 72 | text_lines[index, 3] = rt_y 73 | text_lines[index, 4] = x_max 74 | text_lines[index, 5] = rb_y 75 | text_lines[index, 6] = x_min 76 | text_lines[index, 7] = lb_y 77 | text_lines[index, 8] = score 78 | # 裁剪到图像尺寸内 79 | text_lines = clip_boxes(text_lines, im_size) 80 | 81 | return text_lines 82 | -------------------------------------------------------------------------------- /ctpn/utils/text_proposal_graph_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: text_proposal_graph_builder 4 | Description : 文本提议框 图构建;构建文本框对 5 | Author : mick.yi 6 | date: 2019/3/12 7 | """ 8 | import numpy as np 9 | 10 | 11 | class Graph(object): 12 | def __init__(self, graph): 13 | self.graph = graph 14 | 15 | def sub_graphs_connected(self): 16 | """ 17 | 根据图对生成文本行 18 | :return: list of list; 文本行列表,每个文本行是文本框索引号列表 19 | """ 20 | sub_graphs = [] 21 | for index in range(self.graph.shape[0]): 22 | if not self.graph[:, index].any() and self.graph[index, :].any(): 23 | v = index 24 | sub_graphs.append([v]) 25 | while self.graph[v, :].any(): 26 | v = np.where(self.graph[v, :])[0][0] 27 | sub_graphs[-1].append(v) 28 | return sub_graphs 29 | 30 | 31 | class TextProposalGraphBuilder(object): 32 | """ 33 | 文本框的配对构建 34 | """ 35 | 36 | def __init__(self, max_horizontal_gap=50, min_vertical_overlaps=0.7, min_size_similarity=0.7): 37 | """ 38 | 39 | :param max_horizontal_gap: 文本行内,文本框最大水平距离,超出此距离的文本框属于不同的文本行 40 | :param min_vertical_overlaps:文本框最小垂直Iou 41 | :param min_size_similarity: 文本框尺寸最小相似度 42 | """ 43 | 44 | self.max_horizontal_gap = max_horizontal_gap 45 | self.min_vertical_overlaps = min_vertical_overlaps 46 | self.min_size_similarity = min_size_similarity 47 | self.text_proposals = None 48 | self.scores = None 49 | self.im_size = None 50 | self.heights = None 51 | self.boxes_table = None 52 | 53 | def get_successions(self, index): 54 | """ 55 | 获取指定索引号文本框的后继文本框 56 | :param index: 文本框索引号 57 | :return: 所有后继文本框的索引号列表 58 | """ 59 | box = self.text_proposals[index] 60 | results = [] 61 | for left in range(int(box[1]) + 1, min(int(box[1]) + self.max_horizontal_gap + 1, self.im_size[1])): 62 | adj_box_indices = self.boxes_table[left] 63 | for adj_box_index in adj_box_indices: 64 | if self.meet_v_iou(adj_box_index, index): 65 | results.append(adj_box_index) 66 | if len(results) != 0: 67 | return results 68 | return results 69 | 70 | def get_precursors(self, index): 71 | """ 72 | 获取指定索引号文本框的前驱文本框 73 | :param index: 文本框索引号 74 | :return: 所有前驱文本框的索引号列表 75 | """ 76 | box = self.text_proposals[index] 77 | results = [] 78 | # 向前遍历 79 | for left in range(int(box[1]) - 1, max(int(box[1] - self.max_horizontal_gap), 0) - 1, -1): 80 | adj_box_indices = self.boxes_table[left] 81 | for adj_box_index in adj_box_indices: 82 | if self.meet_v_iou(adj_box_index, index): 83 | results.append(adj_box_index) 84 | if len(results) != 0: 85 | return results 86 | return results 87 | 88 | def is_succession_node(self, index, succession_index): 89 | """ 90 | 是否是配对的文本框 91 | :param index: 文本框索引号 92 | :param succession_index: 后继文本框索引号,注:此文本框是后继文本框中 93 | :return: 94 | """ 95 | precursors = self.get_precursors(succession_index) 96 | if self.scores[index] >= np.max(self.scores[precursors]): 97 | return True 98 | return False 99 | 100 | def meet_v_iou(self, index1, index2): 101 | """ 102 | 两个文本框是否满足垂直方向iou条件 103 | :param index1: 104 | :param index2: 105 | :return: True or False 106 | """ 107 | 108 | def overlaps_v(idx1, idx2): 109 | """ 110 | 两个边框垂直方向的iou 111 | """ 112 | # 边框高宽 113 | h1 = self.heights[idx1] 114 | h2 = self.heights[idx2] 115 | # 垂直方向的交集 116 | max_y1 = max(self.text_proposals[idx2][0], self.text_proposals[idx1][0]) 117 | min_y2 = min(self.text_proposals[idx2][2], self.text_proposals[idx1][2]) 118 | return max(0, min_y2 - max_y1) / min(h1, h2) 119 | 120 | def size_similarity(idx1, idx2): 121 | """ 122 | 两个边框高度尺寸相似度 123 | """ 124 | h1 = self.heights[idx1] 125 | h2 = self.heights[idx2] 126 | return min(h1, h2) / max(h1, h2) 127 | 128 | return overlaps_v(index1, index2) >= self.min_vertical_overlaps and \ 129 | size_similarity(index1, index2) >= self.min_size_similarity 130 | 131 | def build_graph(self, text_proposals, scores, im_size): 132 | """ 133 | 根据文本框构建文本框对 134 | :param text_proposals: 文本框,numpy 数组,[n,(y1,x1,y2,x2)] 135 | :param scores: 文本框得分,[n] 136 | :param im_size: 图像尺寸,tuple(H,W,C) 137 | :return: 返回二维bool类型 numpy数组,[n,n];指示文本框两两之间是否配对 138 | """ 139 | self.text_proposals = text_proposals 140 | self.scores = scores 141 | self.im_size = im_size 142 | self.heights = text_proposals[:, 2] - text_proposals[:, 0] # 所有文本框的高宽 143 | 144 | # 安装每个文本框左侧坐标x1分组 145 | im_width = self.im_size[1] 146 | boxes_table = [[] for _ in range(im_width)] 147 | for index, box in enumerate(text_proposals): 148 | boxes_table[int(box[1])].append(index) 149 | self.boxes_table = boxes_table 150 | 151 | # 构建文本对,numpy数组[N,N],bool类型;如果 152 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) 153 | 154 | for index, box in enumerate(text_proposals): 155 | # 获取当前文本框(Bi)的后继文本框 156 | successions = self.get_successions(index) 157 | if len(successions) == 0: 158 | continue 159 | # 后继文本框中得分最高的那个,记做Bj 160 | succession_index = successions[np.argmax(scores[successions])] 161 | # 获取Bj的前驱文本框 162 | precursors = self.get_precursors(succession_index) 163 | # print("{},{},{}".format(index, succession_index, precursors)) 164 | # 如果Bi也是,也是Bj的前驱文本框中,得分最高的那个;则Bi,Bj构成文本框对 165 | if self.scores[index] >= np.max(self.scores[precursors]): 166 | graph[index, succession_index] = True 167 | return Graph(graph) 168 | -------------------------------------------------------------------------------- /ctpn/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: tf_utils 4 | Description : tensorflow工具类 5 | Author : mick.yi 6 | date: 2019/3/13 7 | """ 8 | import tensorflow as tf 9 | from deprecated import deprecated 10 | 11 | 12 | @deprecated(reason='建议使用原生tf.map_fn;效率更高,并且不需要显示传入batch_size参数') 13 | def batch_slice(inputs, graph_fn, batch_size, names=None): 14 | """ 15 | 将输入分片,然后每个分片执行指定计算,最后组合结果;适用于批量处理计算图逻辑只支持一个实例的情况 16 | :param inputs: tensor列表 17 | :param graph_fn: 计算逻辑 18 | :param batch_size: 19 | :param names: 20 | :return: 21 | """ 22 | if not isinstance(inputs, list): 23 | inputs = [inputs] 24 | 25 | outputs = [] 26 | for i in range(batch_size): 27 | inputs_slice = [x[i] for x in inputs] 28 | output_slice = graph_fn(*inputs_slice) 29 | if not isinstance(output_slice, (tuple, list)): 30 | output_slice = [output_slice] 31 | outputs.append(output_slice) 32 | 33 | # 行转列 34 | outputs = list(zip(*outputs)) 35 | 36 | if names is None: 37 | names = [None] * len(outputs) 38 | # list转tensor 39 | result = [tf.stack(o, axis=0, name=n) 40 | for o, n in zip(outputs, names)] 41 | # 如果返回单个值,不使用list 42 | if len(result) == 1: 43 | result = result[0] 44 | 45 | return result 46 | 47 | 48 | def pad_to_fixed_size_with_negative(input_tensor, fixed_size, negative_num, data_type=tf.float32): 49 | # 输入尺寸 50 | input_size = tf.shape(input_tensor)[0] 51 | # tag 列 padding 52 | positive_num = input_size - negative_num # 正例数 53 | # 正样本padding 1,负样本padding -1 54 | column_padding = tf.concat([tf.ones([positive_num], data_type), 55 | tf.ones([negative_num], data_type) * -1], 56 | axis=0) 57 | # 都转为float,拼接 58 | x = tf.concat([tf.cast(input_tensor, data_type), tf.expand_dims(column_padding, axis=1)], axis=1) 59 | # 不够的padding 0 60 | padding_size = tf.maximum(0, fixed_size - input_size) 61 | x = tf.pad(x, [[0, padding_size], [0, 0]], mode='CONSTANT', constant_values=0) 62 | return x 63 | 64 | 65 | def pad_to_fixed_size(input_tensor, fixed_size): 66 | """ 67 | 增加padding到固定尺寸,在第二维增加一个标志位,0-padding,1-非padding 68 | :param input_tensor: 二维张量 69 | :param fixed_size: 70 | :param negative_num: 负样本数量 71 | :return: 72 | """ 73 | input_size = tf.shape(input_tensor)[0] 74 | x = tf.pad(input_tensor, [[0, 0], [0, 1]], mode='CONSTANT', constant_values=1) 75 | # padding 76 | padding_size = tf.maximum(0, fixed_size - input_size) 77 | x = tf.pad(x, [[0, padding_size], [0, 0]], mode='CONSTANT', constant_values=0) 78 | return x 79 | 80 | 81 | def pad_list_to_fixed_size(tensor_list, fixed_size): 82 | return [pad_to_fixed_size(tensor, fixed_size) for tensor in tensor_list] 83 | 84 | 85 | def remove_pad(input_tensor): 86 | """ 87 | 88 | :param input_tensor: 89 | :return: 90 | """ 91 | pad_tag = input_tensor[..., -1] 92 | real_size = tf.cast(tf.reduce_sum(pad_tag), tf.int32) 93 | return input_tensor[:real_size, :-1] 94 | 95 | 96 | def clip_boxes(boxes, window): 97 | """ 98 | 将boxes裁剪到指定的窗口范围内 99 | :param boxes: 边框坐标,[N,(y1,x1,y2,x2)] 100 | :param window: 窗口坐标,[(y1,x1,y2,x2)] 101 | :return: 102 | """ 103 | wy1, wx1, wy2, wx2 = tf.split(window, 4) 104 | y1, x1, y2, x2 = tf.split(boxes, 4, axis=1) # split后维数不变 105 | 106 | y1 = tf.maximum(tf.minimum(y1, wy2), wy1) # wy1<=y1<=wy2 107 | y2 = tf.maximum(tf.minimum(y2, wy2), wy1) 108 | x1 = tf.maximum(tf.minimum(x1, wx2), wx1) 109 | x2 = tf.maximum(tf.minimum(x2, wx2), wx1) 110 | 111 | clipped_boxes = tf.concat([y1, x1, y2, x2], axis=1, name='clipped_boxes') 112 | # clipped_boxes.([boxes.shape[0], 4]) 113 | return clipped_boxes 114 | 115 | 116 | def apply_regress(deltas, anchors): 117 | """ 118 | 应用回归目标到边框 119 | :param deltas: 回归目标[N,(dy, dx, dh, dw)] 120 | :param anchors: anchor boxes[N,(y1,x1,y2,x2)] 121 | :return: 122 | """ 123 | # 高度和宽度 124 | h = anchors[:, 2] - anchors[:, 0] 125 | w = anchors[:, 3] - anchors[:, 1] 126 | 127 | # 中心点坐标 128 | cy = (anchors[:, 2] + anchors[:, 0]) * 0.5 129 | cx = (anchors[:, 3] + anchors[:, 1]) * 0.5 130 | 131 | # 回归系数 132 | deltas *= tf.constant([0.1, 0.1, 0.2, 0.2]) 133 | dy, dx, dh, dw = deltas[:, 0], deltas[:, 1], deltas[:, 2], deltas[:, 3] 134 | 135 | # 中心坐标回归 136 | cy += dy * h 137 | cx += dx * w 138 | # 高度和宽度回归 139 | h *= tf.exp(dh) 140 | w *= tf.exp(dw) 141 | 142 | # 转为y1,x1,y2,x2 143 | y1 = cy - h * 0.5 144 | x1 = cx - w * 0.5 145 | y2 = cy + h * 0.5 146 | x2 = cx + w * 0.5 147 | 148 | return tf.stack([y1, x1, y2, x2], axis=1) 149 | -------------------------------------------------------------------------------- /ctpn/utils/visualize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: visualize 4 | Description : 可视化 5 | Author : mick.yi 6 | date: 2019/2/20 7 | """ 8 | 9 | import matplotlib.pyplot as plt 10 | from matplotlib import patches 11 | import random 12 | import colorsys 13 | import numpy as np 14 | 15 | 16 | def random_colors(N, bright=True): 17 | """ 18 | 生成随机RGB颜色 19 | :param N: 颜色数量 20 | :param bright: 21 | :return: 22 | """ 23 | brightness = 1.0 if bright else 0.7 24 | hsv = [(i / N, 1, brightness) for i in range(N)] 25 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 26 | random.shuffle(colors) 27 | return colors 28 | 29 | 30 | def display_boxes(image, boxes, 31 | scores=None, title="", 32 | figsize=(16, 16), ax=None, 33 | show_bbox=True, 34 | colors=None): 35 | """ 36 | 可视化实例 37 | :param image: numpy数组,[h,w,c} 38 | :param boxes: 边框坐标 [num_instance,(y1,y2,x1,x2)] 39 | :param scores: (optional)预测类别得分[num_instances] 40 | :param title: (optional)标题 41 | :param figsize: (optional) 42 | :param ax:(optional) 43 | :param show_bbox:(optional) 44 | :param colors:(optional) 45 | :return: 46 | """ 47 | 48 | # Number of instances 49 | N = boxes.shape[0] 50 | if not N: 51 | print("\n*** No instances to display *** \n") 52 | 53 | # If no axis is passed, create one and automatically call show() 54 | auto_show = False 55 | if not ax: 56 | _, ax = plt.subplots(1, figsize=figsize) 57 | auto_show = True 58 | 59 | # Generate random colors 60 | colors = colors or random_colors(N) 61 | 62 | # Show area outside image boundaries. 63 | height, width = image.shape[:2] 64 | ax.set_ylim(height + 10, -10) 65 | ax.set_xlim(-10, width + 10) 66 | ax.axis('on') 67 | ax.set_title(title) 68 | 69 | masked_image = image.astype(np.uint32).copy() 70 | for i in range(N): 71 | color = colors[i] 72 | 73 | # Bounding box 74 | if not np.any(boxes[i]): 75 | # Skip this instance. Has no bbox. Likely lost in image cropping. 76 | continue 77 | y1, x1, y2, x2 = boxes[i] 78 | if show_bbox: 79 | p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, 80 | alpha=0.7, linestyle="dashed", 81 | edgecolor=color, facecolor='none') 82 | ax.add_patch(p) 83 | ax.text(x1, y1 + 8, scores[i] if scores is not None else '', 84 | color='w', size=11, backgroundcolor="none") 85 | 86 | ax.imshow(masked_image.astype(np.uint8)) 87 | if auto_show: 88 | plt.show() 89 | 90 | 91 | def display_polygons(image, polygons, scores=None, figsize=(16, 16), ax=None, colors=None): 92 | auto_show = False 93 | if ax is None: 94 | _, ax = plt.subplots(1, figsize=figsize) 95 | auto_show = True 96 | if colors is None: 97 | colors = random_colors(len(polygons)) 98 | 99 | height, width = image.shape[:2] 100 | ax.set_ylim(height + 10, -10) 101 | ax.set_xlim(-10, width + 10) 102 | ax.axis('off') 103 | 104 | for i, polygon in enumerate(polygons): 105 | color = colors[i] 106 | polygon = np.reshape(polygon, (-1, 2)) # 转为[n,(x,y)] 107 | patch = patches.Polygon(polygon, facecolor=None, fill=False, color=color) 108 | ax.add_patch(patch) 109 | # 多边形得分 110 | x1, y1 = polygon[0][:] 111 | ax.text(x1, y1 - 1, scores[i] if scores is not None else '', 112 | color='w', size=11, backgroundcolor="none") 113 | ax.imshow(image.astype(np.uint8)) 114 | if auto_show: 115 | plt.show() 116 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: evaluate 4 | Description : 评估入口 5 | Author : mick.yi 6 | date: 2019/3/21 7 | """ 8 | 9 | import sys 10 | import os 11 | import numpy as np 12 | import argparse 13 | from ctpn.utils import image_utils, file_utils, np_utils 14 | from ctpn.utils.detector import TextDetector 15 | from ctpn.config import cur_config as config 16 | from ctpn.layers import models 17 | import datetime 18 | 19 | 20 | def generator(image_path_list, image_shape): 21 | for i, image_path in enumerate(image_path_list): 22 | image, image_meta, _, _ = image_utils.load_image_gt(np.random.randint(10), 23 | image_path, 24 | image_shape[0]) 25 | if i % 200 == 0: 26 | print("开始评估第 {} 张图像".format(i)) 27 | yield {"input_image": np.asarray([image]), 28 | "input_image_meta": np.asarray([image_meta])} 29 | 30 | 31 | def main(args): 32 | # 覆盖参数 33 | config.USE_SIDE_REFINE = bool(args.use_side_refine) 34 | if args.weight_path is not None: 35 | config.WEIGHT_PATH = args.weight_path 36 | config.IMAGES_PER_GPU = 1 37 | config.IMAGE_SHAPE = (1024, 1024, 3) 38 | # 图像路径 39 | image_path_list = file_utils.get_sub_files(args.image_dir) 40 | 41 | # 加载模型 42 | m = models.ctpn_net(config, 'test') 43 | m.load_weights(config.WEIGHT_PATH, by_name=True) 44 | 45 | # 预测 46 | start_time = datetime.datetime.now() 47 | gen = generator(image_path_list, config.IMAGE_SHAPE) 48 | text_boxes, text_scores, image_metas = m.predict_generator(generator=gen, 49 | steps=len(image_path_list), 50 | use_multiprocessing=True) 51 | end_time = datetime.datetime.now() 52 | print("======完成{}张图像评估,耗时:{} 秒".format(len(image_path_list), end_time - start_time)) 53 | # 去除padding 54 | text_boxes = [np_utils.remove_pad(text_box) for text_box in text_boxes] 55 | text_scores = [np_utils.remove_pad(text_score)[:, 0] for text_score in text_scores] 56 | image_metas = image_utils.batch_parse_image_meta(image_metas) 57 | # 文本行检测 58 | detector = TextDetector(config) 59 | text_lines = [detector.detect(boxes, scores, config.IMAGE_SHAPE, window) 60 | for boxes, scores, window in zip(text_boxes, text_scores, image_metas["window"])] 61 | # 还原检测文本行边框到原始图像坐标 62 | text_lines = [image_utils.recover_detect_quad(boxes, window, scale) 63 | for boxes, window, scale in zip(text_lines, image_metas["window"], image_metas["scale"])] 64 | 65 | # 写入文档中 66 | for image_path, boxes in zip(image_path_list, text_lines): 67 | output_filename = os.path.splitext('res_' + os.path.basename(image_path))[0] + '.txt' 68 | with open(os.path.join(args.output_dir, output_filename), mode='w') as f: 69 | for box in boxes.astype(np.int32): 70 | f.write("{},{},{},{},{},{},{},{}\r\n".format(box[0], 71 | box[1], 72 | box[2], 73 | box[3], 74 | box[4], 75 | box[5], 76 | box[6], 77 | box[7])) 78 | 79 | 80 | if __name__ == '__main__': 81 | parse = argparse.ArgumentParser() 82 | parse.add_argument("--image_dir", type=str, help="image dir") 83 | parse.add_argument("--output_dir", type=str, help="output dir") 84 | parse.add_argument("--weight_path", type=str, default=None, help="weight path") 85 | parse.add_argument("--use_side_refine", type=int, default=1, help="1: use side refine; 0 not use") 86 | argments = parse.parse_args(sys.argv[1:]) 87 | main(argments) 88 | -------------------------------------------------------------------------------- /image_examples/a0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a0.png -------------------------------------------------------------------------------- /image_examples/a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a1.png -------------------------------------------------------------------------------- /image_examples/a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a2.png -------------------------------------------------------------------------------- /image_examples/a3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a3.png -------------------------------------------------------------------------------- /image_examples/bkgd_1_0_generated_0.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/bkgd_1_0_generated_0.1.jpg -------------------------------------------------------------------------------- /image_examples/flip1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/flip1.png -------------------------------------------------------------------------------- /image_examples/flip2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/flip2.png -------------------------------------------------------------------------------- /image_examples/icdar2015/img_200.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_200.0.jpg -------------------------------------------------------------------------------- /image_examples/icdar2015/img_200.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_200.1.jpg -------------------------------------------------------------------------------- /image_examples/icdar2015/img_5.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_5.0.jpg -------------------------------------------------------------------------------- /image_examples/icdar2015/img_5.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_5.1.jpg -------------------------------------------------------------------------------- /image_examples/icdar2015/img_8.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_8.0.jpg -------------------------------------------------------------------------------- /image_examples/icdar2015/img_8.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_8.1.jpg -------------------------------------------------------------------------------- /image_examples/icdar2017/ts_img_01000.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01000.0.jpg -------------------------------------------------------------------------------- /image_examples/icdar2017/ts_img_01000.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01000.1.jpg -------------------------------------------------------------------------------- /image_examples/icdar2017/ts_img_01001.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01001.0.jpg -------------------------------------------------------------------------------- /image_examples/icdar2017/ts_img_01001.1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01001.1.jpg -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: predict 4 | Description : 模型预测 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | import os 9 | import sys 10 | import numpy as np 11 | import argparse 12 | import matplotlib 13 | 14 | matplotlib.use('Agg') 15 | from matplotlib import pyplot as plt 16 | from ctpn.utils import image_utils, np_utils, visualize 17 | from ctpn.utils.detector import TextDetector 18 | from ctpn.config import cur_config as config 19 | from ctpn.layers import models 20 | 21 | 22 | def main(args): 23 | # 覆盖参数 24 | config.USE_SIDE_REFINE = bool(args.use_side_refine) 25 | if args.weight_path is not None: 26 | config.WEIGHT_PATH = args.weight_path 27 | config.IMAGES_PER_GPU = 1 28 | config.IMAGE_SHAPE = (1024, 1024, 3) 29 | # 加载图片 30 | image, image_meta, _, _ = image_utils.load_image_gt(np.random.randint(10), 31 | args.image_path, 32 | config.IMAGE_SHAPE[0], 33 | None) 34 | # 加载模型 35 | m = models.ctpn_net(config, 'test') 36 | m.load_weights(config.WEIGHT_PATH, by_name=True) 37 | # m.summary() 38 | 39 | # 模型预测 40 | text_boxes, text_scores, _ = m.predict([np.array([image]), np.array([image_meta])]) 41 | text_boxes = np_utils.remove_pad(text_boxes[0]) 42 | text_scores = np_utils.remove_pad(text_scores[0])[:, 0] 43 | 44 | # 文本行检测器 45 | image_meta = image_utils.parse_image_meta(image_meta) 46 | detector = TextDetector(config) 47 | text_lines = detector.detect(text_boxes, text_scores, config.IMAGE_SHAPE, image_meta['window']) 48 | # 可视化保存图像 49 | boxes_num = 30 50 | fig = plt.figure(figsize=(16, 16)) 51 | ax = fig.add_subplot(1, 1, 1) 52 | visualize.display_polygons(image, text_lines[:boxes_num, :8], text_lines[:boxes_num, 8], 53 | ax=ax) 54 | image_name = os.path.basename(args.image_path) 55 | fig.savefig('{}.{}.jpg'.format(os.path.splitext(image_name)[0], int(config.USE_SIDE_REFINE))) 56 | 57 | 58 | if __name__ == '__main__': 59 | parse = argparse.ArgumentParser() 60 | parse.add_argument("--image_path", type=str, help="image path") 61 | parse.add_argument("--weight_path", type=str, default=None, help="weight path") 62 | parse.add_argument("--use_side_refine", type=int, default=1, help="1: use side refine; 0 not use") 63 | argments = parse.parse_args(sys.argv[1:]) 64 | main(argments) 65 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: train 4 | Description : ctpn训练 5 | Author : mick.yi 6 | date: 2019/3/14 7 | """ 8 | import os 9 | import sys 10 | import tensorflow as tf 11 | import keras 12 | import argparse 13 | from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau 14 | from ctpn.layers import models 15 | from ctpn.config import cur_config as config 16 | from ctpn.utils import file_utils 17 | from ctpn.utils.generator import generator 18 | from ctpn.preprocess import reader 19 | 20 | 21 | def set_gpu_growth(): 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 23 | cfg = tf.ConfigProto() 24 | cfg.gpu_options.allow_growth = True 25 | session = tf.Session(config=cfg) 26 | keras.backend.set_session(session) 27 | 28 | 29 | def get_call_back(): 30 | """ 31 | 定义call back 32 | :return: 33 | """ 34 | checkpoint = ModelCheckpoint(filepath='/tmp/ctpn.{epoch:03d}.h5', 35 | monitor='val_loss', 36 | verbose=1, 37 | save_best_only=False, 38 | save_weights_only=True, 39 | period=5) 40 | 41 | # 验证误差没有提升 42 | lr_reducer = ReduceLROnPlateau(monitor='loss', 43 | factor=0.1, 44 | cooldown=0, 45 | patience=10, 46 | min_lr=1e-4) 47 | log = TensorBoard(log_dir='log') 48 | return [lr_reducer, checkpoint, log] 49 | 50 | 51 | def main(args): 52 | set_gpu_growth() 53 | # 加载标注 54 | annotation_files = file_utils.get_sub_files(config.IMAGE_GT_DIR) 55 | image_annotations = [reader.load_annotation(file, 56 | config.IMAGE_DIR) for file in annotation_files] 57 | # 过滤不存在的图像,ICDAR2017中部分图像找不到 58 | image_annotations = [ann for ann in image_annotations if os.path.exists(ann['image_path'])] 59 | # 加载模型 60 | m = models.ctpn_net(config, 'train') 61 | models.compile(m, config, loss_names=['ctpn_regress_loss', 'ctpn_class_loss', 'side_regress_loss']) 62 | # 增加度量 63 | output = models.get_layer(m, 'ctpn_target').output 64 | models.add_metrics(m, ['gt_num', 'pos_num', 'neg_num', 'gt_min_iou', 'gt_avg_iou'], output[-5:]) 65 | if args.init_epochs > 0: 66 | m.load_weights(args.weight_path, by_name=True) 67 | else: 68 | m.load_weights(config.PRE_TRAINED_WEIGHT, by_name=True) 69 | m.summary() 70 | # 生成器 71 | gen = generator(image_annotations[:-100], 72 | config.IMAGES_PER_GPU, 73 | config.IMAGE_SHAPE, 74 | config.ANCHORS_WIDTH, 75 | config.MAX_GT_INSTANCES, 76 | horizontal_flip=False, 77 | random_crop=False) 78 | val_gen = generator(image_annotations[-100:], 79 | config.IMAGES_PER_GPU, 80 | config.IMAGE_SHAPE, 81 | config.ANCHORS_WIDTH, 82 | config.MAX_GT_INSTANCES) 83 | 84 | # 训练 85 | m.fit_generator(gen, 86 | steps_per_epoch=len(image_annotations) // config.IMAGES_PER_GPU * 2, 87 | epochs=args.epochs, 88 | initial_epoch=args.init_epochs, 89 | validation_data=val_gen, 90 | validation_steps=100 // config.IMAGES_PER_GPU, 91 | verbose=True, 92 | callbacks=get_call_back(), 93 | workers=2, 94 | use_multiprocessing=True) 95 | 96 | # 保存模型 97 | m.save(config.WEIGHT_PATH) 98 | 99 | 100 | if __name__ == '__main__': 101 | parse = argparse.ArgumentParser() 102 | parse.add_argument("--epochs", type=int, default=100, help="epochs") 103 | parse.add_argument("--init_epochs", type=int, default=0, help="epochs") 104 | parse.add_argument("--weight_path", type=str, default=None, help="weight path") 105 | argments = parse.parse_args(sys.argv[1:]) 106 | main(argments) 107 | --------------------------------------------------------------------------------