├── .gitignore
├── LICENSE
├── README.md
├── ctpn
├── __init__.py
├── config.py
├── layers
│ ├── __init__.py
│ ├── anchor.py
│ ├── base_net.py
│ ├── gt.py
│ ├── losses.py
│ ├── models.py
│ ├── target.py
│ └── text_proposals.py
├── preprocess
│ ├── __init__.py
│ └── reader.py
└── utils
│ ├── __init__.py
│ ├── detector.py
│ ├── file_utils.py
│ ├── generator.py
│ ├── gt_utils.py
│ ├── image_utils.py
│ ├── np_utils.py
│ ├── text_proposal_connector.py
│ ├── text_proposal_graph_builder.py
│ ├── tf_utils.py
│ └── visualize.py
├── evaluate.py
├── image_examples
├── a0.png
├── a1.png
├── a2.png
├── a3.png
├── bkgd_1_0_generated_0.1.jpg
├── flip1.png
├── flip2.png
├── icdar2015
│ ├── img_200.0.jpg
│ ├── img_200.1.jpg
│ ├── img_5.0.jpg
│ ├── img_5.1.jpg
│ ├── img_8.0.jpg
│ └── img_8.1.jpg
└── icdar2017
│ ├── ts_img_01000.0.jpg
│ ├── ts_img_01000.1.jpg
│ ├── ts_img_01001.0.jpg
│ └── ts_img_01001.1.jpg
├── predict.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | .idea
106 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # keras-ctpn
2 |
3 | [TOC]
4 |
5 | 1. [说明](#说明)
6 | 2. [预测](#预测)
7 | 3. [训练](#训练)
8 | 4. [例子](#例子)
9 | 4.1 [ICDAR2015](#ICDAR2015)
10 | 4.1.1 [带侧边细化](#带侧边细化)
11 | 4.1.2 [不带带侧边细化](#不带侧边细化)
12 | 4.1.3 [做数据增广-水平翻转](#做数据增广-水平翻转)
13 | 4.2 [ICDAR2017](#ICDAR2017)
14 | 4.3 [其它数据集](#其它数据集)
15 | 5. [toDoList](#toDoList)
16 | 6. [总结](#总结)
17 |
18 | ## 说明
19 |
20 | 本工程是keras实现的[CPTN: Detecting Text in Natural Image with Connectionist Text Proposal Network](https://arxiv.org/abs/1609.03605) . 本工程实现主要参考了[keras-faster-rcnn](https://github.com/yizt/keras-faster-rcnn) ; 并在ICDAR2015和ICDAR2017数据集上训练和测试。
21 |
22 | 工程地址: [keras-ctpn](https://github.com/yizt/keras-ctpn)
23 |
24 | cptn论文翻译:[CTPN.md](https://github.com/yizt/cv-papers/blob/master/CTPN.md)
25 |
26 | **效果**:
27 |
28 | 使用ICDAR2015的1000张图像训练在500张测试集上结果为:Recall: 37.07 % Precision: 42.94 % Hmean: 39.79 %;
29 | 原文中的F值为61%;使用了额外的3000张图像训练。
30 |
31 | **关键点说明**:
32 |
33 | a.骨干网络使用的是resnet50
34 |
35 | b.训练输入图像大小为720\*720; 将图像的长边缩放到720,保持长宽比,短边padding;原文是短边600;预测时使用1024*1024
36 |
37 | c.batch_size为4, 每张图像训练128个anchor,正负样本比为1:1;
38 |
39 | d.分类、边框回归以及侧边细化的损失函数权重为1:1:1;原论文中是1:1:2
40 |
41 | e.侧边细化与边框回归选择一样的正样本anchor;原文中应该是分开选择的
42 |
43 | f.侧边细化还是有效果的(注:网上很多人说没有啥效果)
44 |
45 | g.由于有双向GRU,水平翻转会影响效果(见样例[做数据增广-水平翻转](#做数据增广-水平翻转))
46 |
47 | h.随机裁剪做数据增广,网络不收敛
48 |
49 |
50 |
51 |
52 | ## 预测
53 |
54 | a. 工程下载
55 |
56 | ```bash
57 | git clone https://github.com/yizt/keras-ctpn
58 | ```
59 |
60 |
61 |
62 | b. 预训练模型下载
63 |
64 | ICDAR2015训练集上训练好的模型下载地址: [google drive](https://drive.google.com/open?id=12t-PFYvYwx4In2aRv7OgRFkHa9rCjjn7),[百度云盘](https://pan.baidu.com/s/1GnDATacvBeFXpAwnBW6RaQ) 取码:wm47
65 |
66 | c.修改配置类config.py中如下属性
67 |
68 | ```python
69 | WEIGHT_PATH = '/tmp/ctpn.h5'
70 | ```
71 |
72 | d. 检测文本
73 |
74 | ```shell
75 | python predict.py --image_path image_3.jpg
76 | ```
77 |
78 | ## 评估
79 |
80 | a. 执行如下命令,并将输出的txt压缩为zip包
81 | ```shell
82 | python evaluate.py --weight_path /tmp/ctpn.100.h5 --image_dir /opt/dataset/OCR/ICDAR_2015/test_images/ --output_dir /tmp/output_2015/
83 | ```
84 |
85 | b. 提交在线评估
86 | 将压缩的zip包提交评估,评估地址:http://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1
87 |
88 | ## 训练
89 |
90 | a. 训练数据下载
91 | ```shell
92 | #icdar2013
93 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task12_Images.zip
94 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task1_GT.zip
95 | wget http://rrc.cvc.uab.es/downloads/Challenge2_Test_Task12_Images.zip
96 | ```
97 |
98 | ```shell
99 | #icdar2015
100 | wget http://rrc.cvc.uab.es/downloads/ch4_training_images.zip
101 | wget http://rrc.cvc.uab.es/downloads/ch4_training_localization_transcription_gt.zip
102 | wget http://rrc.cvc.uab.es/downloads/ch4_test_images.zip
103 | ```
104 |
105 | ```shell
106 | #icdar2017
107 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_images_1~8.zip
108 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_localization_transcription_gt_v2.zip
109 | wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_test_images.zip
110 | ```
111 |
112 |
113 |
114 | b. resnet50与训练模型下载
115 |
116 | ```shell
117 | wget https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
118 | ```
119 |
120 |
121 |
122 | c. 修改配置类config.py中,如下属性
123 |
124 | ```python
125 | # 预训练模型
126 | PRE_TRAINED_WEIGHT = '/opt/pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
127 |
128 | # 数据集路径
129 | IMAGE_DIR = '/opt/dataset/OCR/ICDAR_2015/train_images'
130 | IMAGE_GT_DIR = '/opt/dataset/OCR/ICDAR_2015/train_gt'
131 | ```
132 |
133 | d.训练
134 |
135 | ```shell
136 | python train.py --epochs 50
137 | ```
138 |
139 |
140 |
141 |
142 |
143 | ## 例子
144 |
145 | ### ICDAR2015
146 |
147 | #### 带侧边细化
148 |
149 | 
150 |
151 | 
152 |
153 | #### 不带侧边细化
154 | 
155 |
156 | 
157 |
158 | #### 做数据增广-水平翻转
159 | 
160 | 
161 |
162 | ### ICDAR2017
163 |
164 |
165 | 
166 |
167 | 
168 |
169 | ### 其它数据集
170 | 
171 | 
172 | 
173 | 
174 | 
175 |
176 | ## toDoList
177 |
178 | 1. 侧边细化(已完成)
179 | 2. ICDAR2017数据集训练(已完成)
180 | 3. 检测文本行坐标映射到原图(已完成)
181 | 4. 精度评估(已完成)
182 | 5. 侧边回归,限制在边框内(已完成)
183 | 6. 增加水平翻转(已完成)
184 | 7. 增加随机裁剪(已完成)
185 |
186 |
187 |
188 | ### 总结
189 |
190 | 1. ctpn对水平文字检测效果不错
191 | 2. 整个网络对于数据集很敏感;在2017上训练的模型到2015上测试效果很不好;同样2015训练的在2013上测试效果也很差
192 | 3. 推测由于双向GRU,网络有存储记忆的缘故?在使用随机裁剪作数据增广时网络不收敛,使用水平翻转时预测结果也水平对称出现
193 |
--------------------------------------------------------------------------------
/ctpn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: __init__.py
4 | Description :
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
--------------------------------------------------------------------------------
/ctpn/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: config
4 | Description : 配置类
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 |
9 |
10 | class Config(object):
11 | IMAGES_PER_GPU = 4
12 | IMAGE_SHAPE = (720, 720, 3)
13 | MAX_GT_INSTANCES = 1000
14 |
15 | NUM_CLASSES = 1 + 1 #
16 | CLASS_MAPPING = {'bg': 0,
17 | 'text': 1}
18 | # 训练样本
19 | ANCHORS_HEIGHT = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283]
20 | ANCHORS_WIDTH = 16
21 | TRAIN_ANCHORS_PER_IMAGE = 128
22 | ANCHOR_POSITIVE_RATIO = 0.5
23 | # 步长
24 | NET_STRIDE = 16
25 | # text proposal输出
26 | TEXT_PROPOSALS_MIN_SCORE = 0.7
27 | TEXT_PROPOSALS_NMS_THRESH = 0.3
28 | TEXT_PROPOSALS_MAX_NUM = 500
29 | TEXT_PROPOSALS_WIDTH = 16
30 | # text line boxes超参数
31 | LINE_MIN_SCORE = 0.7
32 | MAX_HORIZONTAL_GAP = 50
33 | TEXT_LINE_NMS_THRESH = 0.3
34 | MIN_NUM_PROPOSALS = 1
35 | MIN_RATIO = 1.2
36 | MIN_V_OVERLAPS = 0.7
37 | MIN_SIZE_SIM = 0.7
38 |
39 | # 训练超参数
40 | LEARNING_RATE = 0.01
41 | LEARNING_MOMENTUM = 0.9
42 | # 权重衰减
43 | WEIGHT_DECAY = 0.0005,
44 | GRADIENT_CLIP_NORM = 5.0
45 |
46 | LOSS_WEIGHTS = {
47 | "ctpn_regress_loss": 1.,
48 | "ctpn_class_loss": 1,
49 | "side_regress_loss": 1
50 | }
51 | # 是否使用侧边改善
52 | USE_SIDE_REFINE = True
53 | # 预训练模型
54 | PRE_TRAINED_WEIGHT = '/opt/pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
55 |
56 | WEIGHT_PATH = '/tmp/ctpn.h5'
57 |
58 | # 数据集路径
59 | IMAGE_DIR = '/opt/dataset/OCR/ICDAR_2015/train_images'
60 | IMAGE_GT_DIR = '/opt/dataset/OCR/ICDAR_2015/train_gt'
61 |
62 |
63 | cur_config = Config()
64 |
--------------------------------------------------------------------------------
/ctpn/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: __init__.py
4 | Description : layer层
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
--------------------------------------------------------------------------------
/ctpn/layers/anchor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: anchor
4 | Description : ctpn anchor层,在输入图像边框外的anchors丢弃
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | import keras
9 | import tensorflow as tf
10 | import numpy as np
11 |
12 |
13 | def generate_anchors(heights, width):
14 | """
15 | 生成基准anchors
16 | :param heights: 高度列表
17 | :param width: 宽度,数值
18 | :return:
19 | """
20 | w = np.array([width] * len(heights))
21 | h = np.array(heights)
22 | return np.stack([-0.5 * h, -0.5 * w, 0.5 * h, 0.5 * w], axis=1)
23 |
24 |
25 | def shift(shape, stride, base_anchors):
26 | """
27 | 根据feature map的长宽,生成所有的anchors
28 | :param shape: (H,W)
29 | :param stride: 步长
30 | :param base_anchors:所有的基准anchors,(anchor_num,4)
31 | :return:
32 | """
33 | H, W = shape[0], shape[1]
34 | ctr_x = (tf.cast(tf.range(W), tf.float32) + tf.constant(0.5, dtype=tf.float32)) * stride
35 | ctr_y = (tf.cast(tf.range(H), tf.float32) + tf.constant(0.5, dtype=tf.float32)) * stride
36 |
37 | ctr_x, ctr_y = tf.meshgrid(ctr_x, ctr_y)
38 |
39 | # 打平为1维,得到所有锚点的坐标
40 | ctr_x = tf.reshape(ctr_x, [-1])
41 | ctr_y = tf.reshape(ctr_y, [-1])
42 | # (H*W,1,4)
43 | shifts = tf.expand_dims(tf.stack([ctr_y, ctr_x, ctr_y, ctr_x], axis=1), axis=1)
44 | # (1,anchor_num,4)
45 | base_anchors = tf.expand_dims(tf.constant(base_anchors, dtype=tf.float32), axis=0)
46 |
47 | # (H*W,anchor_num,4)
48 | anchors = shifts + base_anchors
49 | # 转为(H*W*anchor_num,4) 返回
50 | return tf.reshape(anchors, [-1, 4])
51 |
52 |
53 | def filter_out_of_bound_boxes(boxes, feature_shape, stride):
54 | """
55 | 过滤图像边框外的anchor
56 | :param boxes: [n,y1,x1,y2,x2]
57 | :param feature_shape: 特征图的长宽 [h,w]
58 | :param stride: 网络步长
59 | :return:
60 | """
61 | # 图像原始长宽为特征图长宽*步长
62 | h, w = feature_shape[0], feature_shape[1]
63 | h = tf.cast(h * stride, tf.float32)
64 | w = tf.cast(w * stride, tf.float32)
65 |
66 | valid_boxes_tag = tf.logical_and(tf.logical_and(tf.logical_and(boxes[:, 0] >= 0,
67 | boxes[:, 1] >= 0),
68 | boxes[:, 2] <= h),
69 | boxes[:, 3] <= w)
70 | boxes = tf.boolean_mask(boxes, valid_boxes_tag)
71 | valid_boxes_indices = tf.where(valid_boxes_tag)[:, 0]
72 | return boxes, valid_boxes_indices
73 |
74 |
75 | class CtpnAnchor(keras.layers.Layer):
76 | def __init__(self, heights, width, stride, **kwargs):
77 | """
78 | :param heights: 高度列表
79 | :param width: 宽度,数值,如:16
80 | :param stride: 步长,
81 | :param image_shape: tuple(H,W,C)
82 | """
83 | self.heights = heights
84 | self.width = width
85 | self.stride = stride
86 | # base anchors数量
87 | self.num_anchors = None # 初始化值
88 | super(CtpnAnchor, self).__init__(**kwargs)
89 |
90 | def call(self, inputs, **kwargs):
91 | """
92 |
93 | :param inputs:输入 卷积层特征(锚点所在层),shape:[batch_size,H,W,C]
94 | :param kwargs:
95 | :return:
96 | """
97 | features = inputs
98 | features_shape = tf.shape(features)
99 | print("feature_shape:{}".format(features_shape))
100 |
101 | base_anchors = generate_anchors(self.heights, self.width)
102 | # print("len(base_anchors):".format(len(base_anchors)))
103 | anchors = shift(features_shape[1:3], self.stride, base_anchors)
104 | anchors, valid_anchors_indices = filter_out_of_bound_boxes(anchors, features_shape[1:3], self.stride)
105 | self.num_anchors = tf.shape(anchors)[0]
106 | # 扩展第一维,batch_size;每个样本都有相同的anchors
107 | anchors = tf.tile(tf.expand_dims(anchors, axis=0), [features_shape[0], 1, 1])
108 | valid_anchors_indices = tf.tile(tf.expand_dims(valid_anchors_indices, axis=0), [features_shape[0], 1])
109 |
110 | return [anchors, valid_anchors_indices]
111 |
112 | def compute_output_shape(self, input_shape):
113 | """
114 |
115 | :param input_shape: [batch_size,H,W,C]
116 | :return:
117 | """
118 | # 计算所有的anchors数量
119 | total = self.num_anchors
120 | return [(input_shape[0], total, 4),
121 | (input_shape[0], total)]
122 |
123 |
124 | def main():
125 | anchors = generate_anchors([11, 16, 23, 33, 48, 68, 97, 139, 198, 283], 16)
126 | print(anchors)
127 |
128 |
129 | if __name__ == '__main__':
130 | main()
131 |
--------------------------------------------------------------------------------
/ctpn/layers/base_net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: base_net
4 | Description : 基网络
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | from keras import backend, layers
9 | from keras.models import Model
10 |
11 |
12 | def identity_block(input_tensor, kernel_size, filters, stage, block):
13 | """The identity block is the block that has no conv layer at shortcut.
14 |
15 | # Arguments
16 | input_tensor: input tensor
17 | kernel_size: default 3, the kernel size of
18 | middle conv layer at main path
19 | filters: list of integers, the filters of 3 conv layer at main path
20 | stage: integer, current stage label, used for generating layer names
21 | block: 'a','b'..., current block label, used for generating layer names
22 |
23 | # Returns
24 | Output tensor for the block.
25 | """
26 | filters1, filters2, filters3 = filters
27 | if backend.image_data_format() == 'channels_last':
28 | bn_axis = 3
29 | else:
30 | bn_axis = 1
31 | conv_name_base = 'res' + str(stage) + block + '_branch'
32 | bn_name_base = 'bn' + str(stage) + block + '_branch'
33 |
34 | x = layers.Conv2D(filters1, (1, 1),
35 | kernel_initializer='he_normal',
36 | name=conv_name_base + '2a')(input_tensor)
37 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
38 | x = layers.Activation('relu')(x)
39 |
40 | x = layers.Conv2D(filters2, kernel_size,
41 | padding='same',
42 | kernel_initializer='he_normal',
43 | name=conv_name_base + '2b')(x)
44 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
45 | x = layers.Activation('relu')(x)
46 |
47 | x = layers.Conv2D(filters3, (1, 1),
48 | kernel_initializer='he_normal',
49 | name=conv_name_base + '2c')(x)
50 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
51 |
52 | x = layers.add([x, input_tensor])
53 | x = layers.Activation('relu')(x)
54 | return x
55 |
56 |
57 | def conv_block(input_tensor,
58 | kernel_size,
59 | filters,
60 | stage,
61 | block,
62 | strides=(2, 2)):
63 | """A block that has a conv layer at shortcut.
64 |
65 | # Arguments
66 | input_tensor: input tensor
67 | kernel_size: default 3, the kernel size of
68 | middle conv layer at main path
69 | filters: list of integers, the filters of 3 conv layer at main path
70 | stage: integer, current stage label, used for generating layer names
71 | block: 'a','b'..., current block label, used for generating layer names
72 | strides: Strides for the first conv layer in the block.
73 |
74 | # Returns
75 | Output tensor for the block.
76 |
77 | Note that from stage 3,
78 | the first conv layer at main path is with strides=(2, 2)
79 | And the shortcut should have strides=(2, 2) as well
80 | """
81 | filters1, filters2, filters3 = filters
82 | if backend.image_data_format() == 'channels_last':
83 | bn_axis = 3
84 | else:
85 | bn_axis = 1
86 | conv_name_base = 'res' + str(stage) + block + '_branch'
87 | bn_name_base = 'bn' + str(stage) + block + '_branch'
88 |
89 | x = layers.Conv2D(filters1, (1, 1), strides=strides,
90 | kernel_initializer='he_normal',
91 | name=conv_name_base + '2a')(input_tensor)
92 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
93 | x = layers.Activation('relu')(x)
94 |
95 | x = layers.Conv2D(filters2, kernel_size, padding='same',
96 | kernel_initializer='he_normal',
97 | name=conv_name_base + '2b')(x)
98 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
99 | x = layers.Activation('relu')(x)
100 |
101 | x = layers.Conv2D(filters3, (1, 1),
102 | kernel_initializer='he_normal',
103 | name=conv_name_base + '2c')(x)
104 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
105 |
106 | shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
107 | kernel_initializer='he_normal',
108 | name=conv_name_base + '1')(input_tensor)
109 | shortcut = layers.BatchNormalization(
110 | axis=bn_axis, name=bn_name_base + '1')(shortcut)
111 |
112 | x = layers.add([x, shortcut])
113 | x = layers.Activation('relu')(x)
114 | return x
115 |
116 |
117 | def resnet50(image_input):
118 | bn_axis = 3
119 |
120 | x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(image_input)
121 | x = layers.Conv2D(64, (7, 7),
122 | strides=(2, 2),
123 | padding='valid',
124 | name='conv1')(x)
125 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
126 | x = layers.Activation('relu')(x)
127 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
128 | # block 2
129 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
130 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
131 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
132 | # # 确定精调层
133 | no_train_model = Model(inputs=image_input, outputs=x)
134 | for l in no_train_model.layers:
135 | if isinstance(l, layers.BatchNormalization):
136 | l.trainable = True
137 | else:
138 | l.trainable = False
139 | # block 3
140 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
141 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
142 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
143 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
144 | # block 4
145 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
146 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
147 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
148 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
149 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
150 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
151 | # block 5
152 | # x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
153 | # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
154 | # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
155 |
156 | # model = Model(input, x, name='resnet50')
157 |
158 | return x
159 |
--------------------------------------------------------------------------------
/ctpn/layers/gt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: gt
4 | Description : 根据GT生成宽度为16的系列GT
5 | Author : mick.yi
6 | date: 2019/3/18
7 | """
8 | import keras
9 | import tensorflow as tf
10 | from ..utils import tf_utils, gt_utils
11 | from deprecated import deprecated
12 |
13 |
14 | def generate_gt_graph(gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride, max_gt_num):
15 | """
16 |
17 | :param gt_quadrilaterals: [mat_gt_num,(x1,y1,x2,y2,x3,y3,x4,y4,tag)] 左上、右上、右下、左下(顺时针)
18 | :param input_gt_class_ids:
19 | :param image_shape:
20 | :param width_stride:
21 | :param max_gt_num:
22 | :return:
23 | """
24 |
25 | gt_quadrilaterals = tf_utils.remove_pad(gt_quadrilaterals)
26 | input_gt_class_ids = tf_utils.remove_pad(input_gt_class_ids)
27 | gt_boxes, gt_class_ids = tf.py_func(func=gt_utils.gen_gt_from_quadrilaterals,
28 | inp=[gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride],
29 | Tout=[tf.float32] * 2)
30 | return tf_utils.pad_list_to_fixed_size([gt_boxes, gt_class_ids], max_gt_num)
31 |
32 |
33 | @deprecated(reason='目前没有用')
34 | class GenGT(keras.layers.Layer):
35 | def __init__(self, image_shape, width_stride, max_gt_num, **kwargs):
36 | self.image_shape = image_shape
37 | self.width_stride = width_stride
38 | self.max_gt_num = max_gt_num
39 | super(GenGT, self).__init__(**kwargs)
40 |
41 | def call(self, inputs, **kwargs):
42 | """
43 |
44 | :param inputs: gt_quadrilaterals [batch_size,mat_gt_num,(y1,x1,y2,x1,tag)]
45 | :param kwargs:
46 | :return:
47 | """
48 | gt_quadrilaterals = inputs[0]
49 | input_gt_class_ids = inputs[1]
50 | outputs = tf_utils.batch_slice([gt_quadrilaterals, input_gt_class_ids],
51 | lambda x: generate_gt_graph(x, self.image_shape,
52 | self.width_stride, self.max_gt_num),
53 | batch_size=self.batch_size)
54 | return outputs
55 |
56 | def compute_output_shape(self, input_shape):
57 | return [(input_shape[0][0], self.max_gt_num, 5),
58 | (input_shape[0][0], self.max_gt_num, 2)]
59 |
--------------------------------------------------------------------------------
/ctpn/layers/losses.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: losses
4 | Description : 损失函数层
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | import tensorflow as tf
9 | from keras import backend as K
10 |
11 |
12 | def ctpn_cls_loss(predict_cls_ids, true_cls_ids, indices):
13 | """
14 | ctpn分类损失
15 | :param predict_cls_ids: 预测的anchors类别,(batch_num,anchors_num,2) fg or bg
16 | :param true_cls_ids:实际的anchors类别,(batch_num,rpn_train_anchors,(class_id,tag))
17 | tag 1:正负样本,0 padding
18 | :param indices: 正负样本索引,(batch_num,rpn_train_anchors,(idx,tag)),
19 | idx:指定anchor索引位置,tag 1:正样本,-1:负样本,0 padding
20 | :return:
21 | """
22 | # 去除padding
23 | train_indices = tf.where(tf.not_equal(indices[:, :, -1], 0)) # 0为padding
24 | train_anchor_indices = tf.gather_nd(indices[..., 0], train_indices) # 一维(batch*train_num,),每个训练anchor的索引
25 | true_cls_ids = tf.gather_nd(true_cls_ids[..., 0], train_indices) # 一维(batch*train_num,)
26 | # 转为onehot编码
27 | true_cls_ids = tf.where(true_cls_ids >= 1,
28 | tf.ones_like(true_cls_ids, dtype=tf.uint8),
29 | tf.zeros_like(true_cls_ids, dtype=tf.uint8)) # 前景类都为1
30 | true_cls_ids = tf.one_hot(true_cls_ids, depth=2)
31 | # batch索引
32 | batch_indices = train_indices[:, 0] # 训练的第一维是batch索引
33 | # 每个训练anchor的2维索引
34 | train_indices_2d = tf.stack([batch_indices, tf.cast(train_anchor_indices, dtype=tf.int64)], axis=1)
35 | # 获取预测的anchors类别
36 | predict_cls_ids = tf.gather_nd(predict_cls_ids, train_indices_2d) # (batch*train_num,2)
37 |
38 | # 交叉熵损失函数
39 | losses = tf.nn.softmax_cross_entropy_with_logits_v2(
40 | labels=true_cls_ids, logits=predict_cls_ids)
41 | return tf.reduce_mean(losses)
42 |
43 |
44 | def smooth_l1_loss(y_true, y_predict, sigma2=9.0):
45 | """
46 | smooth L1损失函数; 0.5 * sigma2 * x^2 if |x| <1/sigma2 else |x|-0.5/sigma2; x是 diff
47 | :param y_true: 真实值,可以是任何维度
48 | :param y_predict: 预测值
49 | :param sigma2
50 | :return:
51 | """
52 | abs_diff = tf.abs(y_true - y_predict, name='abs_diff')
53 | loss = tf.where(tf.less(abs_diff, 1. / sigma2), 0.5 * sigma2 * tf.pow(abs_diff, 2), abs_diff - 0.5 / sigma2)
54 | return loss
55 |
56 |
57 | def ctpn_regress_loss(predict_deltas, deltas, indices):
58 | """
59 | 高度方向中心点偏移和高度尺寸缩放回归损失
60 | :param predict_deltas: 预测的回归目标,(batch_num, anchors_num, 2)
61 | :param deltas: 真实的回归目标,(batch_num, ctpn_train_anchors, 3+1), 最后一位为tag, tag=0 为padding
62 | :param indices: 正负样本索引,(batch_num, ctpn_train_anchors, (idx,tag)),
63 | idx:指定anchor索引位置,最后一位为tag, tag=0 为padding; 1为正样本,-1为负样本
64 | :return:
65 | """
66 | # 去除padding和负样本
67 | positive_indices = tf.where(tf.equal(indices[:, :, -1], 1))
68 | deltas = tf.gather_nd(deltas[..., :-2], positive_indices) # (n,(dy,dh,dx,tag))
69 | true_positive_indices = tf.gather_nd(indices[..., 0], positive_indices) # 一维,正anchor索引
70 |
71 | # batch索引
72 | batch_indices = positive_indices[:, 0]
73 | # 正样本anchor的2维索引
74 | train_indices_2d = tf.stack([batch_indices, tf.cast(true_positive_indices, dtype=tf.int64)], axis=1)
75 | # 正样本anchor预测的回归类型
76 | predict_deltas = tf.gather_nd(predict_deltas, train_indices_2d, name='ctpn_regress_loss_predict_deltas')
77 |
78 | # Smooth-L1 # 非常重要,不然报NAN
79 | loss = K.switch(tf.size(deltas) > 0,
80 | smooth_l1_loss(deltas, predict_deltas),
81 | tf.constant(0.0))
82 | loss = K.mean(loss)
83 | return loss
84 |
85 |
86 | def side_regress_loss(predict_deltas, deltas, indices):
87 | """
88 | 侧边改善回归目标
89 | :param predict_deltas: 预测的x周偏移回归目标,(batch_num, anchors_num, 1)
90 | :param deltas: 真实的回归目标,(batch_num, ctpn_train_anchors, 3+1), 最后一位为tag, tag=0 为padding
91 | :param indices: 正负样本索引,(batch_num, ctpn_train_anchors, (idx,tag)),
92 | idx:指定anchor索引位置,最后一位为tag, tag=0 为padding; 1为正样本,-1为负样本
93 | :return:
94 | """
95 | # 去除padding和负样本
96 | positive_indices = tf.where(tf.equal(indices[:, :, -1], 1))
97 | deltas = tf.gather_nd(deltas[..., 2:3], positive_indices) # (n,(dy,dh,dx,tag)) 取 dx
98 | true_positive_indices = tf.gather_nd(indices[..., 0], positive_indices) # 一维,正anchor索引
99 |
100 | # batch索引
101 | batch_indices = positive_indices[:, 0]
102 | # 正样本anchor的2维索引
103 | train_indices_2d = tf.stack([batch_indices, tf.cast(true_positive_indices, dtype=tf.int64)], axis=1)
104 | # 正样本anchor预测的回归类型
105 | predict_deltas = tf.gather_nd(predict_deltas, train_indices_2d, name='ctpn_regress_loss_predict_side_deltas')
106 |
107 | # Smooth-L1 # 非常重要,不然报NAN
108 | loss = K.switch(tf.size(deltas) > 0,
109 | smooth_l1_loss(deltas, predict_deltas),
110 | tf.constant(0.0))
111 | loss = K.mean(loss)
112 | return loss
113 |
--------------------------------------------------------------------------------
/ctpn/layers/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: models
4 | Description : 模型
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | import keras
9 | from keras import layers
10 | from keras import Input, Model
11 | import tensorflow as tf
12 | from .base_net import resnet50
13 | from .anchor import CtpnAnchor
14 | from .target import CtpnTarget
15 | from .losses import ctpn_cls_loss, ctpn_regress_loss, side_regress_loss
16 | from .text_proposals import TextProposal
17 |
18 |
19 | def ctpn_net(config, stage='train'):
20 | # 网络构建
21 | # input_image = Input(batch_shape=(config.IMAGES_PER_GPU,) + config.IMAGE_SHAPE, name='input_image')
22 | # input_image_meta = Input(batch_shape=(config.IMAGES_PER_GPU, 12), name='input_image_meta')
23 | # gt_class_ids = Input(batch_shape=(config.IMAGES_PER_GPU, config.MAX_GT_INSTANCES, 2), name='gt_class_ids')
24 | # gt_boxes = Input(batch_shape=(config.IMAGES_PER_GPU, config.MAX_GT_INSTANCES, 5), name='gt_boxes')
25 | input_image = Input(shape=config.IMAGE_SHAPE, name='input_image')
26 | input_image_meta = Input(shape=(12,), name='input_image_meta')
27 | gt_class_ids = Input(shape=(config.MAX_GT_INSTANCES, 2), name='gt_class_ids')
28 | gt_boxes = Input(shape=(config.MAX_GT_INSTANCES, 5), name='gt_boxes')
29 |
30 | # 预测
31 | base_features = resnet50(input_image)
32 | num_anchors = len(config.ANCHORS_HEIGHT)
33 | predict_class_logits, predict_deltas, predict_side_deltas = ctpn(base_features, num_anchors, 64, 256)
34 |
35 | # anchors生成
36 | anchors, valid_anchors_indices = CtpnAnchor(config.ANCHORS_HEIGHT, config.ANCHORS_WIDTH, config.NET_STRIDE,
37 | name='gen_ctpn_anchors')(base_features)
38 |
39 | if stage == 'train':
40 | targets = CtpnTarget(config.IMAGES_PER_GPU,
41 | train_anchors_num=config.TRAIN_ANCHORS_PER_IMAGE,
42 | positive_ratios=config.ANCHOR_POSITIVE_RATIO,
43 | max_gt_num=config.MAX_GT_INSTANCES,
44 | name='ctpn_target')([gt_boxes, gt_class_ids, anchors, valid_anchors_indices])
45 | deltas, class_ids, anchors_indices = targets[:3]
46 | # 损失函数
47 | regress_loss = layers.Lambda(lambda x: ctpn_regress_loss(*x),
48 | name='ctpn_regress_loss')([predict_deltas, deltas, anchors_indices])
49 | side_loss = layers.Lambda(lambda x: side_regress_loss(*x),
50 | name='side_regress_loss')([predict_side_deltas, deltas, anchors_indices])
51 | cls_loss = layers.Lambda(lambda x: ctpn_cls_loss(*x),
52 | name='ctpn_class_loss')([predict_class_logits, class_ids, anchors_indices])
53 | model = Model(inputs=[input_image, gt_boxes, gt_class_ids],
54 | outputs=[regress_loss, cls_loss, side_loss])
55 |
56 | else:
57 | text_boxes, text_scores, text_class_logits = TextProposal(config.IMAGES_PER_GPU,
58 | score_threshold=config.TEXT_PROPOSALS_MIN_SCORE,
59 | output_box_num=config.TEXT_PROPOSALS_MAX_NUM,
60 | iou_threshold=config.TEXT_PROPOSALS_NMS_THRESH,
61 | use_side_refine=config.USE_SIDE_REFINE,
62 | name='text_proposals')(
63 | [predict_deltas, predict_side_deltas, predict_class_logits, anchors, valid_anchors_indices])
64 | image_meta = layers.Lambda(lambda x: x)(input_image_meta) # 原样返回
65 | model = Model(inputs=[input_image, input_image_meta], outputs=[text_boxes, text_scores, image_meta])
66 | return model
67 |
68 |
69 | def ctpn(base_features, num_anchors, rnn_units=128, fc_units=512):
70 | """
71 | ctpn网络
72 | :param base_features: (B,H,W,C)
73 | :param num_anchors: anchors个数
74 | :param rnn_units:
75 | :param fc_units:
76 | :return:
77 | """
78 | x = layers.Conv2D(512, kernel_size=(3, 3), padding='same', name='pre_fc')(base_features) # [B,H,W,512]
79 | # 沿着宽度方式做rnn
80 | rnn_forward = layers.TimeDistributed(layers.GRU(rnn_units, return_sequences=True, kernel_initializer='he_normal'),
81 | name='gru_forward')(x)
82 | rnn_backward = layers.TimeDistributed(
83 | layers.GRU(rnn_units, return_sequences=True, kernel_initializer='he_normal', go_backwards=True),
84 | name='gru_backward')(x)
85 |
86 | rnn_output = layers.Concatenate(name='gru_concat')([rnn_forward, rnn_backward]) # (B,H,W,256)
87 |
88 | # conv实现fc
89 | fc_output = layers.Conv2D(fc_units, kernel_size=(1, 1), activation='relu', name='fc_output')(
90 | rnn_output) # (B,H,W,512)
91 |
92 | # 分类
93 | class_logits = layers.Conv2D(2 * num_anchors, kernel_size=(1, 1), name='cls')(fc_output)
94 | class_logits = layers.Reshape(target_shape=(-1, 2), name='cls_reshape')(class_logits)
95 | # 中心点垂直坐标和高度回归
96 | predict_deltas = layers.Conv2D(2 * num_anchors, kernel_size=(1, 1), name='deltas')(fc_output)
97 | predict_deltas = layers.Reshape(target_shape=(-1, 2), name='deltas_reshape')(predict_deltas)
98 | # 侧边精调(只需要预测x偏移即可)
99 | predict_side_deltas = layers.Conv2D(num_anchors, kernel_size=(1, 1), name='side_deltas')(fc_output)
100 | predict_side_deltas = layers.Reshape(target_shape=(-1, 1), name='side_deltas_reshape')(
101 | predict_side_deltas)
102 | return class_logits, predict_deltas, predict_side_deltas
103 |
104 |
105 | def get_layer(model, name):
106 | for layer in model.layers:
107 | if layer.name == name:
108 | return layer
109 | return None
110 |
111 |
112 | def compile(keras_model, config, loss_names=[]):
113 | """
114 | 编译模型,增加损失函数,L2正则化以
115 | :param keras_model:
116 | :param config:
117 | :param loss_names: 损失函数列表
118 | :return:
119 | """
120 | # 优化目标
121 | optimizer = keras.optimizers.SGD(
122 | lr=config.LEARNING_RATE, momentum=config.LEARNING_MOMENTUM,
123 | clipnorm=config.GRADIENT_CLIP_NORM)
124 | # 增加损失函数,首先清除之前的,防止重复
125 | keras_model._losses = []
126 | keras_model._per_input_losses = {}
127 |
128 | for name in loss_names:
129 | layer = get_layer(keras_model, name)
130 | if layer is None or layer.output in keras_model.losses:
131 | continue
132 | loss = (tf.reduce_mean(layer.output, keepdims=True)
133 | * config.LOSS_WEIGHTS.get(name, 1.))
134 | keras_model.add_loss(loss)
135 |
136 | # 增加L2正则化
137 | # 跳过批标准化层的 gamma 和 beta 权重
138 | reg_losses = [
139 | keras.regularizers.l2(config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
140 | for w in keras_model.trainable_weights
141 | if 'gamma' not in w.name and 'beta' not in w.name]
142 | keras_model.add_loss(tf.add_n(reg_losses))
143 |
144 | # 编译
145 | keras_model.compile(
146 | optimizer=optimizer,
147 | loss=[None] * len(keras_model.outputs)) # 使用虚拟损失
148 |
149 | # 为每个损失函数增加度量
150 | for name in loss_names:
151 | if name in keras_model.metrics_names:
152 | continue
153 | layer = get_layer(keras_model, name)
154 | if layer is None:
155 | continue
156 | keras_model.metrics_names.append(name)
157 | loss = (
158 | tf.reduce_mean(layer.output, keepdims=True)
159 | * config.LOSS_WEIGHTS.get(name, 1.))
160 | keras_model.metrics_tensors.append(loss)
161 |
162 |
163 | def add_metrics(keras_model, metric_name_list, metric_tensor_list):
164 | """
165 | 增加度量
166 | :param keras_model: 模型
167 | :param metric_name_list: 度量名称列表
168 | :param metric_tensor_list: 度量张量列表
169 | :return: 无
170 | """
171 | for name, tensor in zip(metric_name_list, metric_tensor_list):
172 | keras_model.metrics_names.append(name)
173 | keras_model.metrics_tensors.append(tf.reduce_mean(tensor, keepdims=False))
174 |
--------------------------------------------------------------------------------
/ctpn/layers/target.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: Target
4 | Description : 分类和回归目标层
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 |
9 | from keras import layers
10 | import tensorflow as tf
11 | from ..utils import tf_utils
12 |
13 |
14 | def compute_iou(gt_boxes, anchors):
15 | """
16 | 计算iou
17 | :param gt_boxes: [N,(y1,x1,y2,x2)]
18 | :param anchors: [M,(y1,x1,y2,x2)]
19 | :return: IoU [N,M]
20 | """
21 | gt_boxes = tf.expand_dims(gt_boxes, axis=1) # [N,1,4]
22 | anchors = tf.expand_dims(anchors, axis=0) # [1,M,4]
23 | # 交集
24 | intersect_w = tf.maximum(0.0,
25 | tf.minimum(gt_boxes[:, :, 3], anchors[:, :, 3]) -
26 | tf.maximum(gt_boxes[:, :, 1], anchors[:, :, 1]))
27 | intersect_h = tf.maximum(0.0,
28 | tf.minimum(gt_boxes[:, :, 2], anchors[:, :, 2]) -
29 | tf.maximum(gt_boxes[:, :, 0], anchors[:, :, 0]))
30 | intersect = intersect_h * intersect_w
31 |
32 | # 计算面积
33 | area_gt = (gt_boxes[:, :, 3] - gt_boxes[:, :, 1]) * \
34 | (gt_boxes[:, :, 2] - gt_boxes[:, :, 0])
35 | area_anchor = (anchors[:, :, 3] - anchors[:, :, 1]) * \
36 | (anchors[:, :, 2] - anchors[:, :, 0])
37 |
38 | # 计算并集
39 | union = area_gt + area_anchor - intersect
40 | # 交并比
41 | iou = tf.divide(intersect, union, name='regress_target_iou')
42 | return iou
43 |
44 |
45 | def ctpn_regress_target(anchors, gt_boxes):
46 | """
47 | 计算回归目标
48 | :param anchors: [N,(y1,x1,y2,x2)]
49 | :param gt_boxes: [N,(y1,x1,y2,x2)]
50 | :return: [N, (dy, dh, dx)] dx 代表侧边改善的
51 | """
52 | # anchor高度
53 | h = anchors[:, 2] - anchors[:, 0]
54 | # gt高度
55 | gt_h = gt_boxes[:, 2] - gt_boxes[:, 0]
56 |
57 | # anchor中心点y坐标
58 | center_y = (anchors[:, 2] + anchors[:, 0]) * 0.5
59 | # gt中心点y坐标
60 | gt_center_y = (gt_boxes[:, 2] + gt_boxes[:, 0]) * 0.5
61 |
62 | # 计算回归目标
63 | dy = (gt_center_y - center_y) / h
64 | dh = tf.log(gt_h / h)
65 | dx = side_regress_target(anchors, gt_boxes) # 侧边改善
66 | target = tf.stack([dy, dh, dx], axis=1)
67 | target /= tf.constant([0.1, 0.2, 0.1])
68 |
69 | return target
70 |
71 |
72 | def side_regress_target(anchors, gt_boxes):
73 | """
74 | 侧边改善回归目标
75 | :param anchors: [N,(y1,x1,y2,x2)]
76 | :param gt_boxes: anchor 对应的GT boxes[N,(y1,x1,y2,x2)]
77 | :return:
78 | """
79 | w = anchors[:, 3] - anchors[:, 1] # 实际是固定长度16
80 | center_x = (anchors[:, 3] + anchors[:, 1]) * 0.5
81 | gt_center_x = (gt_boxes[:, 3] + gt_boxes[:, 1]) * 0.5
82 | # 侧边框移动到gt的侧边,相当于中心点偏移的两倍;不是侧边的anchor 偏移为0;
83 | dx = (gt_center_x - center_x) * 2 / w
84 | return dx
85 |
86 |
87 | def ctpn_target_graph(gt_boxes, gt_cls, anchors, valid_anchors_indices, train_anchors_num=128, positive_ratios=0.5,
88 | max_gt_num=50):
89 | """
90 | 处理单个图像的ctpn回归目标
91 | a)正样本: 与gt IoU大于0.7的anchor,或者与GT IoU最大的那个anchor
92 | b)需要保证所有的GT都有anchor对应
93 | :param gt_boxes: gt边框坐标 [gt_num, (y1,x1,y2,x2,tag)], tag=0为padding
94 | :param gt_cls: gt类别 [gt_num, 1+1], 最后一位为tag, tag=0为padding
95 | :param anchors: [anchor_num, (y1,x1,y2,x2)]
96 | :param valid_anchors_indices:有效的anchors索引 [anchor_num]
97 | :param train_anchors_num
98 | :param positive_ratios
99 | :param max_gt_num
100 | :return:
101 | deltas:[train_anchors_num, (dy,dh,dx,tag)],anchor边框回归目标,tag=1为正负样本,tag=0为padding
102 | class_id:[train_anchors_num,(class_id,tag)]
103 | indices: [train_anchors_num,(anchors_index,tag)] tag=1为正样本,tag=0为padding,-1为负样本
104 | """
105 | # 获取真正的GT,去除标签位
106 | gt_boxes = tf_utils.remove_pad(gt_boxes)
107 | gt_cls = tf_utils.remove_pad(gt_cls)[:, 0] # [N,1]转[N]
108 |
109 | gt_num = tf.shape(gt_cls)[0] # gt 个数
110 |
111 | # 计算IoU
112 | iou = compute_iou(gt_boxes, anchors)
113 | # 每个GT对应的IoU最大的anchor是正样本(一般有多个)
114 | gt_iou_max = tf.reduce_max(iou, axis=1, keep_dims=True) # 每个gt最大的iou [gt_num,1]
115 | gt_iou_max_bool = tf.equal(iou, gt_iou_max) # bool类型[gt_num,num_anchors];每个gt最大的iou(可能多个)
116 |
117 | # 每个anchors最大iou ,且iou>0.7的为正样本
118 | anchors_iou_max = tf.reduce_max(iou, axis=0, keep_dims=True) # 每个anchor最大的iou; [1,num_anchors]
119 | anchors_iou_max = tf.where(tf.greater_equal(anchors_iou_max, 0.7),
120 | anchors_iou_max,
121 | tf.ones_like(anchors_iou_max))
122 | anchors_iou_max_bool = tf.equal(iou, anchors_iou_max)
123 |
124 | # 合并两部分正样本索引
125 | positive_bool_matrix = tf.logical_or(gt_iou_max_bool, anchors_iou_max_bool)
126 | # 获取最小的iou,用于度量
127 | gt_match_min_iou = tf.reduce_min(tf.boolean_mask(iou, positive_bool_matrix), keep_dims=True)[0] # 一维
128 | gt_match_mean_iou = tf.reduce_mean(tf.boolean_mask(iou, positive_bool_matrix), keep_dims=True)[0]
129 | # 正样本索引
130 | positive_indices = tf.where(positive_bool_matrix) # 第一维gt索引号,第二维anchor索引号
131 | # before_sample_positive_indices = positive_indices # 采样之前的正样本索引
132 | # 采样正样本
133 | positive_num = tf.minimum(tf.shape(positive_indices)[0], int(train_anchors_num * positive_ratios))
134 | positive_indices = tf.random_shuffle(positive_indices)[:positive_num]
135 |
136 | # 获取正样本和对应的GT
137 | positive_gt_indices = positive_indices[:, 0]
138 | positive_anchor_indices = positive_indices[:, 1]
139 | positive_anchors = tf.gather(anchors, positive_anchor_indices)
140 | positive_gt_boxes = tf.gather(gt_boxes, positive_gt_indices)
141 | positive_gt_cls = tf.gather(gt_cls, positive_gt_indices)
142 |
143 | # 计算回归目标
144 | deltas = ctpn_regress_target(positive_anchors, positive_gt_boxes)
145 |
146 | # # 获取负样本 iou<0.5
147 | negative_bool = tf.less(tf.reduce_max(iou, axis=0), 0.5)
148 | positive_bool = tf.reduce_any(positive_bool_matrix, axis=0) # 正样本anchors [num_anchors]
149 | negative_bool = tf.logical_and(negative_bool, tf.logical_not(positive_bool))
150 |
151 | # 采样负样本
152 | negative_num = tf.minimum(int(train_anchors_num * (1. - positive_ratios)), train_anchors_num - positive_num)
153 | negative_indices = tf.random_shuffle(tf.where(negative_bool)[:, 0])[:negative_num]
154 |
155 | negative_gt_cls = tf.zeros([negative_num]) # 负样本类别id为0
156 | negative_deltas = tf.zeros([negative_num, 3])
157 |
158 | # 合并正负样本
159 | deltas = tf.concat([deltas, negative_deltas], axis=0, name='ctpn_target_deltas')
160 | class_ids = tf.concat([positive_gt_cls, negative_gt_cls], axis=0, name='ctpn_target_class_ids')
161 | indices = tf.concat([positive_anchor_indices, negative_indices], axis=0,
162 | name='ctpn_train_anchor_indices')
163 | indices = tf.gather(valid_anchors_indices, indices) # 对应到有效的索引号
164 |
165 | # 计算padding
166 | deltas, class_ids = tf_utils.pad_list_to_fixed_size([deltas, tf.expand_dims(class_ids, 1)],
167 | train_anchors_num)
168 | # 将负样本tag标志改为-1;方便后续处理;
169 | indices = tf_utils.pad_to_fixed_size_with_negative(tf.expand_dims(indices, 1), train_anchors_num,
170 | negative_num=negative_num, data_type=tf.int64)
171 |
172 | return [deltas, class_ids, indices, tf.cast( # 用作度量的必须是浮点类型
173 | gt_num, dtype=tf.float32), tf.cast(
174 | positive_num, dtype=tf.float32), tf.cast(negative_num, dtype=tf.float32),
175 | gt_match_min_iou, gt_match_mean_iou]
176 |
177 |
178 | class CtpnTarget(layers.Layer):
179 | def __init__(self, batch_size, train_anchors_num=128, positive_ratios=0.5, max_gt_num=50, **kwargs):
180 | self.batch_size = batch_size
181 | self.train_anchors_num = train_anchors_num
182 | self.positive_ratios = positive_ratios
183 | self.max_gt_num = max_gt_num
184 | super(CtpnTarget, self).__init__(**kwargs)
185 |
186 | def call(self, inputs, **kwargs):
187 | """
188 |
189 | :param inputs:
190 | inputs[0]: GT 边框坐标 [batch_size, MAX_GT_BOXs,(y1,x1,y2,x2,tag)] ,tag=0 为padding
191 | inputs[1]: GT 类别 [batch_size, MAX_GT_BOXs,num_class+1] ;最后一位为tag, tag=0 为padding
192 | inputs[2]: Anchors [batch_size, anchor_num,(y1,x1,y2,x2)]
193 | inputs[3]: val_anchors_indices [batch_size, anchor_num]
194 | :param kwargs:
195 | :return:
196 | """
197 | gt_boxes, gt_cls_ids, anchors, valid_anchors_indices = inputs
198 | # options = {"train_anchors_num": self.train_anchors_num,
199 | # "positive_ratios": self.positive_ratios,
200 | # "max_gt_num": self.max_gt_num}
201 | #
202 | # outputs = tf.map_fn(fn=lambda x: ctpn_target_graph(*x, **options),
203 | # elems=[gt_boxes, gt_cls_ids, anchors, valid_anchors_indices],
204 | # dtype=[tf.float32] * 2 + [tf.int64] + [tf.float32] + [tf.int64] + [tf.float32] * 3)
205 | outputs = tf_utils.batch_slice([gt_boxes, gt_cls_ids, anchors, valid_anchors_indices],
206 | lambda x, y, z, s: ctpn_target_graph(x, y, z, s,
207 | self.train_anchors_num,
208 | self.positive_ratios,
209 | self.max_gt_num),
210 | batch_size=self.batch_size)
211 | return outputs
212 |
213 | def compute_output_shape(self, input_shape):
214 | return [(input_shape[0][0], self.train_anchors_num, 4), # deltas (dy,dh,dx)
215 | (input_shape[0][0], self.train_anchors_num, 2), # cls
216 | (input_shape[0][0], self.train_anchors_num, 2), # indices
217 | (input_shape[0][0],), # gt_num
218 | (input_shape[0][0],), # positive_num
219 | (input_shape[0][0],), # negative_num
220 | (input_shape[0][0], 1),
221 | (input_shape[0][0], 1)] # gt_match_min_iou
222 |
--------------------------------------------------------------------------------
/ctpn/layers/text_proposals.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: text_proposals
4 | Description : 文本提议框生成
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | from keras import layers
9 | import tensorflow as tf
10 | from ..utils import tf_utils
11 |
12 |
13 | def apply_regress(deltas, side_deltas, anchors, use_side_refine=False):
14 | """
15 | 应用回归目标到边框, 垂直中心点偏移和高度缩放
16 | :param deltas: 回归目标[N,(dy,dh,)]
17 | :param side_deltas: 回归目标[N,(dx)]
18 | :param anchors: anchor boxes[N,(y1,x1,y2,x2)]
19 | :param use_side_refine: 是否应用侧边回归
20 | :return:
21 | """
22 | # 高度和宽度
23 | h = anchors[:, 2] - anchors[:, 0]
24 | w = anchors[:, 3] - anchors[:, 1]
25 |
26 | # 中心点坐标
27 | cy = (anchors[:, 2] + anchors[:, 0]) * 0.5
28 | cx = (anchors[:, 3] + anchors[:, 1]) * 0.5
29 |
30 | deltas = tf.concat([deltas, side_deltas], axis=1)
31 | # 回归系数
32 | deltas *= tf.constant([0.1, 0.2, 0.1])
33 | dy, dh, dx = deltas[:, 0], deltas[:, 1], deltas[:, 2]
34 |
35 | # 中心坐标回归
36 | cy += dy * h
37 | # 侧边精调
38 | cx += dx * w
39 | # 高度和宽度回归
40 | h *= tf.exp(dh)
41 |
42 | # 转为y1,x1,y2,x2
43 | y1 = cy - h * 0.5
44 | y2 = cy + h * 0.5
45 | x1 = tf.maximum(cx - w * 0.5, 0.) # 限制在窗口内,修复后继节点找不到对应的前驱节点
46 | x2 = cx + w * 0.5
47 |
48 | if use_side_refine:
49 | return tf.stack([y1, x1, y2, x2], axis=1)
50 | else:
51 | return tf.stack([y1, anchors[:, 1], y2, anchors[:, 3]], axis=1)
52 |
53 |
54 | def get_valid_predicts(deltas, side_deltas, class_logits, valid_anchors_indices):
55 | return tf.gather(deltas, valid_anchors_indices), tf.gather(
56 | side_deltas, valid_anchors_indices), tf.gather(
57 | class_logits, valid_anchors_indices)
58 |
59 |
60 | def nms(boxes, scores, class_logits, max_output_size, iou_threshold=0.5, score_threshold=0.05,
61 | name=None):
62 | """
63 | 非极大抑制
64 | :param boxes: 形状为[num_boxes, 4]的二维浮点型Tensor.
65 | :param scores: 形状为[num_boxes]的一维浮点型Tensor,表示与每个框(每行框)对应的单个分数.
66 | :param class_logits: 形状为[num_boxes,num_classes] 原始的预测类别
67 | :param max_output_size: 一个标量整数Tensor,表示通过非最大抑制选择的框的最大数量.
68 | :param iou_threshold: 浮点数,IOU 阈值
69 | :param score_threshold: 浮点数, 过滤低于阈值的边框
70 | :param name:
71 | :return: 检测边框、边框得分、边框类别
72 | """
73 | indices = tf.image.non_max_suppression(boxes, scores, max_output_size, iou_threshold, score_threshold,
74 | name) # 一维索引
75 | output_boxes = tf.gather(boxes, indices) # (M,4)
76 | class_scores = tf.expand_dims(tf.gather(scores, indices), axis=1) # 扩展到二维(M,1)
77 | class_logits = tf.gather(class_logits, indices)
78 | # padding到固定大小
79 | return [tf_utils.pad_to_fixed_size(output_boxes, max_output_size),
80 | tf_utils.pad_to_fixed_size(class_scores, max_output_size),
81 | tf_utils.pad_to_fixed_size(class_logits, max_output_size)]
82 |
83 |
84 | class TextProposal(layers.Layer):
85 | """
86 | 生成候选框
87 | """
88 |
89 | def __init__(self, batch_size, score_threshold=0.7, output_box_num=500, iou_threshold=0.3,
90 | use_side_refine=False, **kwargs):
91 | """
92 |
93 | :param score_threshold: 分数阈值
94 | :param output_box_num: 生成proposal 边框数量
95 | :param iou_threshold: nms iou阈值
96 | :param use_side_refine : 预测时是否使用侧边改善
97 | """
98 | self.batch_size = batch_size
99 | self.score_threshold = score_threshold
100 | self.output_box_num = output_box_num
101 | self.iou_threshold = iou_threshold
102 | self.use_side_refine = use_side_refine
103 | super(TextProposal, self).__init__(**kwargs)
104 |
105 | def call(self, inputs, **kwargs):
106 | """
107 | 应用边框回归,并使用nms生成最后的边框
108 | :param inputs:
109 | inputs[0]: deltas, [batch_size,N,(dy,dh)] N是所有的anchors数量
110 | inputs[1]: side_deltas, [batch_size,N,(dx)] N是所有的anchors数量
111 | inputs[2]: class logits [batch_size,N,num_classes]
112 | inputs[3]: anchors [batch_size,N,(y1,x1,y2,x2)]
113 | inputs[4]: valid_anchors_indices [batch_size, anchor_num]
114 | :param kwargs:
115 | :return:
116 | """
117 | deltas, side_deltas, class_logits, anchors, val_anchors_indices = inputs
118 | # 只看有效anchor的预测结果
119 | deltas, side_deltas, class_logits = tf_utils.batch_slice(
120 | [deltas, side_deltas, class_logits, val_anchors_indices],
121 | lambda x, y, z, u: get_valid_predicts(x, y, z, u),
122 | self.batch_size)
123 | # 转为分类评分
124 | class_scores = tf.nn.softmax(logits=class_logits, axis=-1) # [N,num_classes]
125 | fg_scores = tf.reduce_max(class_scores[..., 1:], axis=-1) # 第一类为背景 (N,)
126 |
127 | # # 应用边框回归
128 | # proposals = tf.map_fn(fn=lambda x: apply_regress(*x),
129 | # elems=[deltas, anchors],
130 | # dtype=tf.float32)
131 | # # # 非极大抑制
132 | #
133 | # options = {"max_output_size": self.output_box_num,
134 | # "iou_threshold": self.iou_threshold,
135 | # "score_threshold": self.score_threshold}
136 | # outputs = tf.map_fn(fn=lambda x: nms(*x, **options),
137 | # elems=[proposals, fg_scores, class_logits],
138 | # dtype=[tf.float32] * 3)
139 | proposals = tf_utils.batch_slice([deltas, side_deltas, anchors],
140 | lambda x, y, z: apply_regress(x, y, z, self.use_side_refine),
141 | self.batch_size)
142 |
143 | outputs = tf_utils.batch_slice([proposals, fg_scores, class_logits],
144 | lambda x, y, z: nms(x, y, z,
145 | max_output_size=self.output_box_num,
146 | iou_threshold=self.iou_threshold,
147 | score_threshold=self.score_threshold),
148 | self.batch_size)
149 | return outputs
150 |
151 | def compute_output_shape(self, input_shape):
152 | """
153 | 注意多输出,call返回值必须是列表
154 | :param input_shape:
155 | :return:
156 | """
157 | return [(input_shape[0][0], self.output_box_num, 4 + 1),
158 | (input_shape[0][0], self.output_box_num, 1 + 1),
159 | (input_shape[0][0], self.output_box_num, input_shape[1][-1])]
160 |
--------------------------------------------------------------------------------
/ctpn/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: __init__.py
4 | Description :
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
--------------------------------------------------------------------------------
/ctpn/preprocess/reader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: reader
4 | Description :
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | import numpy as np
9 | import os
10 | import glob
11 |
12 |
13 | def load_annotation(annotation_path, image_dir):
14 | """
15 | 加载标注信息
16 | :param annotation_path:
17 | :param image_dir:
18 | :return:
19 | """
20 | image_annotation = {}
21 | # 文件名称,路径
22 | base_name = os.path.basename(annotation_path)
23 | image_name = base_name[3:-3] + '*' # 通配符 gt_img_3.txt,img_3.jpg or png
24 | image_annotation["annotation_path"] = annotation_path
25 | image_annotation["image_path"] = glob.glob(os.path.join(image_dir, image_name))[0]
26 | image_annotation["file_name"] = os.path.basename(image_annotation["image_path"]) # 图像文件名
27 | # 读取边框标注
28 | bbox = []
29 | quadrilateral = [] # 四边形
30 |
31 | with open(annotation_path, "r", encoding='utf-8') as f:
32 | lines = f.read().encode('utf-8').decode('utf-8-sig').splitlines()
33 | # lines = f.readlines()
34 | # print(lines)
35 | for line in lines:
36 | line = line.strip().split(",")
37 | # 左上、右上、右下、左下 四个坐标 如:377,117,463,117,465,130,378,130
38 | lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y = map(float, line[:8])
39 | x_min, y_min, x_max, y_max = min(lt_x, lb_x), min(lt_y, rt_y), max(rt_x, rb_x), max(lb_y, rb_y)
40 | bbox.append([y_min, x_min, y_max, x_max])
41 | quadrilateral.append([lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y])
42 |
43 | image_annotation["boxes"] = np.asarray(bbox, np.float32).reshape((-1, 4))
44 | image_annotation["quadrilaterals"] = np.asarray(quadrilateral, np.float32).reshape((-1, 8))
45 | image_annotation["labels"] = np.ones(shape=(len(bbox)), dtype=np.uint8)
46 | return image_annotation
47 |
--------------------------------------------------------------------------------
/ctpn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: __init__.py
4 | Description :
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
--------------------------------------------------------------------------------
/ctpn/utils/detector.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: detector
4 | Description : 文本行检测器
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | import numpy as np
9 | from .text_proposal_connector import TextProposalConnector
10 | from ..utils import np_utils
11 |
12 |
13 | def normalize(data):
14 | if data.shape[0] == 0:
15 | return data
16 | max_ = data.max()
17 | min_ = data.min()
18 | return (data - min_) / (max_ - min_) if max_ - min_ != 0 else data - min_
19 |
20 |
21 | class TextDetector:
22 | """
23 | Detect text from an image
24 | """
25 |
26 | def __init__(self, config):
27 | self.config = config
28 | self.text_proposal_connector = TextProposalConnector()
29 |
30 | def detect(self, text_proposals, scores, image_shape, window):
31 | """
32 | 检测文本行
33 | :param text_proposals: 文本提议框
34 | :param scores: 文本框得分
35 | :param image_shape: 图像形状
36 | :param window [y1,x1,y2,x2] 去除padding后的窗口
37 | :return: text_lines; [ num,(y1,x1,y2,x2,score)]
38 | """
39 |
40 | scores = normalize(scores) # 加上后,效果变差; 评估结果好像还是好一点
41 | text_lines = self.text_proposal_connector.get_text_lines(text_proposals, scores, image_shape)
42 | keep_indices = self.filter_boxes(text_lines)
43 | text_lines = text_lines[keep_indices]
44 | text_lines = filter_out_of_window(text_lines, window)
45 |
46 | # 文本行nms
47 | if text_lines.shape[0] != 0:
48 | keep_indices = np_utils.quadrangle_nms(text_lines[:, :8], text_lines[:, 8],
49 | self.config.TEXT_LINE_NMS_THRESH)
50 | text_lines = text_lines[keep_indices]
51 |
52 | return text_lines
53 |
54 | def filter_boxes(self, text_lines):
55 | widths = text_lines[:, 2] - text_lines[:, 0]
56 | scores = text_lines[:, -1]
57 | return np.where((scores > self.config.LINE_MIN_SCORE) &
58 | (widths > (self.config.TEXT_PROPOSALS_WIDTH * self.config.MIN_NUM_PROPOSALS)))[0]
59 |
60 |
61 | def filter_out_of_window(text_lines, window):
62 | """
63 | 过滤窗口外的text_lines
64 | :param text_lines: [n,9]
65 | :param window: [y1,x1,y2,x2]
66 | :return:
67 | """
68 | y1, x1, y2, x2 = window
69 |
70 | quadrilaterals = np.reshape(text_lines[:, :8], (-1, 4, 2)) # [n,4 points,(x,y)]
71 | min_x = np.min(quadrilaterals[:, :, 0], axis=1) # [n]
72 | max_x = np.max(quadrilaterals[:, :, 0], axis=1)
73 | min_y = np.min(quadrilaterals[:, :, 1], axis=1)
74 | max_y = np.max(quadrilaterals[:, :, 1], axis=1)
75 | # 窗口内的text_lines
76 | indices = np.where(np.logical_and(np.logical_and(np.logical_and(min_x >= x1,
77 | max_x <= x2),
78 | min_y >= y1),
79 | max_y <= y2))
80 | return text_lines[indices]
81 |
--------------------------------------------------------------------------------
/ctpn/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: file_utils
4 | Description : 文件处理工具类
5 | Author : mick.yi
6 | date: 2019/2/19
7 | """
8 | import os
9 |
10 |
11 | def get_sub_files(dir_path, recursive=False):
12 | """
13 | 获取目录下所有文件名
14 | :param dir_path:
15 | :param recursive: 是否递归
16 | :return:
17 | """
18 | file_paths = []
19 | for dir_name in os.listdir(dir_path):
20 | cur_dir_path = os.path.join(dir_path, dir_name)
21 | if os.path.isdir(cur_dir_path) and recursive:
22 | file_paths = file_paths + get_sub_files(cur_dir_path)
23 | else:
24 | file_paths.append(cur_dir_path)
25 | return file_paths
26 |
--------------------------------------------------------------------------------
/ctpn/utils/generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: generator
4 | Description : 生成器
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | import random
9 | import numpy as np
10 | from ..utils import image_utils, np_utils, gt_utils
11 |
12 |
13 | def generator(image_annotations, batch_size, image_shape, width_stride,
14 | max_gt_num, horizontal_flip=False, random_crop=False):
15 | image_length = len(image_annotations)
16 | while True:
17 | ids = np.random.choice(image_length, batch_size, replace=False)
18 | batch_images = []
19 | batch_images_meta = []
20 | batch_gt_boxes = []
21 | batch_gt_class_ids = []
22 | for id in ids:
23 | image_annotation = image_annotations[id]
24 | image, image_meta, _, gt_quadrilaterals = image_utils.load_image_gt(id,
25 | image_annotation['image_path'],
26 | image_shape[0],
27 | gt_quadrilaterals=image_annotation[
28 | 'quadrilaterals'],
29 | horizontal_flip=horizontal_flip,
30 | random_crop=random_crop)
31 | class_ids = image_annotation['labels']
32 | gt_boxes, class_ids = gt_utils.gen_gt_from_quadrilaterals(gt_quadrilaterals,
33 | class_ids,
34 | image_shape,
35 | width_stride,
36 | box_min_size=3)
37 | batch_images.append(image)
38 | batch_images_meta.append(image_meta)
39 | gt_boxes = np_utils.pad_to_fixed_size(gt_boxes[:max_gt_num], max_gt_num) # GT boxes数量防止超出阈值
40 | batch_gt_boxes.append(gt_boxes)
41 | batch_gt_class_ids.append(
42 | np_utils.pad_to_fixed_size(np.expand_dims(np.array(class_ids), axis=1), max_gt_num))
43 |
44 | # 返回结果
45 | yield {"input_image": np.asarray(batch_images),
46 | "input_image_meta": np.asarray(batch_images_meta),
47 | "gt_class_ids": np.asarray(batch_gt_class_ids),
48 | "gt_boxes": np.asarray(batch_gt_boxes)}, None
49 |
--------------------------------------------------------------------------------
/ctpn/utils/gt_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: gt_utils
4 | Description : gt 四边形分割为固定宽度的系列gt boxes
5 | Author : mick.yi
6 | date: 2019/3/18
7 | """
8 | import numpy as np
9 |
10 |
11 | def linear_fit_y(xs, ys, x_list):
12 | """
13 | 线性函数拟合两点(x1,y1),(x2,y2);并求得x_list在的取值
14 | :param xs: [x1,x2]
15 | :param ys: [y1,y2]
16 | :param x_list: x轴坐标点,numpy数组 [n]
17 | :return:
18 | """
19 | if xs[0] == xs[1]: # 垂直线
20 | return np.ones_like(x_list) * np.mean(ys)
21 | elif ys[0] == ys[1]: # 水平线
22 | return np.ones_like(x_list) * ys[0]
23 | else:
24 | fn = np.poly1d(np.polyfit(xs, ys, 1)) # 一元线性函数
25 | return fn(x_list)
26 |
27 |
28 | def get_min_max_y(quadrilateral, xs):
29 | """
30 | 获取指定x值坐标点集合四边形上的y轴最小值和最大值
31 | :param quadrilateral: 四边形坐标;x1,y1,x2,y2,x3,y3,x4,y4
32 | :param xs: x轴坐标点,numpy数组 [n]
33 | :return: x轴坐标点在四边形上的最小值和最大值
34 | """
35 | x1, y1, x2, y2, x3, y3, x4, y4 = quadrilateral.tolist()
36 | y_val_1 = linear_fit_y(np.array([x1, x2]), np.array([y1, y2]), xs)
37 | y_val_2 = linear_fit_y(np.array([x2, x3]), np.array([y2, y3]), xs)
38 | y_val_3 = linear_fit_y(np.array([x3, x4]), np.array([y3, y4]), xs)
39 | y_val_4 = linear_fit_y(np.array([x4, x1]), np.array([y4, y1]), xs)
40 | y_val_min = []
41 | y_val_max = []
42 | for i in range(len(xs)):
43 | y_val = []
44 | if min(x1, x2) <= xs[i] <= max(x1, x2):
45 | y_val.append(y_val_1[i])
46 | if min(x2, x3) <= xs[i] <= max(x2, x3):
47 | y_val.append(y_val_2[i])
48 | if min(x3, x4) <= xs[i] <= max(x3, x4):
49 | y_val.append(y_val_3[i])
50 | if min(x4, x1) <= xs[i] <= max(x4, x1):
51 | y_val.append(y_val_4[i])
52 | # print("y_val:{}".format(y_val))
53 | y_val_min.append(min(y_val))
54 | y_val_max.append(max(y_val))
55 |
56 | return np.array(y_val_min), np.array(y_val_max)
57 |
58 |
59 | def get_xs_in_range(x_array, x_min, x_max):
60 | """
61 | 获取分割坐标点
62 | :param x_array: 宽度方向分割坐标点数组;0~image_width,间隔16 ;如:[0,16,32,...608]
63 | :param x_min: 四边形x最小值
64 | :param x_max: 四边形x最大值
65 | :return:
66 | """
67 | indices = np.logical_and(x_array >= x_min, x_array <= x_max)
68 | xs = x_array[indices]
69 | # 处理两端的值
70 | if xs.shape[0] == 0 or xs[0] > x_min:
71 | xs = np.insert(xs, 0, x_min)
72 | if xs.shape[0] == 0 or xs[-1] < x_max:
73 | xs = np.append(xs, x_max)
74 | return xs
75 |
76 |
77 | def gen_gt_from_quadrilaterals(gt_quadrilaterals, input_gt_class_ids, image_shape, width_stride, box_min_size=3):
78 | """
79 | 从gt 四边形生成,宽度固定的gt boxes
80 | :param gt_quadrilaterals: GT四边形坐标,[n,(x1,y1,x2,y2,x3,y3,x4,y4)]
81 | :param input_gt_class_ids: GT四边形类别,一般就是1 [n]
82 | :param image_shape:
83 | :param width_stride: 分割的步长,一般16
84 | :param box_min_size: 分割后GT boxes的最小尺寸
85 | :return:
86 | gt_boxes:[m,(y1,x1,y2,x2)]
87 | gt_class_ids: [m]
88 | """
89 | h, w = list(image_shape)[:2]
90 | x_array = np.arange(0, w + 1, width_stride, np.float32) # 固定宽度间隔的x坐标点
91 | # 每个四边形x 最小值和最大值
92 | x_min_np = np.min(gt_quadrilaterals[:, ::2], axis=1)
93 | x_max_np = np.max(gt_quadrilaterals[:, ::2], axis=1)
94 | gt_boxes = []
95 | gt_class_ids = []
96 | for i in np.arange(len(gt_quadrilaterals)):
97 | xs = get_xs_in_range(x_array, x_min_np[i], x_max_np[i]) # 获取四边形内的x中坐标点
98 | ys_min, ys_max = get_min_max_y(gt_quadrilaterals[i], xs)
99 | # print("xs:{}".format(xs))
100 | # 为每个四边形生成固定宽度的gt
101 | for j in range(len(xs) - 1):
102 | x1, x2 = xs[j], xs[j + 1]
103 | y1, y2 = np.min(ys_min[j:j + 2]), np.max(ys_max[j:j + 2])
104 | gt_boxes.append([y1, x1, y2, x2])
105 | gt_class_ids.append(input_gt_class_ids[i])
106 | gt_boxes = np.reshape(np.array(gt_boxes), (-1, 4))
107 | gt_class_ids = np.reshape(np.array(gt_class_ids), (-1,))
108 | # 过滤高度太小的边框
109 | height = gt_boxes[:, 2] - gt_boxes[:, 0]
110 | width = gt_boxes[:, 3] - gt_boxes[:, 1]
111 | indices = np.where(np.logical_and(height >= 8, width >= 2))
112 | return gt_boxes[indices], gt_class_ids[indices]
113 |
--------------------------------------------------------------------------------
/ctpn/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: image
4 | Description : 图像处理工具类
5 | Author : mick.yi
6 | date: 2019/2/18
7 | """
8 | import skimage
9 | from skimage import io, transform
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import random
13 |
14 |
15 | def load_image(image_path):
16 | """
17 | 加载图像
18 | :param image_path: 图像路径
19 | :return: [h,w,3] numpy数组
20 | """
21 | image = plt.imread(image_path)
22 | # 灰度图转为RGB
23 | if len(image.shape) == 2:
24 | image = np.expand_dims(image, axis=2)
25 | image = np.tile(image, (1, 1, 3))
26 | elif image.shape[-1] == 1:
27 | image = skimage.color.gray2rgb(image) # io.imread 报ValueError: Input image expected to be RGB, RGBA or gray
28 | # 标准化为0~255之间
29 | if image.dtype == np.float32:
30 | image *= 255
31 | image = image.astype(np.uint8)
32 | # 删除alpha通道
33 | return image[..., :3]
34 |
35 |
36 | def load_image_gt(image_id, image_path, output_size, gt_boxes=None,
37 | gt_quadrilaterals=None, horizontal_flip=False, random_crop=False):
38 | """
39 | 加载图像,生成训练输入大小的图像,并调整GT 边框,返回相关元数据信息
40 | :param image_id: 图像编号id
41 | :param image_path: 图像路径
42 | :param output_size: 图像输出尺寸,及网络输入到高度或宽度(默认长宽相等)
43 | :param gt_boxes: GT 边框 [N,(y1,x1,y2,x2)]
44 | :param gt_quadrilaterals:
45 | :param horizontal_flip: 是否水平翻转
46 | :param random_crop: 是否随机裁剪
47 | :return:
48 | image: (H,W,3)
49 | image_meta: 元数据信息,详见compose_image_meta
50 | gt_boxes:图像缩放及padding后对于的GT 边框坐标 [N,(y1,x1,y2,x2)]
51 | """
52 | # 加载图像
53 | image = load_image(image_path)
54 | # 随机裁剪
55 | if random_crop and random.random() >= 0.5:
56 | min_x, max_x = np.min(gt_quadrilaterals[:, ::2]), np.max(gt_quadrilaterals[:, ::2])
57 | min_y, max_y = np.min(gt_quadrilaterals[:, 1::2]), np.max(gt_quadrilaterals[:, 1::2])
58 | image, crop_window = crop_image(image, [min_y, min_x, max_y, max_x])
59 | # print(image.shape,[min_y, min_x, max_y, max_x],crop_window)
60 | # gt坐标偏移
61 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0:
62 | gt_quadrilaterals[:, 1::2] -= crop_window[0]
63 | gt_quadrilaterals[:, ::2] -= crop_window[1]
64 | # 水平翻转
65 | if horizontal_flip:
66 | image = image[:, ::-1, :]
67 | # gt翻转
68 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0:
69 | gt_quadrilaterals[:, ::2] = image.shape[1] - gt_quadrilaterals[:, ::2]
70 | lt_x, lt_y, rt_x, rt_y, rb_x, rb_y, lb_x, lb_y = np.split(gt_quadrilaterals, 8, axis=1)
71 | gt_quadrilaterals = np.concatenate([rt_x, rt_y, lt_x, lt_y, lb_x, lb_y, rb_x, rb_y], axis=1)
72 |
73 | original_shape = image.shape
74 | # resize图像,并获取相关元数据信息
75 | image, window, scale, padding = resize_image(image, output_size)
76 |
77 | # 组合元数据信息
78 | image_meta = compose_image_meta(image_id, original_shape, image.shape,
79 | window, scale)
80 | # 根据缩放及padding调整GT边框
81 | if gt_boxes is not None and gt_boxes.shape[0] > 0:
82 | gt_boxes = adjust_box(gt_boxes, padding, scale)
83 | if gt_quadrilaterals is not None and gt_quadrilaterals.shape[0] > 0:
84 | gt_quadrilaterals = adjust_quadrilaterals(gt_quadrilaterals, padding, scale)
85 |
86 | return image, image_meta, gt_boxes, gt_quadrilaterals
87 |
88 |
89 | def crop_image(image, gt_window):
90 | h, w = list(image.shape)[:2]
91 | y1, x1, y2, x2 = gt_window
92 | gaps = np.array([y1, x1, h - y2, w - x2])
93 | wy1 = np.random.randint(min(y1+1, h // 5))
94 | wx1 = np.random.randint(min(x1+1, w // 5))
95 | wy2 = h - np.random.randint(min(h - y2 + 1, h // 5))
96 | wx2 = w - np.random.randint(min(w - x2 + 1, w // 5))
97 | return image[wy1:wy2, wx1:wx2], [wy1, wx1, wy2, wx2]
98 |
99 |
100 | def resize_image(image, max_dim):
101 | """
102 | 缩放图像为正方形,指定长边大小,短边padding;
103 | :param image: numpy 数组(H,W,3)
104 | :param max_dim: 长边大小
105 | :return: 缩放后的图像,元素图像的宽口位置,缩放尺寸,padding
106 | """
107 | image_dtype = image.dtype
108 | h, w = image.shape[:2]
109 | scale = max_dim / max(h, w) # 缩放尺寸
110 | image = transform.resize(image, (round(h * scale), round(w * scale)),
111 | order=1, mode='constant', cval=0, clip=True, preserve_range=True)
112 | h, w = image.shape[:2]
113 | # 计算padding
114 | top_pad = (max_dim - h) // 2
115 | bottom_pad = max_dim - h - top_pad
116 | left_pad = (max_dim - w) // 2
117 | right_pad = max_dim - w - left_pad
118 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
119 | image = np.pad(image, padding, mode='constant', constant_values=0)
120 | # 原始图像在缩放图像上的窗口位置
121 | window = (top_pad, left_pad, h + top_pad, w + left_pad) #
122 | return image.astype(image_dtype), window, scale, padding
123 |
124 |
125 | def compose_image_meta(image_id, original_image_shape, image_shape,
126 | window, scale):
127 | """
128 | 组合图像元数据信息,返回numpy数据
129 | :param image_id:
130 | :param original_image_shape: 原始图像形状,tuple(H,W,3)
131 | :param image_shape: 缩放后图像形状tuple(H,W,3)
132 | :param window: 原始图像在缩放图像上的窗口位置(y1,x1,y2,x2)
133 | :param scale: 缩放因子
134 | :return:
135 | """
136 | meta = np.array(
137 | [image_id] + # size=1
138 | list(original_image_shape) + # size=3
139 | list(image_shape) + # size=3
140 | list(window) + # size=4 (y1, x1, y2, x2) in image cooredinates
141 | [scale] # size=1
142 | )
143 | return meta
144 |
145 |
146 | def parse_image_meta(meta):
147 | """
148 | 解析图像元数据信息,注意输入是元数据信息数组
149 | :param meta: [12]
150 | :return:
151 | """
152 | image_id = meta[0]
153 | original_image_shape = meta[1:4]
154 | image_shape = meta[4:7]
155 | window = meta[7:11] # (y1, x1, y2, x2) window of image in in pixels
156 | scale = meta[11]
157 | return {
158 | "image_id": image_id.astype(np.int32),
159 | "original_image_shape": original_image_shape.astype(np.int32),
160 | "image_shape": image_shape.astype(np.int32),
161 | "window": window.astype(np.int32),
162 | "scale": scale.astype(np.float32)
163 | }
164 |
165 |
166 | def batch_parse_image_meta(meta):
167 | """
168 | 解析图像元数据信息,注意输入是元数据信息数组
169 | :param meta: [batch,12]
170 | :return:
171 | """
172 | image_id = meta[:, 0]
173 | original_image_shape = meta[:, 1:4]
174 | image_shape = meta[:, 4:7]
175 | window = meta[:, 7:11] # (y1, x1, y2, x2) window of image in in pixels
176 | scale = meta[:, 11]
177 | return {
178 | "image_id": image_id.astype(np.int32),
179 | "original_image_shape": original_image_shape.astype(np.int32),
180 | "image_shape": image_shape.astype(np.int32),
181 | "window": window.astype(np.int32),
182 | "scale": scale.astype(np.float32)
183 | }
184 |
185 |
186 | def adjust_box(boxes, padding, scale):
187 | """
188 | 根据填充和缩放因子,调整boxes的值
189 | :param boxes: numpy 数组; GT boxes [N,(y1,x1,y2,x2)]
190 | :param padding: [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
191 | :param scale: 缩放因子
192 | :return:
193 | """
194 | boxes = boxes * scale
195 | boxes[:, 0::2] += padding[0][0] # 高度padding
196 | boxes[:, 1::2] += padding[1][0] # 宽度padding
197 | return boxes
198 |
199 |
200 | def adjust_quadrilaterals(quadrilaterals, padding, scale):
201 | """
202 | 根据填充和缩放因子,调整四边形的值
203 | :param quadrilaterals: numpy 数组; GT quadrilaterals[N,(x1,y1,x2,y2,x3,y3,x4,y4)]
204 | :param padding: [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
205 | :param scale: 缩放因子
206 | :return:
207 | """
208 | quadrilaterals = quadrilaterals * scale
209 | quadrilaterals[:, 1::2] += padding[0][0] # 高度padding
210 | quadrilaterals[:, 0::2] += padding[1][0] # 宽度padding
211 | return quadrilaterals
212 |
213 |
214 | def recover_detect_boxes(boxes, window, scale):
215 | """
216 | 将检测边框映射到原始图像上,去除padding和缩放
217 | :param boxes: numpy数组,[n,(y1,x1,y2,x2)]
218 | :param window: [(y1,x1,y2,x2)]
219 | :param scale: 标量
220 | :return:
221 | """
222 | # 去除padding
223 | boxes[:, 0::2] -= window[0]
224 | boxes[:, 1::2] -= window[1]
225 | # 还原缩放
226 | boxes /= scale
227 | return boxes
228 |
229 |
230 | def recover_detect_quad(boxes, window, scale):
231 | """
232 | 将检测四边形映射到原始图像上,去除padding和缩放
233 | :param boxes: numpy数组,[n,(x1,y1,x2,y2,x3,y3,x4,y4)]
234 | :param window: [(y1,x1,y2,x2)]
235 | :param scale: 标量
236 | :return:
237 | """
238 | # 去除padding
239 | boxes[:, 1::2] -= window[0] # 高度
240 | boxes[:, 0::2] -= window[1] # 宽度
241 | # 还原缩放
242 | boxes /= scale
243 | return boxes
244 |
--------------------------------------------------------------------------------
/ctpn/utils/np_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: np_utils
4 | Description : numpy 工具类
5 | Author : mick.yi
6 | date: 2019/2/19
7 | """
8 |
9 | import numpy as np
10 | from shapely.geometry import Polygon
11 |
12 |
13 | def pad_to_fixed_size(input_np, fixed_size):
14 | """
15 | 增加padding到固定尺寸,在第二维增加一个标志位,0-padding,1-非padding
16 | :param input_np: 二维数组
17 | :param fixed_size:
18 | :return:
19 | """
20 | shape = input_np.shape
21 | # 增加tag
22 | np_array = np.pad(input_np, ((0, 0), (0, 1)), mode='constant', constant_values=1)
23 | # 增加padding
24 | pad_num = max(0, fixed_size - shape[0])
25 | return np.pad(np_array, ((0, pad_num), (0, 0)), mode='constant', constant_values=0)
26 |
27 |
28 | def remove_pad(input_np):
29 | """
30 | 去除padding
31 | :param input_np:
32 | :return:
33 | """
34 | pad_tag = input_np[:, -1] # 最后一维是padding 标志,1-非padding
35 | real_size = int(np.sum(pad_tag))
36 | return input_np[:real_size, :-1]
37 |
38 |
39 | def compute_iou(boxes_a, boxes_b):
40 | """
41 | numpy 计算IoU
42 | :param boxes_a: (N,4)
43 | :param boxes_b: (M,4)
44 | :return: IoU (N,M)
45 | """
46 | # 扩维
47 | boxes_a = np.expand_dims(boxes_a, axis=1) # (N,1,4)
48 | boxes_b = np.expand_dims(boxes_b, axis=0) # (1,M,4)
49 |
50 | # 分别计算高度和宽度的交集
51 | overlap_h = np.maximum(0.0,
52 | np.minimum(boxes_a[..., 2], boxes_b[..., 2]) -
53 | np.maximum(boxes_a[..., 0], boxes_b[..., 0])) # (N,M)
54 |
55 | overlap_w = np.maximum(0.0,
56 | np.minimum(boxes_a[..., 3], boxes_b[..., 3]) -
57 | np.maximum(boxes_a[..., 1], boxes_b[..., 1])) # (N,M)
58 |
59 | # 交集
60 | overlap = overlap_w * overlap_h
61 |
62 | # 计算面积
63 | area_a = (boxes_a[..., 2] - boxes_a[..., 0]) * (boxes_a[..., 3] - boxes_a[..., 1])
64 | area_b = (boxes_b[..., 2] - boxes_b[..., 0]) * (boxes_b[..., 3] - boxes_b[..., 1])
65 |
66 | # 交并比
67 | iou = overlap / (area_a + area_b - overlap)
68 | return iou
69 |
70 |
71 | def compute_iou_1vn(box, boxes, box_area, boxes_area):
72 | """Calculates IoU of the given box with the array of the given boxes.
73 | box: 1D vector [y1, x1, y2, x2]
74 | boxes: [boxes_count, (y1, x1, y2, x2)]
75 | box_area: float. the area of 'box'
76 | boxes_area: array of length boxes_count.
77 |
78 | Note: the areas are passed in rather than calculated here for
79 | efficiency. Calculate once in the caller to avoid duplicate work.
80 | """
81 | # Calculate intersection areas
82 | y1 = np.maximum(box[0], boxes[:, 0])
83 | y2 = np.minimum(box[2], boxes[:, 2])
84 | x1 = np.maximum(box[1], boxes[:, 1])
85 | x2 = np.minimum(box[3], boxes[:, 3])
86 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
87 | union = box_area + boxes_area[:] - intersection[:]
88 | iou = intersection / union
89 | return iou
90 |
91 |
92 | def threshold(coords, min_, max_):
93 | return np.maximum(np.minimum(coords, max_), min_)
94 |
95 |
96 | def clip_boxes(boxes, im_shape):
97 | """
98 | 裁剪边框到图像内
99 | :param boxes: 边框 [n,(y1,x1,y2,x2)]
100 | :param im_shape: tuple(H,W,C)
101 | :return:
102 | """
103 | boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1])
104 | boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0])
105 | return boxes
106 |
107 |
108 | def non_max_suppression(boxes, scores, iou_threshold):
109 | """
110 | 非极大抑制
111 | :param boxes: [n,(y1,x1,y2,x2)]
112 | :param scores: [n]
113 | :param iou_threshold:
114 | :return:
115 | """
116 | assert boxes.shape[0] > 0
117 | if boxes.dtype.kind != "f":
118 | boxes = boxes.astype(np.float32)
119 |
120 | # Compute box areas
121 | y1 = boxes[:, 0]
122 | x1 = boxes[:, 1]
123 | y2 = boxes[:, 2]
124 | x2 = boxes[:, 3]
125 | area = (y2 - y1) * (x2 - x1)
126 |
127 | # Get indicies of boxes sorted by scores (highest first)
128 | ixs = scores.argsort()[::-1]
129 |
130 | pick = []
131 | while len(ixs) > 0:
132 | # Pick top box and add its index to the list
133 | i = ixs[0]
134 | pick.append(i)
135 | # Compute IoU of the picked box with the rest
136 | iou = compute_iou_1vn(boxes[i], boxes[ixs[1:]], area[i], area[ixs[1:]])
137 | # Identify boxes with IoU over the threshold. This
138 | # returns indices into ixs[1:], so add 1 to get
139 | # indices into ixs.
140 | remove_ixs = np.where(iou > iou_threshold)[0] + 1
141 | # Remove indices of the picked and overlapped boxes.
142 | ixs = np.delete(ixs, remove_ixs)
143 | ixs = np.delete(ixs, 0)
144 | return np.array(pick, dtype=np.int32)
145 |
146 |
147 | def quadrangle_iou(quadrangle_a, quadrangle_b):
148 | """
149 | 四边形iou
150 | :param quadrangle_a: 一维numpy数组[(x1,y1,x2,y2,x3,y3,x4,y4)]
151 | :param quadrangle_b: 一维numpy数组[(x1,y1,x2,y2,x3,y3,x4,y4)]
152 | :return:
153 | """
154 | a = Polygon(quadrangle_a.reshape((4, 2)))
155 | b = Polygon(quadrangle_b.reshape((4, 2)))
156 | if not a.is_valid or not b.is_valid:
157 | return 0
158 | inter = Polygon(a).intersection(Polygon(b)).area
159 | union = a.area + b.area - inter
160 | if union == 0:
161 | return 0
162 | else:
163 | return inter / union
164 |
165 |
166 | def quadrangle_nms(quadrangles, scores, iou_threshold):
167 | """
168 | 四边形nms
169 | :param quadrangles: 四边形坐标,二维numpy数组[n,(x1,y1,x2,y2,x3,y3,x4,y4)]
170 | :param scores: 四边形得分,[n]
171 | :param iou_threshold: iou阈值
172 | :return:
173 | """
174 | order = np.argsort(scores)[::-1]
175 | keep = []
176 | while order.size > 0:
177 | # 选择得分最高的
178 | i = order[0]
179 | keep.append(i)
180 | # 逐个计算iou
181 | overlap = np.array([quadrangle_iou(quadrangles[i], quadrangles[t]) for t in order[1:]])
182 | # 小于阈值的,用于下一个极值点选择
183 | indices = np.where(overlap < iou_threshold)[0]
184 | order = order[indices + 1]
185 |
186 | return keep
187 |
188 |
189 | def main():
190 | x = np.zeros(shape=(0, 4))
191 | y = pad_to_fixed_size(x, 5)
192 | print(y.shape)
193 | x = np.asarray([], np.float32).reshape((-1, 4))
194 | y = pad_to_fixed_size(x, 5)
195 | print(y.shape)
196 |
197 |
198 | if __name__ == '__main__':
199 | main()
200 |
--------------------------------------------------------------------------------
/ctpn/utils/text_proposal_connector.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: text_proposal_connector
4 | Description : 文本框连接,构建文本行
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | import numpy as np
9 | from .text_proposal_graph_builder import TextProposalGraphBuilder
10 | from .np_utils import clip_boxes
11 |
12 |
13 | class TextProposalConnector:
14 | """
15 | 连接文本框构建文本行
16 | """
17 |
18 | def __init__(self):
19 | self.graph_builder = TextProposalGraphBuilder()
20 |
21 | def group_text_proposals(self, text_proposals, scores, im_size):
22 | """
23 | 将文本框连接起来,按照文本行分组
24 | :param text_proposals: 文本框,[n,(y1,x1,y2,x2)]
25 | :param scores: 文本框得分,[n]
26 | :param im_size: 图像尺寸,tuple(H,W,C)
27 | :return: list of list; 文本行列表,每个文本行是文本框索引号列表
28 | """
29 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size)
30 | return graph.sub_graphs_connected()
31 |
32 | def fit_y(self, X, Y, x1, x2):
33 | """
34 | 一元线性函数拟合X,Y,并返回x1,x2的的函数值
35 | """
36 | len(X) != 0
37 | # 只有一个点返回 y=Y[0]
38 | if np.sum(X == X[0]) == len(X):
39 | return Y[0], Y[0]
40 | p = np.poly1d(np.polyfit(X, Y, 1))
41 | return p(x1), p(x2)
42 |
43 | def get_text_lines(self, text_proposals, scores, im_size):
44 | """
45 | 获取文本行
46 | :param text_proposals: 文本框,[n,(y1,x1,y2,x2)]
47 | :param scores: 文本框得分,[n]
48 | :param im_size: 图像尺寸,tuple(H,W,C)
49 | :return: 文本行,边框和得分,numpy数组 [m,(y1,x1,y2,x2,score)]
50 | """
51 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size)
52 | text_lines = np.zeros((len(tp_groups), 9), np.float32)
53 | # print("len(tp_groups):{}".format(len(tp_groups)))
54 | # 逐个文本行处理
55 | for index, tp_indices in enumerate(tp_groups):
56 | text_line_boxes = text_proposals[list(tp_indices)]
57 | # 宽度方向最小值和最大值
58 | x_min = np.min(text_line_boxes[:, 1])
59 | x_max = np.max(text_line_boxes[:, 3])
60 | # 文本框宽度的一半
61 | offset = (text_line_boxes[0, 3] - text_line_boxes[0, 1]) * 0.5
62 | # 使用一元线性函数求文本行左右两边高度边界
63 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 1], text_line_boxes[:, 0], x_min - offset, x_max + offset)
64 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 1], text_line_boxes[:, 2], x_min - offset, x_max + offset)
65 |
66 | # 文本行的得分为所有文本框得分的均值
67 | score = scores[list(tp_indices)].sum() / float(len(tp_indices))
68 | # 文本行坐标
69 | text_lines[index, 0] = x_min
70 | text_lines[index, 1] = lt_y
71 | text_lines[index, 2] = x_max
72 | text_lines[index, 3] = rt_y
73 | text_lines[index, 4] = x_max
74 | text_lines[index, 5] = rb_y
75 | text_lines[index, 6] = x_min
76 | text_lines[index, 7] = lb_y
77 | text_lines[index, 8] = score
78 | # 裁剪到图像尺寸内
79 | text_lines = clip_boxes(text_lines, im_size)
80 |
81 | return text_lines
82 |
--------------------------------------------------------------------------------
/ctpn/utils/text_proposal_graph_builder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: text_proposal_graph_builder
4 | Description : 文本提议框 图构建;构建文本框对
5 | Author : mick.yi
6 | date: 2019/3/12
7 | """
8 | import numpy as np
9 |
10 |
11 | class Graph(object):
12 | def __init__(self, graph):
13 | self.graph = graph
14 |
15 | def sub_graphs_connected(self):
16 | """
17 | 根据图对生成文本行
18 | :return: list of list; 文本行列表,每个文本行是文本框索引号列表
19 | """
20 | sub_graphs = []
21 | for index in range(self.graph.shape[0]):
22 | if not self.graph[:, index].any() and self.graph[index, :].any():
23 | v = index
24 | sub_graphs.append([v])
25 | while self.graph[v, :].any():
26 | v = np.where(self.graph[v, :])[0][0]
27 | sub_graphs[-1].append(v)
28 | return sub_graphs
29 |
30 |
31 | class TextProposalGraphBuilder(object):
32 | """
33 | 文本框的配对构建
34 | """
35 |
36 | def __init__(self, max_horizontal_gap=50, min_vertical_overlaps=0.7, min_size_similarity=0.7):
37 | """
38 |
39 | :param max_horizontal_gap: 文本行内,文本框最大水平距离,超出此距离的文本框属于不同的文本行
40 | :param min_vertical_overlaps:文本框最小垂直Iou
41 | :param min_size_similarity: 文本框尺寸最小相似度
42 | """
43 |
44 | self.max_horizontal_gap = max_horizontal_gap
45 | self.min_vertical_overlaps = min_vertical_overlaps
46 | self.min_size_similarity = min_size_similarity
47 | self.text_proposals = None
48 | self.scores = None
49 | self.im_size = None
50 | self.heights = None
51 | self.boxes_table = None
52 |
53 | def get_successions(self, index):
54 | """
55 | 获取指定索引号文本框的后继文本框
56 | :param index: 文本框索引号
57 | :return: 所有后继文本框的索引号列表
58 | """
59 | box = self.text_proposals[index]
60 | results = []
61 | for left in range(int(box[1]) + 1, min(int(box[1]) + self.max_horizontal_gap + 1, self.im_size[1])):
62 | adj_box_indices = self.boxes_table[left]
63 | for adj_box_index in adj_box_indices:
64 | if self.meet_v_iou(adj_box_index, index):
65 | results.append(adj_box_index)
66 | if len(results) != 0:
67 | return results
68 | return results
69 |
70 | def get_precursors(self, index):
71 | """
72 | 获取指定索引号文本框的前驱文本框
73 | :param index: 文本框索引号
74 | :return: 所有前驱文本框的索引号列表
75 | """
76 | box = self.text_proposals[index]
77 | results = []
78 | # 向前遍历
79 | for left in range(int(box[1]) - 1, max(int(box[1] - self.max_horizontal_gap), 0) - 1, -1):
80 | adj_box_indices = self.boxes_table[left]
81 | for adj_box_index in adj_box_indices:
82 | if self.meet_v_iou(adj_box_index, index):
83 | results.append(adj_box_index)
84 | if len(results) != 0:
85 | return results
86 | return results
87 |
88 | def is_succession_node(self, index, succession_index):
89 | """
90 | 是否是配对的文本框
91 | :param index: 文本框索引号
92 | :param succession_index: 后继文本框索引号,注:此文本框是后继文本框中
93 | :return:
94 | """
95 | precursors = self.get_precursors(succession_index)
96 | if self.scores[index] >= np.max(self.scores[precursors]):
97 | return True
98 | return False
99 |
100 | def meet_v_iou(self, index1, index2):
101 | """
102 | 两个文本框是否满足垂直方向iou条件
103 | :param index1:
104 | :param index2:
105 | :return: True or False
106 | """
107 |
108 | def overlaps_v(idx1, idx2):
109 | """
110 | 两个边框垂直方向的iou
111 | """
112 | # 边框高宽
113 | h1 = self.heights[idx1]
114 | h2 = self.heights[idx2]
115 | # 垂直方向的交集
116 | max_y1 = max(self.text_proposals[idx2][0], self.text_proposals[idx1][0])
117 | min_y2 = min(self.text_proposals[idx2][2], self.text_proposals[idx1][2])
118 | return max(0, min_y2 - max_y1) / min(h1, h2)
119 |
120 | def size_similarity(idx1, idx2):
121 | """
122 | 两个边框高度尺寸相似度
123 | """
124 | h1 = self.heights[idx1]
125 | h2 = self.heights[idx2]
126 | return min(h1, h2) / max(h1, h2)
127 |
128 | return overlaps_v(index1, index2) >= self.min_vertical_overlaps and \
129 | size_similarity(index1, index2) >= self.min_size_similarity
130 |
131 | def build_graph(self, text_proposals, scores, im_size):
132 | """
133 | 根据文本框构建文本框对
134 | :param text_proposals: 文本框,numpy 数组,[n,(y1,x1,y2,x2)]
135 | :param scores: 文本框得分,[n]
136 | :param im_size: 图像尺寸,tuple(H,W,C)
137 | :return: 返回二维bool类型 numpy数组,[n,n];指示文本框两两之间是否配对
138 | """
139 | self.text_proposals = text_proposals
140 | self.scores = scores
141 | self.im_size = im_size
142 | self.heights = text_proposals[:, 2] - text_proposals[:, 0] # 所有文本框的高宽
143 |
144 | # 安装每个文本框左侧坐标x1分组
145 | im_width = self.im_size[1]
146 | boxes_table = [[] for _ in range(im_width)]
147 | for index, box in enumerate(text_proposals):
148 | boxes_table[int(box[1])].append(index)
149 | self.boxes_table = boxes_table
150 |
151 | # 构建文本对,numpy数组[N,N],bool类型;如果
152 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)
153 |
154 | for index, box in enumerate(text_proposals):
155 | # 获取当前文本框(Bi)的后继文本框
156 | successions = self.get_successions(index)
157 | if len(successions) == 0:
158 | continue
159 | # 后继文本框中得分最高的那个,记做Bj
160 | succession_index = successions[np.argmax(scores[successions])]
161 | # 获取Bj的前驱文本框
162 | precursors = self.get_precursors(succession_index)
163 | # print("{},{},{}".format(index, succession_index, precursors))
164 | # 如果Bi也是,也是Bj的前驱文本框中,得分最高的那个;则Bi,Bj构成文本框对
165 | if self.scores[index] >= np.max(self.scores[precursors]):
166 | graph[index, succession_index] = True
167 | return Graph(graph)
168 |
--------------------------------------------------------------------------------
/ctpn/utils/tf_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: tf_utils
4 | Description : tensorflow工具类
5 | Author : mick.yi
6 | date: 2019/3/13
7 | """
8 | import tensorflow as tf
9 | from deprecated import deprecated
10 |
11 |
12 | @deprecated(reason='建议使用原生tf.map_fn;效率更高,并且不需要显示传入batch_size参数')
13 | def batch_slice(inputs, graph_fn, batch_size, names=None):
14 | """
15 | 将输入分片,然后每个分片执行指定计算,最后组合结果;适用于批量处理计算图逻辑只支持一个实例的情况
16 | :param inputs: tensor列表
17 | :param graph_fn: 计算逻辑
18 | :param batch_size:
19 | :param names:
20 | :return:
21 | """
22 | if not isinstance(inputs, list):
23 | inputs = [inputs]
24 |
25 | outputs = []
26 | for i in range(batch_size):
27 | inputs_slice = [x[i] for x in inputs]
28 | output_slice = graph_fn(*inputs_slice)
29 | if not isinstance(output_slice, (tuple, list)):
30 | output_slice = [output_slice]
31 | outputs.append(output_slice)
32 |
33 | # 行转列
34 | outputs = list(zip(*outputs))
35 |
36 | if names is None:
37 | names = [None] * len(outputs)
38 | # list转tensor
39 | result = [tf.stack(o, axis=0, name=n)
40 | for o, n in zip(outputs, names)]
41 | # 如果返回单个值,不使用list
42 | if len(result) == 1:
43 | result = result[0]
44 |
45 | return result
46 |
47 |
48 | def pad_to_fixed_size_with_negative(input_tensor, fixed_size, negative_num, data_type=tf.float32):
49 | # 输入尺寸
50 | input_size = tf.shape(input_tensor)[0]
51 | # tag 列 padding
52 | positive_num = input_size - negative_num # 正例数
53 | # 正样本padding 1,负样本padding -1
54 | column_padding = tf.concat([tf.ones([positive_num], data_type),
55 | tf.ones([negative_num], data_type) * -1],
56 | axis=0)
57 | # 都转为float,拼接
58 | x = tf.concat([tf.cast(input_tensor, data_type), tf.expand_dims(column_padding, axis=1)], axis=1)
59 | # 不够的padding 0
60 | padding_size = tf.maximum(0, fixed_size - input_size)
61 | x = tf.pad(x, [[0, padding_size], [0, 0]], mode='CONSTANT', constant_values=0)
62 | return x
63 |
64 |
65 | def pad_to_fixed_size(input_tensor, fixed_size):
66 | """
67 | 增加padding到固定尺寸,在第二维增加一个标志位,0-padding,1-非padding
68 | :param input_tensor: 二维张量
69 | :param fixed_size:
70 | :param negative_num: 负样本数量
71 | :return:
72 | """
73 | input_size = tf.shape(input_tensor)[0]
74 | x = tf.pad(input_tensor, [[0, 0], [0, 1]], mode='CONSTANT', constant_values=1)
75 | # padding
76 | padding_size = tf.maximum(0, fixed_size - input_size)
77 | x = tf.pad(x, [[0, padding_size], [0, 0]], mode='CONSTANT', constant_values=0)
78 | return x
79 |
80 |
81 | def pad_list_to_fixed_size(tensor_list, fixed_size):
82 | return [pad_to_fixed_size(tensor, fixed_size) for tensor in tensor_list]
83 |
84 |
85 | def remove_pad(input_tensor):
86 | """
87 |
88 | :param input_tensor:
89 | :return:
90 | """
91 | pad_tag = input_tensor[..., -1]
92 | real_size = tf.cast(tf.reduce_sum(pad_tag), tf.int32)
93 | return input_tensor[:real_size, :-1]
94 |
95 |
96 | def clip_boxes(boxes, window):
97 | """
98 | 将boxes裁剪到指定的窗口范围内
99 | :param boxes: 边框坐标,[N,(y1,x1,y2,x2)]
100 | :param window: 窗口坐标,[(y1,x1,y2,x2)]
101 | :return:
102 | """
103 | wy1, wx1, wy2, wx2 = tf.split(window, 4)
104 | y1, x1, y2, x2 = tf.split(boxes, 4, axis=1) # split后维数不变
105 |
106 | y1 = tf.maximum(tf.minimum(y1, wy2), wy1) # wy1<=y1<=wy2
107 | y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
108 | x1 = tf.maximum(tf.minimum(x1, wx2), wx1)
109 | x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
110 |
111 | clipped_boxes = tf.concat([y1, x1, y2, x2], axis=1, name='clipped_boxes')
112 | # clipped_boxes.([boxes.shape[0], 4])
113 | return clipped_boxes
114 |
115 |
116 | def apply_regress(deltas, anchors):
117 | """
118 | 应用回归目标到边框
119 | :param deltas: 回归目标[N,(dy, dx, dh, dw)]
120 | :param anchors: anchor boxes[N,(y1,x1,y2,x2)]
121 | :return:
122 | """
123 | # 高度和宽度
124 | h = anchors[:, 2] - anchors[:, 0]
125 | w = anchors[:, 3] - anchors[:, 1]
126 |
127 | # 中心点坐标
128 | cy = (anchors[:, 2] + anchors[:, 0]) * 0.5
129 | cx = (anchors[:, 3] + anchors[:, 1]) * 0.5
130 |
131 | # 回归系数
132 | deltas *= tf.constant([0.1, 0.1, 0.2, 0.2])
133 | dy, dx, dh, dw = deltas[:, 0], deltas[:, 1], deltas[:, 2], deltas[:, 3]
134 |
135 | # 中心坐标回归
136 | cy += dy * h
137 | cx += dx * w
138 | # 高度和宽度回归
139 | h *= tf.exp(dh)
140 | w *= tf.exp(dw)
141 |
142 | # 转为y1,x1,y2,x2
143 | y1 = cy - h * 0.5
144 | x1 = cx - w * 0.5
145 | y2 = cy + h * 0.5
146 | x2 = cx + w * 0.5
147 |
148 | return tf.stack([y1, x1, y2, x2], axis=1)
149 |
--------------------------------------------------------------------------------
/ctpn/utils/visualize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: visualize
4 | Description : 可视化
5 | Author : mick.yi
6 | date: 2019/2/20
7 | """
8 |
9 | import matplotlib.pyplot as plt
10 | from matplotlib import patches
11 | import random
12 | import colorsys
13 | import numpy as np
14 |
15 |
16 | def random_colors(N, bright=True):
17 | """
18 | 生成随机RGB颜色
19 | :param N: 颜色数量
20 | :param bright:
21 | :return:
22 | """
23 | brightness = 1.0 if bright else 0.7
24 | hsv = [(i / N, 1, brightness) for i in range(N)]
25 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
26 | random.shuffle(colors)
27 | return colors
28 |
29 |
30 | def display_boxes(image, boxes,
31 | scores=None, title="",
32 | figsize=(16, 16), ax=None,
33 | show_bbox=True,
34 | colors=None):
35 | """
36 | 可视化实例
37 | :param image: numpy数组,[h,w,c}
38 | :param boxes: 边框坐标 [num_instance,(y1,y2,x1,x2)]
39 | :param scores: (optional)预测类别得分[num_instances]
40 | :param title: (optional)标题
41 | :param figsize: (optional)
42 | :param ax:(optional)
43 | :param show_bbox:(optional)
44 | :param colors:(optional)
45 | :return:
46 | """
47 |
48 | # Number of instances
49 | N = boxes.shape[0]
50 | if not N:
51 | print("\n*** No instances to display *** \n")
52 |
53 | # If no axis is passed, create one and automatically call show()
54 | auto_show = False
55 | if not ax:
56 | _, ax = plt.subplots(1, figsize=figsize)
57 | auto_show = True
58 |
59 | # Generate random colors
60 | colors = colors or random_colors(N)
61 |
62 | # Show area outside image boundaries.
63 | height, width = image.shape[:2]
64 | ax.set_ylim(height + 10, -10)
65 | ax.set_xlim(-10, width + 10)
66 | ax.axis('on')
67 | ax.set_title(title)
68 |
69 | masked_image = image.astype(np.uint32).copy()
70 | for i in range(N):
71 | color = colors[i]
72 |
73 | # Bounding box
74 | if not np.any(boxes[i]):
75 | # Skip this instance. Has no bbox. Likely lost in image cropping.
76 | continue
77 | y1, x1, y2, x2 = boxes[i]
78 | if show_bbox:
79 | p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
80 | alpha=0.7, linestyle="dashed",
81 | edgecolor=color, facecolor='none')
82 | ax.add_patch(p)
83 | ax.text(x1, y1 + 8, scores[i] if scores is not None else '',
84 | color='w', size=11, backgroundcolor="none")
85 |
86 | ax.imshow(masked_image.astype(np.uint8))
87 | if auto_show:
88 | plt.show()
89 |
90 |
91 | def display_polygons(image, polygons, scores=None, figsize=(16, 16), ax=None, colors=None):
92 | auto_show = False
93 | if ax is None:
94 | _, ax = plt.subplots(1, figsize=figsize)
95 | auto_show = True
96 | if colors is None:
97 | colors = random_colors(len(polygons))
98 |
99 | height, width = image.shape[:2]
100 | ax.set_ylim(height + 10, -10)
101 | ax.set_xlim(-10, width + 10)
102 | ax.axis('off')
103 |
104 | for i, polygon in enumerate(polygons):
105 | color = colors[i]
106 | polygon = np.reshape(polygon, (-1, 2)) # 转为[n,(x,y)]
107 | patch = patches.Polygon(polygon, facecolor=None, fill=False, color=color)
108 | ax.add_patch(patch)
109 | # 多边形得分
110 | x1, y1 = polygon[0][:]
111 | ax.text(x1, y1 - 1, scores[i] if scores is not None else '',
112 | color='w', size=11, backgroundcolor="none")
113 | ax.imshow(image.astype(np.uint8))
114 | if auto_show:
115 | plt.show()
116 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: evaluate
4 | Description : 评估入口
5 | Author : mick.yi
6 | date: 2019/3/21
7 | """
8 |
9 | import sys
10 | import os
11 | import numpy as np
12 | import argparse
13 | from ctpn.utils import image_utils, file_utils, np_utils
14 | from ctpn.utils.detector import TextDetector
15 | from ctpn.config import cur_config as config
16 | from ctpn.layers import models
17 | import datetime
18 |
19 |
20 | def generator(image_path_list, image_shape):
21 | for i, image_path in enumerate(image_path_list):
22 | image, image_meta, _, _ = image_utils.load_image_gt(np.random.randint(10),
23 | image_path,
24 | image_shape[0])
25 | if i % 200 == 0:
26 | print("开始评估第 {} 张图像".format(i))
27 | yield {"input_image": np.asarray([image]),
28 | "input_image_meta": np.asarray([image_meta])}
29 |
30 |
31 | def main(args):
32 | # 覆盖参数
33 | config.USE_SIDE_REFINE = bool(args.use_side_refine)
34 | if args.weight_path is not None:
35 | config.WEIGHT_PATH = args.weight_path
36 | config.IMAGES_PER_GPU = 1
37 | config.IMAGE_SHAPE = (1024, 1024, 3)
38 | # 图像路径
39 | image_path_list = file_utils.get_sub_files(args.image_dir)
40 |
41 | # 加载模型
42 | m = models.ctpn_net(config, 'test')
43 | m.load_weights(config.WEIGHT_PATH, by_name=True)
44 |
45 | # 预测
46 | start_time = datetime.datetime.now()
47 | gen = generator(image_path_list, config.IMAGE_SHAPE)
48 | text_boxes, text_scores, image_metas = m.predict_generator(generator=gen,
49 | steps=len(image_path_list),
50 | use_multiprocessing=True)
51 | end_time = datetime.datetime.now()
52 | print("======完成{}张图像评估,耗时:{} 秒".format(len(image_path_list), end_time - start_time))
53 | # 去除padding
54 | text_boxes = [np_utils.remove_pad(text_box) for text_box in text_boxes]
55 | text_scores = [np_utils.remove_pad(text_score)[:, 0] for text_score in text_scores]
56 | image_metas = image_utils.batch_parse_image_meta(image_metas)
57 | # 文本行检测
58 | detector = TextDetector(config)
59 | text_lines = [detector.detect(boxes, scores, config.IMAGE_SHAPE, window)
60 | for boxes, scores, window in zip(text_boxes, text_scores, image_metas["window"])]
61 | # 还原检测文本行边框到原始图像坐标
62 | text_lines = [image_utils.recover_detect_quad(boxes, window, scale)
63 | for boxes, window, scale in zip(text_lines, image_metas["window"], image_metas["scale"])]
64 |
65 | # 写入文档中
66 | for image_path, boxes in zip(image_path_list, text_lines):
67 | output_filename = os.path.splitext('res_' + os.path.basename(image_path))[0] + '.txt'
68 | with open(os.path.join(args.output_dir, output_filename), mode='w') as f:
69 | for box in boxes.astype(np.int32):
70 | f.write("{},{},{},{},{},{},{},{}\r\n".format(box[0],
71 | box[1],
72 | box[2],
73 | box[3],
74 | box[4],
75 | box[5],
76 | box[6],
77 | box[7]))
78 |
79 |
80 | if __name__ == '__main__':
81 | parse = argparse.ArgumentParser()
82 | parse.add_argument("--image_dir", type=str, help="image dir")
83 | parse.add_argument("--output_dir", type=str, help="output dir")
84 | parse.add_argument("--weight_path", type=str, default=None, help="weight path")
85 | parse.add_argument("--use_side_refine", type=int, default=1, help="1: use side refine; 0 not use")
86 | argments = parse.parse_args(sys.argv[1:])
87 | main(argments)
88 |
--------------------------------------------------------------------------------
/image_examples/a0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a0.png
--------------------------------------------------------------------------------
/image_examples/a1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a1.png
--------------------------------------------------------------------------------
/image_examples/a2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a2.png
--------------------------------------------------------------------------------
/image_examples/a3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/a3.png
--------------------------------------------------------------------------------
/image_examples/bkgd_1_0_generated_0.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/bkgd_1_0_generated_0.1.jpg
--------------------------------------------------------------------------------
/image_examples/flip1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/flip1.png
--------------------------------------------------------------------------------
/image_examples/flip2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/flip2.png
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_200.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_200.0.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_200.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_200.1.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_5.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_5.0.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_5.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_5.1.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_8.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_8.0.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2015/img_8.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2015/img_8.1.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2017/ts_img_01000.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01000.0.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2017/ts_img_01000.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01000.1.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2017/ts_img_01001.0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01001.0.jpg
--------------------------------------------------------------------------------
/image_examples/icdar2017/ts_img_01001.1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-ctpn/3da503b52f2d4657427bfc2b4c647f49d690e8db/image_examples/icdar2017/ts_img_01001.1.jpg
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: predict
4 | Description : 模型预测
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | import os
9 | import sys
10 | import numpy as np
11 | import argparse
12 | import matplotlib
13 |
14 | matplotlib.use('Agg')
15 | from matplotlib import pyplot as plt
16 | from ctpn.utils import image_utils, np_utils, visualize
17 | from ctpn.utils.detector import TextDetector
18 | from ctpn.config import cur_config as config
19 | from ctpn.layers import models
20 |
21 |
22 | def main(args):
23 | # 覆盖参数
24 | config.USE_SIDE_REFINE = bool(args.use_side_refine)
25 | if args.weight_path is not None:
26 | config.WEIGHT_PATH = args.weight_path
27 | config.IMAGES_PER_GPU = 1
28 | config.IMAGE_SHAPE = (1024, 1024, 3)
29 | # 加载图片
30 | image, image_meta, _, _ = image_utils.load_image_gt(np.random.randint(10),
31 | args.image_path,
32 | config.IMAGE_SHAPE[0],
33 | None)
34 | # 加载模型
35 | m = models.ctpn_net(config, 'test')
36 | m.load_weights(config.WEIGHT_PATH, by_name=True)
37 | # m.summary()
38 |
39 | # 模型预测
40 | text_boxes, text_scores, _ = m.predict([np.array([image]), np.array([image_meta])])
41 | text_boxes = np_utils.remove_pad(text_boxes[0])
42 | text_scores = np_utils.remove_pad(text_scores[0])[:, 0]
43 |
44 | # 文本行检测器
45 | image_meta = image_utils.parse_image_meta(image_meta)
46 | detector = TextDetector(config)
47 | text_lines = detector.detect(text_boxes, text_scores, config.IMAGE_SHAPE, image_meta['window'])
48 | # 可视化保存图像
49 | boxes_num = 30
50 | fig = plt.figure(figsize=(16, 16))
51 | ax = fig.add_subplot(1, 1, 1)
52 | visualize.display_polygons(image, text_lines[:boxes_num, :8], text_lines[:boxes_num, 8],
53 | ax=ax)
54 | image_name = os.path.basename(args.image_path)
55 | fig.savefig('{}.{}.jpg'.format(os.path.splitext(image_name)[0], int(config.USE_SIDE_REFINE)))
56 |
57 |
58 | if __name__ == '__main__':
59 | parse = argparse.ArgumentParser()
60 | parse.add_argument("--image_path", type=str, help="image path")
61 | parse.add_argument("--weight_path", type=str, default=None, help="weight path")
62 | parse.add_argument("--use_side_refine", type=int, default=1, help="1: use side refine; 0 not use")
63 | argments = parse.parse_args(sys.argv[1:])
64 | main(argments)
65 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: train
4 | Description : ctpn训练
5 | Author : mick.yi
6 | date: 2019/3/14
7 | """
8 | import os
9 | import sys
10 | import tensorflow as tf
11 | import keras
12 | import argparse
13 | from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau
14 | from ctpn.layers import models
15 | from ctpn.config import cur_config as config
16 | from ctpn.utils import file_utils
17 | from ctpn.utils.generator import generator
18 | from ctpn.preprocess import reader
19 |
20 |
21 | def set_gpu_growth():
22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
23 | cfg = tf.ConfigProto()
24 | cfg.gpu_options.allow_growth = True
25 | session = tf.Session(config=cfg)
26 | keras.backend.set_session(session)
27 |
28 |
29 | def get_call_back():
30 | """
31 | 定义call back
32 | :return:
33 | """
34 | checkpoint = ModelCheckpoint(filepath='/tmp/ctpn.{epoch:03d}.h5',
35 | monitor='val_loss',
36 | verbose=1,
37 | save_best_only=False,
38 | save_weights_only=True,
39 | period=5)
40 |
41 | # 验证误差没有提升
42 | lr_reducer = ReduceLROnPlateau(monitor='loss',
43 | factor=0.1,
44 | cooldown=0,
45 | patience=10,
46 | min_lr=1e-4)
47 | log = TensorBoard(log_dir='log')
48 | return [lr_reducer, checkpoint, log]
49 |
50 |
51 | def main(args):
52 | set_gpu_growth()
53 | # 加载标注
54 | annotation_files = file_utils.get_sub_files(config.IMAGE_GT_DIR)
55 | image_annotations = [reader.load_annotation(file,
56 | config.IMAGE_DIR) for file in annotation_files]
57 | # 过滤不存在的图像,ICDAR2017中部分图像找不到
58 | image_annotations = [ann for ann in image_annotations if os.path.exists(ann['image_path'])]
59 | # 加载模型
60 | m = models.ctpn_net(config, 'train')
61 | models.compile(m, config, loss_names=['ctpn_regress_loss', 'ctpn_class_loss', 'side_regress_loss'])
62 | # 增加度量
63 | output = models.get_layer(m, 'ctpn_target').output
64 | models.add_metrics(m, ['gt_num', 'pos_num', 'neg_num', 'gt_min_iou', 'gt_avg_iou'], output[-5:])
65 | if args.init_epochs > 0:
66 | m.load_weights(args.weight_path, by_name=True)
67 | else:
68 | m.load_weights(config.PRE_TRAINED_WEIGHT, by_name=True)
69 | m.summary()
70 | # 生成器
71 | gen = generator(image_annotations[:-100],
72 | config.IMAGES_PER_GPU,
73 | config.IMAGE_SHAPE,
74 | config.ANCHORS_WIDTH,
75 | config.MAX_GT_INSTANCES,
76 | horizontal_flip=False,
77 | random_crop=False)
78 | val_gen = generator(image_annotations[-100:],
79 | config.IMAGES_PER_GPU,
80 | config.IMAGE_SHAPE,
81 | config.ANCHORS_WIDTH,
82 | config.MAX_GT_INSTANCES)
83 |
84 | # 训练
85 | m.fit_generator(gen,
86 | steps_per_epoch=len(image_annotations) // config.IMAGES_PER_GPU * 2,
87 | epochs=args.epochs,
88 | initial_epoch=args.init_epochs,
89 | validation_data=val_gen,
90 | validation_steps=100 // config.IMAGES_PER_GPU,
91 | verbose=True,
92 | callbacks=get_call_back(),
93 | workers=2,
94 | use_multiprocessing=True)
95 |
96 | # 保存模型
97 | m.save(config.WEIGHT_PATH)
98 |
99 |
100 | if __name__ == '__main__':
101 | parse = argparse.ArgumentParser()
102 | parse.add_argument("--epochs", type=int, default=100, help="epochs")
103 | parse.add_argument("--init_epochs", type=int, default=0, help="epochs")
104 | parse.add_argument("--weight_path", type=str, default=None, help="weight path")
105 | argments = parse.parse_args(sys.argv[1:])
106 | main(argments)
107 |
--------------------------------------------------------------------------------