├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── app.spec ├── app_cn.py ├── category.py ├── compat ├── __init__.py └── upgrade.py ├── config.py ├── constants.py ├── core.py ├── decoder.py ├── encoder.py ├── exception.py ├── fc ├── __init__.py ├── cnn.py └── rnn.py ├── fuse_model.py ├── gui ├── __init__.py ├── data_augmentation.py ├── pretreatment.py └── utils.py ├── loss.py ├── make_dataset.py ├── middleware ├── __init__.py └── random_captcha.py ├── model.template ├── network ├── CNN.py ├── DenseNet.py ├── GRU.py ├── LSTM.py ├── MobileNet.py ├── ResNet.py └── utils.py ├── optimizer ├── AdaBound.py ├── RAdam.py └── __init__.py ├── predict_testing.py ├── pretreatment.py ├── requirements.txt ├── resource ├── VERSION ├── captcha_snapshot.png ├── icon.ico ├── logo.png ├── main.png ├── net_structure.png └── sample_process.png ├── test ├── __init__.py └── test_preprocessing_by_func.py ├── tf_graph_util.py ├── tf_onnx_util2.py ├── tools ├── delete_repeat_img.py ├── gif_frames.py └── package.py ├── trains.py ├── utils ├── __init__.py ├── category_frequency_statistics.py ├── data.py └── sparse.py └── validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | # C extensions 8 | *.so 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *,cover 43 | .hypothesis/ 44 | # Translations 45 | *.mo 46 | *.pot 47 | # Django stuff: 48 | *.log 49 | local_settings.py 50 | # Flask stuff: 51 | instance/ 52 | .webassets-cache 53 | # Scrapy stuff: 54 | .scrapy 55 | # Sphinx documentation 56 | docs/_build/ 57 | # PyBuilder 58 | target/ 59 | # Jupyter Notebook 60 | .ipynb_checkpoints 61 | # pyenv 62 | .python-version 63 | # celery beat schedule file 64 | celerybeat-schedule 65 | # SageMath parsed files 66 | *.sage.py 67 | # dotenv 68 | .env 69 | # virtualenv 70 | .venv 71 | venv/ 72 | ENV/ 73 | # Spyder project settings 74 | .spyderproject 75 | # Rope project settings 76 | .ropeproject 77 | # jetbrains IDE项目配置文件 78 | .idea/ 79 | # image file 80 | *.jpg 81 | *.gif 82 | # linux系统生成的文件 83 | *.pid 84 | nohup.out 85 | 86 | projects/* 87 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 The TensorFlow Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /app.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | import time 3 | # block_cipher = pyi_crypto.PyiBlockCipher(key='') 4 | block_cipher = None 5 | 6 | added_files = [('resource/icon.ico', 'resource'), ('model.template', '.'), ('resource/VERSION', 'astor'), ('resource/VERSION', 'resource')] 7 | 8 | a = Analysis(['app.py'], 9 | pathex=['.'], 10 | binaries=[], 11 | datas=added_files, 12 | hiddenimports=['numpy.core._dtype_ctypes', 'pkg_resources.py2_warn'], 13 | hookspath=[], 14 | runtime_hooks=[], 15 | excludes=[], 16 | win_no_prefer_redirects=False, 17 | win_private_assemblies=False, 18 | cipher=block_cipher, 19 | noarchive=False) 20 | pyz = PYZ(a.pure, a.zipped_data, 21 | cipher=block_cipher) 22 | exe = EXE(pyz, 23 | a.scripts, 24 | [], 25 | exclude_binaries=True, 26 | name='app', 27 | debug=False, 28 | bootloader_ignore_signals=False, 29 | strip=False, 30 | upx=True, 31 | console=True, 32 | icon='resource/icon.ico') 33 | coll = COLLECT(exe, 34 | a.binaries, 35 | a.zipfiles, 36 | a.datas, 37 | strip=False, 38 | upx=True, 39 | upx_exclude=[], 40 | name='gpu-win64-{}'.format(time.strftime("%Y%m%d", time.localtime()))) 41 | -------------------------------------------------------------------------------- /compat/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /compat/upgrade.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import yaml 5 | import json 6 | 7 | 8 | class ModelConfig: 9 | 10 | def __init__(self, model_conf: str): 11 | self.model_conf = model_conf 12 | self.system = None 13 | self.device = None 14 | self.device_usage = None 15 | self.charset = None 16 | self.split_char = None 17 | self.gen_charset = None 18 | self.char_exclude = None 19 | self.model_name = None 20 | self.model_type = None 21 | self.image_height = None 22 | self.image_width = None 23 | self.image_channel = None 24 | self.padding = None 25 | self.lower_padding = None 26 | self.resize = None 27 | self.binaryzation = None 28 | self.smooth = None 29 | self.blur = None 30 | self.replace_transparent = None 31 | self.model_site = None 32 | self.version = None 33 | self.color_engine = None 34 | self.cf_model = self.read_conf 35 | self.model_exists = False 36 | self.assignment() 37 | 38 | def assignment(self): 39 | 40 | system = self.cf_model.get('System') 41 | self.device = system.get('Device') if system else None 42 | self.device = self.device if self.device else "cpu:0" 43 | self.device_usage = system.get('DeviceUsage') if system else None 44 | self.device_usage = self.device_usage if self.device_usage else 0.02 45 | self.charset = self.cf_model['Model'].get('CharSet') 46 | self.char_exclude = self.cf_model['Model'].get('CharExclude') 47 | self.model_name = self.cf_model['Model'].get('ModelName') 48 | self.model_type = self.cf_model['Model'].get('ModelType') 49 | self.model_site = self.cf_model['Model'].get('Sites') 50 | self.model_site = self.model_site if self.model_site else [] 51 | self.version = self.cf_model['Model'].get('Version') 52 | self.version = self.version if self.version else 1.0 53 | self.split_char = self.cf_model['Model'].get('SplitChar') 54 | self.split_char = '' if not self.split_char else self.split_char 55 | 56 | self.image_height = self.cf_model['Model'].get('ImageHeight') 57 | self.image_width = self.cf_model['Model'].get('ImageWidth') 58 | self.image_channel = self.cf_model['Model'].get('ImageChannel') 59 | self.image_channel = self.image_channel if self.image_channel else 1 60 | self.binaryzation = self.cf_model['Pretreatment'].get('Binaryzation') 61 | self.resize = self.cf_model['Pretreatment'].get('Resize') 62 | self.resize = self.resize if self.resize else [self.image_width, self.image_height] 63 | self.replace_transparent = self.cf_model['Pretreatment'].get('ReplaceTransparent') 64 | 65 | @property 66 | def read_conf(self): 67 | with open(self.model_conf, 'r', encoding="utf-8") as sys_fp: 68 | sys_stream = sys_fp.read() 69 | return yaml.load(sys_stream, Loader=yaml.SafeLoader) 70 | 71 | def convert(self): 72 | with open("../model.template", encoding="utf8") as f: 73 | lines = f.readlines() 74 | bc = "".join(lines) 75 | model = bc.format( 76 | MemoryUsage=0.7, 77 | CNNNetwork='CNNX', 78 | RecurrentNetwork='GRU', 79 | UnitsNum=64, 80 | Optimizer='Adam', 81 | LossFunction='CTC', 82 | Decoder='CTC', 83 | ModelName=self.model_name, 84 | ModelField='Image', 85 | ModelScene='Classification', 86 | Category=self.charset, 87 | Resize=json.dumps(self.resize), 88 | ImageChannel=self.image_channel, 89 | ImageWidth=self.image_width, 90 | ImageHeight=self.image_height, 91 | MaxLabelNum=4, 92 | AutoPadding=False, 93 | OutputSplit="", 94 | LabelFrom="FileName", 95 | ExtractRegex=".*?(?=_)", 96 | LabelSplit='null', 97 | DatasetTrainsPath="", 98 | DatasetValidationPath="", 99 | SourceTrainPath="", 100 | SourceValidationPath="", 101 | ValidationSetNum="300", 102 | SavedSteps="500", 103 | ValidationSteps="500", 104 | EndAcc="0.98", 105 | EndCost="0.05", 106 | EndEpochs="2", 107 | BatchSize="64", 108 | ValidationBatchSize="300", 109 | LearningRate="0.001", 110 | DA_Binaryzation="-1", 111 | DA_MedianBlur="-1", 112 | DA_GaussianBlur="-1", 113 | DA_EqualizeHist="False", 114 | DA_Laplace="False", 115 | DA_WarpPerspective="False", 116 | DA_Rotate="-1", 117 | DA_PepperNoise="-1", 118 | DA_Brightness="False", 119 | DA_Saturation="False", 120 | DA_Hue="False", 121 | DA_Gamma="False", 122 | DA_ChannelSwap="False", 123 | DA_RandomBlank="-1", 124 | DA_RandomTransition="-1", 125 | Pre_Binaryzation="-1", 126 | Pre_ReplaceTransparent="False", 127 | Pre_HorizontalStitching="False", 128 | Pre_ConcatFrames="-1", 129 | Pre_BlendFrames="-1", 130 | DA_RandomCaptcha="", 131 | Pre_ExecuteMap="", 132 | ) 133 | open(self.model_conf.replace(".yaml", "_2.0.yaml"), "w", encoding="utf8").write(model) 134 | 135 | 136 | if __name__ == '__main__': 137 | ModelConfig(model_conf="model.yaml").convert() 138 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | from enum import Enum, unique 5 | 6 | 7 | @unique 8 | class ModelType(Enum): 9 | """模型类别枚举""" 10 | PB = 'PB' 11 | ONNX = 'ONNX' 12 | TFLITE = 'TFLITE' 13 | 14 | 15 | @unique 16 | class DatasetType(Enum): 17 | """数据集类别枚举""" 18 | Directory = 'Directory' 19 | TFRecords = 'TFRecords' 20 | 21 | 22 | @unique 23 | class LabelFrom(Enum): 24 | """标签来源枚举""" 25 | XML = 'XML' 26 | LMDB = 'LMDB' 27 | FileName = 'FileName' 28 | TXT = 'TXT' 29 | 30 | 31 | @unique 32 | class LossFunction(Enum): 33 | """损失函数枚举""" 34 | CrossEntropy = 'CrossEntropy' 35 | CTC = 'CTC' 36 | 37 | 38 | @unique 39 | class ModelScene(Enum): 40 | """模型场景枚举""" 41 | Classification = 'Classification' 42 | 43 | 44 | @unique 45 | class ModelField(Enum): 46 | """模型类别枚举""" 47 | Image = 'Image' 48 | Text = 'Text' 49 | 50 | 51 | @unique 52 | class RunMode(Enum): 53 | """运行模式枚举""" 54 | Validation = 'Validation' 55 | Trains = 'Trains' 56 | Predict = 'Predict' 57 | 58 | 59 | @unique 60 | class CNNNetwork(Enum): 61 | """卷积层枚举""" 62 | CNNX = 'CNNX' 63 | CNN5 = 'CNN5' 64 | # 不建议使用CNN3,过于精简,可能无法很好的收敛 65 | CNN3 = 'CNN3' 66 | ResNetTiny = 'ResNetTiny' 67 | ResNet50 = 'ResNet50' 68 | DenseNet = 'DenseNet' 69 | MobileNetV2 = 'MobileNetV2' 70 | 71 | 72 | @unique 73 | class RecurrentNetwork(Enum): 74 | """循环层枚举""" 75 | NoRecurrent = 'NoRecurrent' 76 | GRU = 'GRU' 77 | BiGRU = 'BiGRU' 78 | GRUcuDNN = 'GRUcuDNN' 79 | LSTM = 'LSTM' 80 | BiLSTM = 'BiLSTM' 81 | LSTMcuDNN = 'LSTMcuDNN' 82 | BiLSTMcuDNN = 'BiLSTMcuDNN' 83 | 84 | 85 | @unique 86 | class Optimizer(Enum): 87 | """优化器枚举""" 88 | RAdam = 'RAdam' 89 | Adam = 'Adam' 90 | Momentum = 'Momentum' 91 | AdaBound = 'AdaBound' 92 | SGD = 'SGD' 93 | AdaGrad = 'AdaGrad' 94 | RMSProp = 'RMSProp' 95 | 96 | 97 | @unique 98 | class SimpleCharset(Enum): 99 | """简单字符分类枚举""" 100 | NUMERIC = 'NUMERIC' 101 | ALPHANUMERIC = 'ALPHANUMERIC' 102 | ALPHANUMERIC_LOWER = 'ALPHANUMERIC_LOWER' 103 | ALPHANUMERIC_UPPER = 'ALPHANUMERIC_UPPER' 104 | ALPHABET_LOWER = 'ALPHABET_LOWER' 105 | ALPHABET_UPPER = 'ALPHABET_UPPER' 106 | ALPHABET = 'ALPHABET' 107 | ARITHMETIC = 'ARITHMETIC' 108 | FLOAT = 'FLOAT' 109 | CHS_3500 = 'CHS_3500' 110 | ALPHANUMERIC_CHS_3500_LOWER = 'ALPHANUMERIC_CHS_3500_LOWER' 111 | 112 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import sys 5 | from config import RecurrentNetwork, RESIZE_MAP, CNNNetwork, Optimizer 6 | from network.CNN import * 7 | from network.MobileNet import MobileNetV2 8 | from network.DenseNet import DenseNet 9 | from network.GRU import GRU, BiGRU, GRUcuDNN 10 | from network.LSTM import LSTM, BiLSTM, BiLSTMcuDNN, LSTMcuDNN 11 | from network.ResNet import ResNet50, ResNetTiny 12 | from network.utils import NetworkUtils 13 | from optimizer.AdaBound import AdaBoundOptimizer 14 | from optimizer.RAdam import RAdamOptimizer 15 | from loss import * 16 | from encoder import * 17 | from decoder import * 18 | from fc import * 19 | 20 | import tensorflow as tf 21 | tf.compat.v1.disable_v2_behavior() 22 | tf.compat.v1.disable_eager_execution() 23 | 24 | 25 | class NeuralNetwork(object): 26 | 27 | """ 28 | 神经网络构建类 29 | """ 30 | def __init__(self, model_conf: ModelConfig, mode: RunMode, backbone: CNNNetwork, recurrent: RecurrentNetwork): 31 | """ 32 | 33 | :param model_conf: 模型配置 34 | :param mode: 运行模式 (Trains/Validation/Predict) 35 | :param backbone: 36 | :param recurrent: 37 | """ 38 | self.model_conf = model_conf 39 | self.decoder = Decoder(self.model_conf) 40 | self.mode = mode 41 | self.network = backbone 42 | self.recurrent = recurrent 43 | self.inputs = tf.keras.Input(dtype=tf.float32, shape=self.input_shape, name='input') 44 | self.labels = tf.keras.Input(dtype=tf.int32, shape=[None], sparse=True, name='labels') 45 | self.utils = NetworkUtils(mode) 46 | self.merged_summary = None 47 | self.optimizer = None 48 | self.dataset_size = None 49 | 50 | @property 51 | def input_shape(self): 52 | """ 53 | :return: tuple/list 类型,输入的 Shape 54 | """ 55 | return RESIZE_MAP[self.model_conf.loss_func](*self.model_conf.resize) + [self.model_conf.image_channel] 56 | 57 | def build_graph(self): 58 | """ 59 | 在当前Session中构建网络计算图 60 | """ 61 | self._build_model() 62 | 63 | def build_train_op(self, dataset_size=None): 64 | self.dataset_size = dataset_size 65 | self._build_train_op() 66 | self.merged_summary = tf.compat.v1.summary.merge_all() 67 | 68 | def _build_model(self): 69 | 70 | """选择采用哪种卷积网络""" 71 | if self.network == CNNNetwork.CNN3: 72 | x = CNN3(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 73 | 74 | elif self.network == CNNNetwork.CNN5: 75 | x = CNN5(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 76 | 77 | elif self.network == CNNNetwork.CNNX: 78 | x = CNNX(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 79 | 80 | elif self.network == CNNNetwork.ResNetTiny: 81 | x = ResNetTiny(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 82 | 83 | elif self.network == CNNNetwork.ResNet50: 84 | x = ResNet50(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 85 | 86 | elif self.network == CNNNetwork.DenseNet: 87 | x = DenseNet(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 88 | 89 | elif self.network == CNNNetwork.MobileNetV2: 90 | x = MobileNetV2(model_conf=self.model_conf, inputs=self.inputs, utils=self.utils).build() 91 | 92 | else: 93 | raise ValueError('This cnn neural network is not supported at this time.') 94 | 95 | """选择采用哪种循环网络""" 96 | 97 | # time_major = True: [max_time_step, batch_size, num_classes] 98 | tf.compat.v1.logging.info("CNN Output: {}".format(x.get_shape())) 99 | 100 | self.seq_len = tf.compat.v1.fill([tf.shape(x)[0]], tf.shape(x)[1], name="seq_len") 101 | 102 | if self.recurrent == RecurrentNetwork.NoRecurrent: 103 | self.recurrent_network_builder = None 104 | elif self.recurrent == RecurrentNetwork.LSTM: 105 | self.recurrent_network_builder = LSTM(model_conf=self.model_conf, inputs=x, utils=self.utils) 106 | elif self.recurrent == RecurrentNetwork.BiLSTM: 107 | self.recurrent_network_builder = BiLSTM(model_conf=self.model_conf, inputs=x, utils=self.utils) 108 | elif self.recurrent == RecurrentNetwork.GRU: 109 | self.recurrent_network_builder = GRU(model_conf=self.model_conf, inputs=x, utils=self.utils) 110 | elif self.recurrent == RecurrentNetwork.BiGRU: 111 | self.recurrent_network_builder = BiGRU(model_conf=self.model_conf, inputs=x, utils=self.utils) 112 | elif self.recurrent == RecurrentNetwork.LSTMcuDNN: 113 | self.recurrent_network_builder = LSTMcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) 114 | elif self.recurrent == RecurrentNetwork.BiLSTMcuDNN: 115 | self.recurrent_network_builder = BiLSTMcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) 116 | elif self.recurrent == RecurrentNetwork.GRUcuDNN: 117 | self.recurrent_network_builder = GRUcuDNN(model_conf=self.model_conf, inputs=x, utils=self.utils) 118 | else: 119 | raise ValueError('This recurrent neural network is not supported at this time.') 120 | 121 | logits = self.recurrent_network_builder.build() if self.recurrent_network_builder else x 122 | if self.recurrent_network_builder and self.model_conf.loss_func != LossFunction.CTC: 123 | raise ValueError('CTC loss must use recurrent neural network.') 124 | 125 | """输出层,根据Loss函数区分""" 126 | with tf.keras.backend.name_scope('output'): 127 | if self.model_conf.loss_func == LossFunction.CTC: 128 | self.outputs = FullConnectedRNN(model_conf=self.model_conf, outputs=logits).build() 129 | elif self.model_conf.loss_func == LossFunction.CrossEntropy: 130 | self.outputs = FullConnectedCNN(model_conf=self.model_conf, outputs=logits).build() 131 | return self.outputs 132 | 133 | @property 134 | def decay_steps(self): 135 | if not self.dataset_size: 136 | return 10000 137 | return 10000 138 | # epoch_step = int(self.dataset_size / self.model_conf.batch_size) 139 | # return int(epoch_step / 4) 140 | 141 | def _build_train_op(self): 142 | """构建训练操作符""" 143 | 144 | # 步数 145 | self.global_step = tf.compat.v1.train.get_or_create_global_step() 146 | 147 | # Loss函数 148 | if self.model_conf.loss_func == LossFunction.CTC: 149 | self.loss = Loss.ctc( 150 | labels=self.labels, 151 | logits=self.outputs, 152 | sequence_length=self.seq_len 153 | ) 154 | elif self.model_conf.loss_func == LossFunction.CrossEntropy: 155 | self.loss = Loss.cross_entropy( 156 | labels=self.labels, 157 | logits=self.outputs 158 | ) 159 | 160 | self.cost = tf.reduce_mean(self.loss) 161 | 162 | tf.compat.v1.summary.scalar('cost', self.cost) 163 | 164 | # 学习率 指数衰减法 165 | self.lrn_rate = tf.compat.v1.train.exponential_decay( 166 | self.model_conf.trains_learning_rate, 167 | self.global_step, 168 | staircase=True, 169 | decay_steps=self.decay_steps, 170 | decay_rate=0.98, 171 | ) 172 | tf.compat.v1.summary.scalar('learning_rate', self.lrn_rate) 173 | 174 | if self.model_conf.neu_optimizer == Optimizer.AdaBound: 175 | self.optimizer = AdaBoundOptimizer( 176 | learning_rate=self.lrn_rate, 177 | final_lr=0.001, 178 | beta1=0.9, 179 | beta2=0.999, 180 | amsbound=True 181 | ) 182 | elif self.model_conf.neu_optimizer == Optimizer.Adam: 183 | self.optimizer = tf.compat.v1.train.AdamOptimizer( 184 | learning_rate=self.lrn_rate 185 | ) 186 | elif self.model_conf.neu_optimizer == Optimizer.RAdam: 187 | self.optimizer = RAdamOptimizer( 188 | learning_rate=self.lrn_rate, 189 | warmup_proportion=0.1, 190 | min_lr=1e-6 191 | ) 192 | elif self.model_conf.neu_optimizer == Optimizer.Momentum: 193 | self.optimizer = tf.compat.v1.train.MomentumOptimizer( 194 | learning_rate=self.lrn_rate, 195 | use_nesterov=True, 196 | momentum=0.9, 197 | ) 198 | elif self.model_conf.neu_optimizer == Optimizer.SGD: 199 | self.optimizer = tf.compat.v1.train.GradientDescentOptimizer( 200 | learning_rate=self.lrn_rate, 201 | ) 202 | elif self.model_conf.neu_optimizer == Optimizer.AdaGrad: 203 | self.optimizer = tf.compat.v1.train.AdagradOptimizer( 204 | learning_rate=self.lrn_rate, 205 | ) 206 | elif self.model_conf.neu_optimizer == Optimizer.RMSProp: 207 | self.optimizer = tf.compat.v1.train.RMSPropOptimizer( 208 | learning_rate=self.lrn_rate, 209 | ) 210 | 211 | # BN 操作符更新(moving_mean, moving_variance) 212 | update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) 213 | 214 | # 将 train_op 和 update_ops 融合 215 | with tf.control_dependencies(update_ops): 216 | self.train_op = self.optimizer.minimize( 217 | loss=self.cost, 218 | global_step=self.global_step, 219 | ) 220 | 221 | # 转录层-Loss函数 222 | if self.model_conf.loss_func == LossFunction.CTC: 223 | self.dense_decoded = self.decoder.ctc( 224 | inputs=self.outputs, 225 | sequence_length=self.seq_len 226 | ) 227 | elif self.model_conf.loss_func == LossFunction.CrossEntropy: 228 | self.dense_decoded = self.decoder.cross_entropy( 229 | inputs=self.outputs 230 | ) 231 | 232 | 233 | if __name__ == '__main__': 234 | pass 235 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from config import ModelConfig 6 | 7 | 8 | class Decoder: 9 | """ 10 | 转录层:用于解码预测结果 11 | """ 12 | def __init__(self, model_conf: ModelConfig): 13 | self.model_conf = model_conf 14 | self.category_num = self.model_conf.category_num 15 | 16 | def ctc(self, inputs, sequence_length): 17 | """针对CTC Loss的解码""" 18 | ctc_decode, _ = tf.compat.v1.nn.ctc_beam_search_decoder_v2(inputs, sequence_length, beam_width=1) 19 | decoded_sequences = tf.sparse.to_dense(ctc_decode[0], default_value=self.category_num, name='dense_decoded') 20 | return decoded_sequences 21 | 22 | @staticmethod 23 | def cross_entropy(inputs): 24 | """针对CrossEntropy Loss的解码""" 25 | return tf.argmax(inputs, 2, name='dense_decoded') 26 | 27 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import io 5 | import re 6 | import cv2 7 | import random 8 | import PIL.Image 9 | import numpy as np 10 | import tensorflow as tf 11 | from exception import * 12 | from constants import RunMode 13 | from config import ModelConfig, LabelFrom, LossFunction 14 | from category import encode_maps, FULL_ANGLE_MAP 15 | from pretreatment import preprocessing 16 | from pretreatment import preprocessing_by_func 17 | from tools.gif_frames import concat_frames, blend_frame 18 | from collections import Counter 19 | 20 | 21 | class Encoder(object): 22 | """ 23 | 编码层:用于将数据输入编码为可输入网络的数据 24 | """ 25 | def __init__(self, model_conf: ModelConfig, mode: RunMode): 26 | self.model_conf = model_conf 27 | self.mode = mode 28 | self.category_param = self.model_conf.category_param 29 | 30 | @staticmethod 31 | def main_color_replace(im: np.ndarray, num=2, repl=(255, 255, 255)): 32 | 33 | red, green, blue = im.T 34 | 35 | colors = [] 36 | for (r, g, b) in im[:, 1, :]: 37 | colors.append((r, g, b)) 38 | 39 | most_common = [i[0] for i in Counter(colors).most_common(num)] 40 | 41 | areas = False 42 | 43 | for r, g, b in most_common: 44 | areas = areas | ((red == r) & (green == g) & (blue == b)) 45 | 46 | im[:, :, :][areas.T] = repl 47 | return im 48 | 49 | def image(self, path_or_bytes): 50 | """针对图片类型的输入的编码""" 51 | # im = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 52 | # The OpenCV cannot handle gif format images, it will return None. 53 | # if im is None: 54 | 55 | path_or_stream = io.BytesIO(path_or_bytes) if isinstance(path_or_bytes, bytes) else path_or_bytes 56 | if not path_or_stream: 57 | return "Picture is corrupted: {}".format(path_or_bytes) 58 | try: 59 | pil_image = PIL.Image.open(path_or_stream) 60 | except OSError as e: 61 | return "{} - {}".format(e, path_or_bytes) 62 | 63 | use_compress = False 64 | 65 | gif_handle = self.model_conf.pre_concat_frames != -1 or self.model_conf.pre_blend_frames != -1 66 | 67 | if pil_image.mode == 'P' and not gif_handle: 68 | pil_image = pil_image.convert('RGB') 69 | 70 | rgb = pil_image.split() 71 | 72 | # if self.mode == RunMode.Trains and use_compress: 73 | # img_compress = io.BytesIO() 74 | # 75 | # pil_image.convert('RGB').save(img_compress, format='JPEG', quality=random.randint(75, 100)) 76 | # img_compress_bytes = img_compress.getvalue() 77 | # img_compress.close() 78 | # path_or_stream = io.BytesIO(img_compress_bytes) 79 | # pil_image = PIL.Image.open(path_or_stream) 80 | 81 | if len(rgb) == 1 and self.model_conf.image_channel == 3: 82 | return "The number of image channels {} is inconsistent with the number of configured channels {}.".format( 83 | len(rgb), self.model_conf.image_channel 84 | ) 85 | 86 | size = pil_image.size 87 | 88 | # if self.mode == RunMode.Trains and len(rgb) == 3 and use_compress: 89 | # new_size = [size[0] + random.randint(5, 10), size[1] + random.randint(5, 10)] 90 | # background = PIL.Image.new( 91 | # 'RGB', new_size, (255, 255, 255) 92 | # ) 93 | # random_offset_w = random.randint(0, 5) 94 | # random_offset_h = random.randint(0, 5) 95 | # background.paste( 96 | # pil_image, 97 | # ( 98 | # random_offset_w, 99 | # random_offset_h, 100 | # size[0] + random_offset_w, 101 | # size[1] + random_offset_h 102 | # ), 103 | # None 104 | # ) 105 | # background.convert('RGB') 106 | # pil_image = background 107 | 108 | if len(rgb) > 3 and self.model_conf.pre_replace_transparent and not gif_handle and not use_compress: 109 | background = PIL.Image.new('RGBA', pil_image.size, (255, 255, 255)) 110 | try: 111 | background.paste(pil_image, (0, 0, size[0], size[1]), pil_image) 112 | background.convert('RGB') 113 | pil_image = background 114 | except: 115 | pil_image = pil_image.convert('RGB') 116 | 117 | if len(pil_image.split()) > 3 and self.model_conf.image_channel == 3: 118 | pil_image = pil_image.convert('RGB') 119 | 120 | if self.model_conf.pre_concat_frames != -1: 121 | im = concat_frames(pil_image, need_frame=self.model_conf.pre_concat_frames) 122 | elif self.model_conf.pre_blend_frames != -1: 123 | im = blend_frame(pil_image, need_frame=self.model_conf.pre_blend_frames) 124 | else: 125 | im = np.array(pil_image) 126 | 127 | if isinstance(im, list): 128 | return None 129 | 130 | im = preprocessing_by_func( 131 | exec_map=self.model_conf.pre_exec_map, 132 | src_arr=im 133 | ) 134 | 135 | if self.model_conf.image_channel == 1 and len(im.shape) == 3: 136 | if self.mode == RunMode.Trains: 137 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY if bool(random.getrandbits(1)) else cv2.COLOR_BGR2GRAY) 138 | else: 139 | im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) 140 | 141 | im = preprocessing( 142 | image=im, 143 | binaryzation=self.model_conf.pre_binaryzation, 144 | ) 145 | 146 | if self.model_conf.pre_horizontal_stitching: 147 | up_slice = im[0: int(size[1] / 2), 0: size[0]] 148 | down_slice = im[int(size[1] / 2): size[1], 0: size[0]] 149 | im = np.concatenate((up_slice, down_slice), axis=1) 150 | 151 | if self.mode == RunMode.Trains and bool(random.getrandbits(1)): 152 | im = preprocessing( 153 | image=im, 154 | binaryzation=self.model_conf.da_binaryzation, 155 | median_blur=self.model_conf.da_median_blur, 156 | gaussian_blur=self.model_conf.da_gaussian_blur, 157 | equalize_hist=self.model_conf.da_equalize_hist, 158 | laplacian=self.model_conf.da_laplace, 159 | rotate=self.model_conf.da_rotate, 160 | warp_perspective=self.model_conf.da_warp_perspective, 161 | sp_noise=self.model_conf.da_sp_noise, 162 | random_brightness=self.model_conf.da_brightness, 163 | random_saturation=self.model_conf.da_saturation, 164 | random_hue=self.model_conf.da_hue, 165 | random_gamma=self.model_conf.da_gamma, 166 | random_channel_swap=self.model_conf.da_channel_swap, 167 | random_blank=self.model_conf.da_random_blank, 168 | random_transition=self.model_conf.da_random_transition, 169 | ).astype(np.float32) 170 | 171 | else: 172 | im = im.astype(np.float32) 173 | if self.model_conf.resize[0] == -1: 174 | # random_ratio = random.choice([2.5, 3, 3.5, 3.2, 2.7, 2.75]) 175 | ratio = self.model_conf.resize[1] / size[1] 176 | # random_width = int(random_ratio * RESIZE[1]) 177 | resize_width = int(ratio * size[0]) 178 | # resize_width = random_width if is_random else resize_width 179 | im = cv2.resize(im, (resize_width, self.model_conf.resize[1])) 180 | else: 181 | im = cv2.resize(im, (self.model_conf.resize[0], self.model_conf.resize[1])) 182 | im = im.swapaxes(0, 1) 183 | 184 | if self.model_conf.image_channel == 1: 185 | return np.array((im[:, :, np.newaxis]) / 255.) 186 | else: 187 | return np.array(im[:, :]) / 255. 188 | 189 | def text(self, content): 190 | """针对文本类型的输入的编码""" 191 | if isinstance(content, bytes): 192 | content = content.decode("utf8") 193 | 194 | found = content 195 | # 如果匹配内置的大小写规范,触发自动转换 196 | if isinstance(self.category_param, str) and '_LOWER' in self.category_param: 197 | found = found.lower() 198 | if isinstance(self.category_param, str) and '_UPPER' in self.category_param: 199 | found = found.upper() 200 | 201 | if self.model_conf.category_param == 'ARITHMETIC': 202 | found = found.replace("x", "×").replace('?', "?") 203 | 204 | # 标签是否包含分隔符 205 | if self.model_conf.label_split: 206 | labels = found.split(self.model_conf.label_split) 207 | elif '&' in found: 208 | labels = found.split('&') 209 | elif self.model_conf.max_label_num == 1: 210 | labels = [found] 211 | else: 212 | labels = [_ for _ in found] 213 | labels = self.filter_full_angle(labels) 214 | try: 215 | if not labels: 216 | return [0] 217 | # 根据类别集合找到对应映射编码为dense数组 218 | if self.model_conf.loss_func == LossFunction.CTC: 219 | label = self.split_continuous_char( 220 | [encode_maps(self.model_conf.category)[i] for i in labels] 221 | ) 222 | else: 223 | label = self.auto_padding_char( 224 | [encode_maps(self.model_conf.category)[i] for i in labels] 225 | ) 226 | return label 227 | 228 | except KeyError as e: 229 | return dict(e=e, label=content, char=e.args[0]) 230 | # exception( 231 | # 'The sample label {} contains invalid charset: {}.'.format( 232 | # content, e.args[0] 233 | # ), ConfigException.SAMPLE_LABEL_ERROR 234 | # ) 235 | 236 | def split_continuous_char(self, content): 237 | # 为连续的分类插入空白符 238 | store_list = [] 239 | # blank_char = [self.model_conf.category_num] if bool(random.getrandbits(1)) else [0] 240 | blank_char = [self.model_conf.category_num] 241 | for i in range(len(content) - 1): 242 | store_list.append(content[i]) 243 | if content[i] == content[i + 1]: 244 | store_list += blank_char 245 | store_list.append(content[-1]) 246 | return store_list 247 | 248 | def auto_padding_char(self, content): 249 | if len(content) < self.model_conf.max_label_num and self.model_conf.auto_padding: 250 | remain_label_num = self.model_conf.max_label_num - len(content) 251 | return [0] * remain_label_num + content 252 | # return content + [0] * remain_label_num 253 | return content 254 | 255 | @staticmethod 256 | def filter_full_angle(content): 257 | return [FULL_ANGLE_MAP.get(i) if i in FULL_ANGLE_MAP.keys() else i for i in content if i != ' '] 258 | 259 | 260 | if __name__ == '__main__': 261 | pass 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /exception.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import sys 5 | import time 6 | 7 | """ 8 | 此类包含各种异常类别,希望对已知可能的异常进行分类,以便出现问题是方便定位 9 | """ 10 | 11 | 12 | class SystemException(RuntimeError): 13 | def __init__(self, message, code=-1): 14 | self.message = message 15 | self.code = code 16 | 17 | 18 | class Error(object): 19 | def __init__(self, message, code=-1): 20 | self.message = message 21 | self.code = code 22 | print(self.message) 23 | time.sleep(5) 24 | sys.exit(self.code) 25 | 26 | 27 | def exception(text, code=-1): 28 | raise SystemException(text, code) 29 | # Error(text, code) 30 | 31 | 32 | class ConfigException: 33 | OPTIMIZER_NOT_SUPPORTED = -4072 34 | NETWORK_NOT_SUPPORTED = -4071 35 | LOSS_FUNC_NOT_SUPPORTED = -4061 36 | MODEL_FIELD_NOT_SUPPORTED = -4052 37 | MODEL_SCENE_NOT_SUPPORTED = -4051 38 | SYS_CONFIG_PATH_NOT_EXIST = -4041 39 | MODEL_CONFIG_PATH_NOT_EXIST = -4042 40 | CATEGORY_NOT_EXIST = -4043 41 | CATEGORY_INCORRECT = -4043 42 | SAMPLE_LABEL_ERROR = -4044 43 | GET_LABEL_REGEX_ERROR = -4045 44 | ERROR_LABEL_FROM = -4046 45 | INSUFFICIENT_SAMPLE = -5 46 | VALIDATION_SET_SIZE_ERROR = -6 47 | 48 | 49 | -------------------------------------------------------------------------------- /fc/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | from .cnn import FullConnectedCNN 6 | from .rnn import FullConnectedRNN -------------------------------------------------------------------------------- /fc/cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from tensorflow.python.keras.regularizers import l1_l2 6 | from config import ModelConfig, RunMode 7 | from exception import exception 8 | 9 | from network.utils import NetworkUtils 10 | 11 | 12 | class FullConnectedCNN(object): 13 | """ 14 | CNN的输出层 15 | """ 16 | def __init__(self, model_conf: ModelConfig, outputs): 17 | self.model_conf = model_conf 18 | 19 | self.max_label_num = self.model_conf.max_label_num 20 | if self.max_label_num == -1: 21 | exception(text="The scene must set the maximum number of label (MaxLabelNum)", code=-998) 22 | self.category_num = self.model_conf.category_num 23 | 24 | flatten = tf.keras.layers.Flatten()(outputs) 25 | shape_list = flatten.get_shape().as_list() 26 | 27 | # print(shape_list[1], self.max_label_num) 28 | outputs = tf.keras.layers.Reshape([self.max_label_num, int(shape_list[1] / self.max_label_num)])(flatten) 29 | self.outputs = tf.keras.layers.Dense( 30 | input_shape=outputs.shape, 31 | units=self.category_num, 32 | )(inputs=outputs) 33 | 34 | print("output to reshape ----------- ", self.outputs.shape) 35 | self.outputs = tf.reshape(self.outputs, [-1, self.max_label_num, self.category_num], name="predict") 36 | 37 | def build(self): 38 | return self.outputs 39 | -------------------------------------------------------------------------------- /fc/rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from tensorflow.python.keras.regularizers import l1_l2 6 | from config import RunMode, ModelConfig 7 | from network.utils import NetworkUtils 8 | 9 | 10 | class FullConnectedRNN(object): 11 | """ 12 | RNN的输出层 13 | """ 14 | 15 | def __init__(self, model_conf: ModelConfig, outputs): 16 | self.model_conf = model_conf 17 | 18 | self.dense = tf.compat.v1.keras.layers.Dense( 19 | units=self.model_conf.category_num + 2, 20 | kernel_initializer=tf.compat.v1.keras.initializers.he_normal(seed=None), 21 | bias_initializer='zeros', 22 | ) 23 | 24 | self.outputs = self.dense(outputs) 25 | self.predict = tf.compat.v1.transpose(self.outputs, perm=(1, 0, 2), name="predict") 26 | # self.predict = tf.keras.backend.permute_dimensions(self.outputs, pattern=(1, 0, 2)) 27 | 28 | def build(self): 29 | return self.predict 30 | -------------------------------------------------------------------------------- /fuse_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | import os 6 | import re 7 | import base64 8 | import pickle 9 | from config import ModelConfig 10 | from constants import ModelType 11 | from config import COMPILE_MODEL_MAP 12 | 13 | 14 | def parse_model(source_bytes: bytes, key=None): 15 | split_tag = b'-#||#-' 16 | 17 | if not key: 18 | key = [b"_____" + i.encode("utf8") + b"_____" for i in "&coriander"] 19 | if isinstance(key, str): 20 | key = [b"_____" + i.encode("utf8") + b"_____" for i in key] 21 | key_len_int = len(key) 22 | model_bytes_list = [] 23 | graph_bytes_list = [] 24 | slice_index = source_bytes.index(key[0]) 25 | split_tag_len = len(split_tag) 26 | slice_0 = source_bytes[0: slice_index].split(split_tag) 27 | model_slice_len = len(slice_0[1]) 28 | graph_slice_len = len(slice_0[0]) 29 | slice_len = split_tag_len + model_slice_len + graph_slice_len 30 | 31 | for i in range(key_len_int-1): 32 | slice_index = source_bytes.index(key[i]) 33 | print(slice_index, slice_index - slice_len) 34 | slices = source_bytes[slice_index - slice_len: slice_index].split(split_tag) 35 | model_bytes_list.append(slices[1]) 36 | graph_bytes_list.append(slices[0]) 37 | slices = source_bytes.split(key[-2])[1][:-len(key[-1])].split(split_tag) 38 | 39 | model_bytes_list.append(slices[1]) 40 | graph_bytes_list.append(slices[0]) 41 | model_bytes = b"".join(model_bytes_list) 42 | model_conf: ModelConfig = pickle.loads(model_bytes) 43 | graph_bytes: bytes = b"".join(graph_bytes_list) 44 | return model_conf, graph_bytes 45 | 46 | 47 | def concat_model(output_path, model_bytes, graph_bytes, key=None): 48 | if not key: 49 | key = [b"_____" + i.encode("utf8") + b"_____" for i in "&coriander"] 50 | if isinstance(key, str): 51 | key = [b"_____" + i.encode("utf8") + b"_____" for i in key] 52 | key_len_int = len(key) 53 | model_slice_len = int(len(model_bytes) / key_len_int) + 1 54 | graph_slice_len = int(len(graph_bytes) / key_len_int) + 1 55 | model_slice = [model_bytes[i:i + model_slice_len] for i in range(0, len(model_bytes), model_slice_len)] 56 | 57 | graph_slice = [graph_bytes[i:i + graph_slice_len] for i in range(0, len(graph_bytes), graph_slice_len)] 58 | 59 | new_model = [] 60 | for i in range(key_len_int): 61 | new_model.append(graph_slice[i] + b'-#||#-') 62 | new_model.append(model_slice[i]) 63 | new_model.append(key[i]) 64 | new_model = b"".join(new_model) 65 | with open(output_path, "wb") as f: 66 | f.write(new_model) 67 | print("Successfully write to model {}".format(output_path)) 68 | 69 | 70 | def output_model(project_name: str, model_type: ModelType, key=None): 71 | model_conf = ModelConfig(project_name, is_dev=False) 72 | 73 | graph_parent_path = model_conf.compile_model_path 74 | model_suffix = COMPILE_MODEL_MAP[model_type] 75 | model_bytes = pickle.dumps(model_conf.conf) 76 | graph_path = os.path.join(graph_parent_path, "{}{}".format(model_conf.model_name, model_suffix)) 77 | 78 | with open(graph_path, "rb") as f: 79 | graph_bytes = f.read() 80 | 81 | output_path = graph_path.replace(".pb", ".pl").replace(".onnx", ".pl").replace(".tflite", ".pl") 82 | concat_model(output_path, model_bytes, graph_bytes, key) 83 | 84 | 85 | if __name__ == '__main__': 86 | output_model("", ModelType.PB) 87 | 88 | -------------------------------------------------------------------------------- /gui/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /gui/data_augmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import json 5 | import tkinter as tk 6 | import tkinter.ttk as ttk 7 | from gui.utils import LayoutGUI 8 | 9 | 10 | class DataAugmentationDialog(tk.Toplevel): 11 | 12 | def __init__(self): 13 | tk.Toplevel.__init__(self) 14 | self.title('Data Augmentation') 15 | self.layout = { 16 | 'global': { 17 | 'start': {'x': 15, 'y': 20}, 18 | 'space': {'x': 15, 'y': 25}, 19 | 'tiny_space': {'x': 5, 'y': 10} 20 | } 21 | } 22 | self.data_augmentation_entity = None 23 | self.da_random_captcha = {"Enable": False, "FontPath": ""} 24 | self.window_width = 750 25 | self.window_height = 220 26 | 27 | self.layout_utils = LayoutGUI(self.layout, self.window_width) 28 | screenwidth = self.winfo_screenwidth() 29 | screenheight = self.winfo_screenheight() 30 | size = '%dx%d+%d+%d' % ( 31 | self.window_width, 32 | self.window_height, 33 | (screenwidth - self.window_width) / 2, 34 | (screenheight - self.window_height) / 2 35 | ) 36 | self.geometry(size) 37 | # ============================= Group 4 ===================================== 38 | self.label_frame_augmentation = ttk.Labelframe(self, text='Data Augmentation') 39 | self.label_frame_augmentation.place( 40 | x=self.layout['global']['start']['x'], 41 | y=self.layout['global']['start']['y'], 42 | width=725, 43 | height=150 44 | ) 45 | 46 | # 二值化 - 标签 47 | self.binaryzation_text = ttk.Label(self, text='Binaryzation', anchor=tk.W) 48 | self.layout_utils.inside_widget( 49 | src=self.binaryzation_text, 50 | target=self.label_frame_augmentation, 51 | width=72, 52 | height=20, 53 | ) 54 | 55 | # 二值化 - 输入框 56 | self.binaryzation_val = tk.StringVar() 57 | self.binaryzation_val.set(-1) 58 | self.binaryzation_entry = ttk.Entry(self, textvariable=self.binaryzation_val, justify=tk.LEFT) 59 | self.layout_utils.next_to_widget( 60 | src=self.binaryzation_entry, 61 | target=self.binaryzation_text, 62 | width=55, 63 | height=20, 64 | tiny_space=True 65 | ) 66 | 67 | # 滤波 - 标签 68 | self.median_blur_text = ttk.Label(self, text='Median Blur', anchor=tk.W) 69 | self.layout_utils.next_to_widget( 70 | src=self.median_blur_text, 71 | target=self.binaryzation_entry, 72 | width=80, 73 | height=20, 74 | tiny_space=False 75 | ) 76 | 77 | # 滤波 - 输入框 78 | self.median_blur_val = tk.IntVar() 79 | self.median_blur_val.set(-1) 80 | self.median_blur_entry = ttk.Entry(self, textvariable=self.median_blur_val, justify=tk.LEFT) 81 | self.layout_utils.next_to_widget( 82 | src=self.median_blur_entry, 83 | target=self.median_blur_text, 84 | width=52, 85 | height=20, 86 | tiny_space=True 87 | ) 88 | 89 | # 高斯模糊 - 标签 90 | self.gaussian_blur_text = ttk.Label(self, text='Gaussian Blur', anchor=tk.W) 91 | self.layout_utils.next_to_widget( 92 | src=self.gaussian_blur_text, 93 | target=self.median_blur_entry, 94 | width=85, 95 | height=20, 96 | tiny_space=False 97 | ) 98 | 99 | # 高斯模糊 - 输入框 100 | self.gaussian_blur_val = tk.IntVar() 101 | self.gaussian_blur_val.set(-1) 102 | self.gaussian_blur_entry = ttk.Entry(self, textvariable=self.gaussian_blur_val, justify=tk.LEFT) 103 | self.layout_utils.next_to_widget( 104 | src=self.gaussian_blur_entry, 105 | target=self.gaussian_blur_text, 106 | width=62, 107 | height=20, 108 | tiny_space=True 109 | ) 110 | 111 | # 椒盐噪声 - 标签 112 | self.sp_noise_text = ttk.Label(self, text='Pepper Noise (0-1)', anchor=tk.W) 113 | self.layout_utils.next_to_widget( 114 | src=self.sp_noise_text, 115 | target=self.gaussian_blur_entry, 116 | width=110, 117 | height=20, 118 | tiny_space=False 119 | ) 120 | 121 | # 椒盐噪声 - 输入框 122 | self.sp_noise_val = tk.DoubleVar() 123 | self.sp_noise_val.set(-1) 124 | self.sp_noise_entry = ttk.Entry(self, textvariable=self.sp_noise_val, justify=tk.LEFT) 125 | self.layout_utils.next_to_widget( 126 | src=self.sp_noise_entry, 127 | target=self.sp_noise_text, 128 | width=71, 129 | height=20, 130 | tiny_space=True 131 | ) 132 | 133 | # 旋转 - 标签 134 | self.rotate_text = ttk.Label(self, text='Rotate (0-90)', anchor=tk.W) 135 | self.layout_utils.below_widget( 136 | src=self.rotate_text, 137 | target=self.binaryzation_text, 138 | width=72, 139 | height=20, 140 | tiny_space=True 141 | ) 142 | 143 | # 旋转 - 输入框 144 | self.rotate_val = tk.IntVar() 145 | self.rotate_val.set(-1) 146 | self.rotate_entry = ttk.Entry(self, textvariable=self.rotate_val, justify=tk.LEFT) 147 | self.layout_utils.next_to_widget( 148 | src=self.rotate_entry, 149 | target=self.rotate_text, 150 | width=55, 151 | height=20, 152 | tiny_space=True 153 | ) 154 | 155 | # 随机空白边缘 - 标签 156 | self.random_blank_text = ttk.Label(self, text='Blank Border', anchor=tk.W) 157 | self.layout_utils.next_to_widget( 158 | src=self.random_blank_text, 159 | target=self.rotate_entry, 160 | width=72, 161 | height=20, 162 | tiny_space=False 163 | ) 164 | 165 | # 随机空白边缘 - 输入框 166 | self.random_blank_val = tk.IntVar() 167 | self.random_blank_val.set(-1) 168 | self.random_blank_entry = ttk.Entry(self, textvariable=self.random_blank_val, justify=tk.LEFT) 169 | self.layout_utils.next_to_widget( 170 | src=self.random_blank_entry, 171 | target=self.random_blank_text, 172 | width=55, 173 | height=20, 174 | tiny_space=True 175 | ) 176 | 177 | # 随机边缘位移 - 标签 178 | self.random_transition_text = ttk.Label(self, text='Transition', anchor=tk.W) 179 | self.layout_utils.next_to_widget( 180 | src=self.random_transition_text, 181 | target=self.random_blank_entry, 182 | width=60, 183 | height=20, 184 | tiny_space=False 185 | ) 186 | 187 | # 随机边缘位移 - 输入框 188 | self.random_transition_val = tk.IntVar() 189 | self.random_transition_val.set(-1) 190 | self.random_transition_entry = ttk.Entry(self, textvariable=self.random_transition_val, justify=tk.LEFT) 191 | self.layout_utils.next_to_widget( 192 | src=self.random_transition_entry, 193 | target=self.random_transition_text, 194 | width=55, 195 | height=20, 196 | tiny_space=True 197 | ) 198 | 199 | # 随机验证码字体 - 标签 200 | self.random_captcha_font_text = ttk.Label(self, text='RandomCaptcha - Font', anchor=tk.W) 201 | self.layout_utils.next_to_widget( 202 | src=self.random_captcha_font_text, 203 | target=self.random_transition_entry, 204 | width=130, 205 | height=20, 206 | tiny_space=False 207 | ) 208 | 209 | # 随机验证码字体 210 | self.random_captcha_font_val = tk.StringVar() 211 | self.random_captcha_font_val.set("") 212 | self.random_captcha_font_entry = ttk.Entry(self, textvariable=self.random_captcha_font_val, justify=tk.LEFT) 213 | self.layout_utils.next_to_widget( 214 | src=self.random_captcha_font_entry, 215 | target=self.random_captcha_font_text, 216 | width=75, 217 | height=20, 218 | tiny_space=True 219 | ) 220 | 221 | # 透视变换 - 多选框 222 | self.warp_perspective_val = tk.IntVar() 223 | self.warp_perspective_val.set(0) 224 | self.warp_perspective = ttk.Checkbutton( 225 | self, text='Distortion', variable=self.warp_perspective_val, onvalue=1, offvalue=0 226 | ) 227 | self.layout_utils.below_widget( 228 | src=self.warp_perspective, 229 | target=self.rotate_text, 230 | width=80, 231 | height=20, 232 | tiny_space=False 233 | ) 234 | 235 | # 直方图均衡化 - 多选框 236 | self.equalize_hist_val = tk.IntVar() 237 | self.equalize_hist_val.set(0) 238 | self.equalize_hist = ttk.Checkbutton( 239 | self, text='EqualizeHist', variable=self.equalize_hist_val, offvalue=0 240 | ) 241 | self.layout_utils.next_to_widget( 242 | src=self.equalize_hist, 243 | target=self.warp_perspective, 244 | width=100, 245 | height=20, 246 | tiny_space=True 247 | ) 248 | 249 | # 拉普拉斯 - 多选框 250 | self.laplace_val = tk.IntVar() 251 | self.laplace_val.set(0) 252 | self.laplace = ttk.Checkbutton( 253 | self, text='Laplace', variable=self.laplace_val, onvalue=1, offvalue=0 254 | ) 255 | self.layout_utils.next_to_widget( 256 | src=self.laplace, 257 | target=self.equalize_hist, 258 | width=64, 259 | height=20, 260 | tiny_space=True 261 | ) 262 | 263 | # 随机亮度 - 多选框 264 | self.brightness_val = tk.IntVar() 265 | self.brightness_val.set(0) 266 | self.brightness = ttk.Checkbutton( 267 | self, text='Brightness', variable=self.brightness_val, offvalue=0 268 | ) 269 | self.layout_utils.next_to_widget( 270 | src=self.brightness, 271 | target=self.laplace, 272 | width=80, 273 | height=20, 274 | tiny_space=True 275 | ) 276 | 277 | # 随机饱和度 - 多选框 278 | self.saturation_val = tk.IntVar() 279 | self.saturation_val.set(0) 280 | self.saturation = ttk.Checkbutton( 281 | self, text='Saturation', variable=self.saturation_val, offvalue=0 282 | ) 283 | self.layout_utils.next_to_widget( 284 | src=self.saturation, 285 | target=self.brightness, 286 | width=80, 287 | height=20, 288 | tiny_space=True 289 | ) 290 | 291 | # 随机色相 - 多选框 292 | self.hue_val = tk.IntVar() 293 | self.hue_val.set(0) 294 | self.hue = ttk.Checkbutton( 295 | self, text='Hue', variable=self.hue_val, offvalue=0 296 | ) 297 | self.layout_utils.next_to_widget( 298 | src=self.hue, 299 | target=self.saturation, 300 | width=50, 301 | height=20, 302 | tiny_space=True 303 | ) 304 | 305 | # 随机Gamma - 多选框 306 | self.gamma_val = tk.IntVar() 307 | self.gamma_val.set(0) 308 | self.gamma = ttk.Checkbutton( 309 | self, text='Gamma', variable=self.gamma_val, offvalue=0 310 | ) 311 | self.layout_utils.next_to_widget( 312 | src=self.gamma, 313 | target=self.hue, 314 | width=80, 315 | height=20, 316 | tiny_space=True 317 | ) 318 | 319 | # 随机通道 - 多选框 320 | self.channel_swap_val = tk.IntVar() 321 | self.channel_swap_val.set(0) 322 | self.channel_swap = ttk.Checkbutton( 323 | self, text='Channel Swap', variable=self.channel_swap_val, offvalue=0 324 | ) 325 | self.layout_utils.next_to_widget( 326 | src=self.channel_swap, 327 | target=self.gamma, 328 | width=100, 329 | height=20, 330 | tiny_space=True 331 | ) 332 | 333 | # 保存 - 按钮 334 | self.btn_save = ttk.Button(self, text='Save Configuration', command=lambda: self.save_conf()) 335 | self.layout_utils.widget_from_right( 336 | src=self.btn_save, 337 | target=self.label_frame_augmentation, 338 | width=120, 339 | height=24, 340 | tiny_space=True 341 | ) 342 | 343 | def read_conf(self, entity): 344 | self.data_augmentation_entity = entity 345 | self.binaryzation_val.set(json.dumps(entity.binaryzation)) 346 | self.median_blur_val.set(entity.median_blur) 347 | self.gaussian_blur_val.set(entity.gaussian_blur) 348 | self.equalize_hist_val.set(entity.equalize_hist) 349 | self.laplace_val.set(entity.laplace) 350 | self.warp_perspective_val.set(entity.warp_perspective) 351 | self.rotate_val.set(entity.rotate) 352 | self.sp_noise_val.set(entity.sp_noise) 353 | self.brightness_val.set(entity.brightness) 354 | self.saturation_val.set(entity.saturation) 355 | self.hue_val.set(entity.hue) 356 | self.gamma_val.set(entity.gamma) 357 | self.channel_swap_val.set(entity.channel_swap) 358 | self.random_blank_val.set(entity.random_blank) 359 | self.random_transition_val.set(entity.random_transition) 360 | self.da_random_captcha = entity.random_captcha 361 | if self.da_random_captcha['Enable']: 362 | self.random_captcha_font_val.set(self.da_random_captcha['FontPath']) 363 | 364 | def save_conf(self): 365 | self.data_augmentation_entity.binaryzation = json.loads(self.binaryzation_val.get()) if self.binaryzation_val else [] 366 | self.data_augmentation_entity.median_blur = self.median_blur_val.get() 367 | self.data_augmentation_entity.gaussian_blur = self.gaussian_blur_val.get() 368 | self.data_augmentation_entity.rotate = self.rotate_val.get() 369 | self.data_augmentation_entity.sp_noise = self.sp_noise_val.get() 370 | self.data_augmentation_entity.random_blank = self.random_blank_val.get() 371 | self.data_augmentation_entity.random_transition = self.random_transition_val.get() 372 | 373 | if self.random_captcha_font_val.get(): 374 | self.data_augmentation_entity.random_captcha['Enable'] = True 375 | self.data_augmentation_entity.random_captcha['FontPath'] = self.random_captcha_font_val.get() 376 | else: 377 | self.data_augmentation_entity.random_captcha['Enable'] = False 378 | self.data_augmentation_entity.random_captcha['FontPath'] = "" 379 | 380 | self.data_augmentation_entity.equalize_hist = True if self.equalize_hist_val.get() == 1 else False 381 | self.data_augmentation_entity.laplace = True if self.laplace_val.get() == 1 else False 382 | self.data_augmentation_entity.warp_perspective = True if self.warp_perspective_val.get() == 1 else False 383 | 384 | self.data_augmentation_entity.brightness = True if self.brightness_val.get() == 1 else False 385 | self.data_augmentation_entity.saturation = True if self.saturation_val.get() == 1 else False 386 | self.data_augmentation_entity.hue = True if self.hue_val.get() == 1 else False 387 | self.data_augmentation_entity.gamma = True if self.gamma_val.get() == 1 else False 388 | self.data_augmentation_entity.channel_swap = True if self.channel_swap_val.get() == 1 else False 389 | 390 | self.destroy() -------------------------------------------------------------------------------- /gui/pretreatment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import json 5 | import tkinter as tk 6 | import tkinter.ttk as ttk 7 | from tkinter import messagebox 8 | from gui.utils import LayoutGUI 9 | 10 | 11 | class PretreatmentDialog(tk.Toplevel): 12 | 13 | def __init__(self): 14 | tk.Toplevel.__init__(self) 15 | self.title('Data Pretreatment') 16 | self.layout = { 17 | 'global': { 18 | 'start': {'x': 15, 'y': 20}, 19 | 'space': {'x': 15, 'y': 25}, 20 | 'tiny_space': {'x': 5, 'y': 10} 21 | } 22 | } 23 | self.pretreatment_entity = None 24 | self.window_width = 600 25 | self.window_height = 220 26 | 27 | self.layout_utils = LayoutGUI(self.layout, self.window_width) 28 | screenwidth = self.winfo_screenwidth() 29 | screenheight = self.winfo_screenheight() 30 | size = '%dx%d+%d+%d' % ( 31 | self.window_width, 32 | self.window_height, 33 | (screenwidth - self.window_width) / 2, 34 | (screenheight - self.window_height) / 2 35 | ) 36 | self.geometry(size) 37 | # ============================= Group 4 ===================================== 38 | self.label_frame_pretreatment = ttk.Labelframe(self, text='Data Pretreatment') 39 | self.label_frame_pretreatment.place( 40 | x=self.layout['global']['start']['x'], 41 | y=self.layout['global']['start']['y'], 42 | width=575, 43 | height=150 44 | ) 45 | 46 | # 帧拼接 - 输入框 47 | self.concat_frames_val = tk.StringVar() 48 | self.concat_frames_val.set("") 49 | self.concat_frames_entry = ttk.Entry(self, textvariable=self.concat_frames_val, justify=tk.LEFT) 50 | self.concat_frames_entry['state'] = tk.DISABLED 51 | 52 | # 帧拼接 - 复选框 53 | self.concat_frames_check_val = tk.IntVar() 54 | self.concat_frames_check = ttk.Checkbutton( 55 | self, 56 | text='GIF Frame Stitching', 57 | variable=self.concat_frames_check_val, 58 | onvalue=1, 59 | offvalue=0, 60 | command=lambda: self.check_btn_event(src=self.concat_frames_check_val, entry=self.concat_frames_entry) 61 | ) 62 | self.layout_utils.inside_widget( 63 | src=self.concat_frames_check, 64 | target=self.label_frame_pretreatment, 65 | width=140, 66 | height=20, 67 | ) 68 | 69 | # 帧拼接 - 布局 70 | self.layout_utils.next_to_widget( 71 | src=self.concat_frames_entry, 72 | target=self.concat_frames_check, 73 | width=100, 74 | height=20, 75 | tiny_space=True 76 | ) 77 | 78 | # 帧融合 - 输入框 79 | self.blend_frames_val = tk.StringVar() 80 | self.blend_frames_val.set("") 81 | self.blend_frames_entry = ttk.Entry(self, textvariable=self.blend_frames_val, justify=tk.LEFT) 82 | self.blend_frames_entry['state'] = tk.DISABLED 83 | 84 | # 帧融合 - 复选框 85 | self.blend_frames_check_val = tk.IntVar() 86 | self.blend_frames_check_val.set(0) 87 | self.blend_frames_check = ttk.Checkbutton( 88 | self, text='GIF Blend Frame', 89 | variable=self.blend_frames_check_val, 90 | onvalue=1, 91 | offvalue=0, 92 | command=lambda: self.check_btn_event(src=self.blend_frames_check_val, entry=self.blend_frames_entry) 93 | ) 94 | 95 | # 帧融合 - 布局 96 | self.layout_utils.next_to_widget( 97 | src=self.blend_frames_check, 98 | target=self.concat_frames_entry, 99 | width=120, 100 | height=20, 101 | tiny_space=False 102 | ) 103 | self.layout_utils.next_to_widget( 104 | src=self.blend_frames_entry, 105 | target=self.blend_frames_check, 106 | width=110, 107 | height=20, 108 | tiny_space=True 109 | ) 110 | 111 | # 替换透明 - 复选框 112 | self.replace_transparent_check_val = tk.IntVar() 113 | self.replace_transparent_check = ttk.Checkbutton( 114 | self, text='Replace Transparent', 115 | variable=self.replace_transparent_check_val, 116 | onvalue=1, 117 | offvalue=0 118 | ) 119 | self.layout_utils.below_widget( 120 | src=self.replace_transparent_check, 121 | target=self.concat_frames_check, 122 | width=140, 123 | height=20, 124 | ) 125 | 126 | # 水平拼接 - 复选框 127 | self.horizontal_stitching_check_val = tk.IntVar() 128 | self.horizontal_stitching_check_val.set(0) 129 | self.horizontal_stitching_check = ttk.Checkbutton( 130 | self, text='Horizontal Stitching', 131 | variable=self.horizontal_stitching_check_val, 132 | onvalue=1, 133 | offvalue=0 134 | ) 135 | self.layout_utils.next_to_widget( 136 | src=self.horizontal_stitching_check, 137 | target=self.replace_transparent_check, 138 | width=130, 139 | height=20, 140 | tiny_space=False 141 | ) 142 | 143 | # 二值化 - 标签 144 | self.binaryzation_text = ttk.Label(self, text='Binaryzation', anchor=tk.W) 145 | self.layout_utils.next_to_widget( 146 | src=self.binaryzation_text, 147 | target=self.horizontal_stitching_check, 148 | width=75, 149 | height=20, 150 | tiny_space=False 151 | ) 152 | 153 | # 二值化 - 输入框 154 | self.binaryzation_val = tk.IntVar() 155 | self.binaryzation_val.set(-1) 156 | self.binaryzation_entry = ttk.Entry(self, textvariable=self.binaryzation_val, justify=tk.LEFT) 157 | self.layout_utils.next_to_widget( 158 | src=self.binaryzation_entry, 159 | target=self.binaryzation_text, 160 | width=55, 161 | height=20, 162 | tiny_space=True 163 | ) 164 | 165 | # 保存 - 按钮 166 | self.btn_save = ttk.Button(self, text='Save Configuration', command=lambda: self.save_conf()) 167 | self.layout_utils.widget_from_right( 168 | src=self.btn_save, 169 | target=self.label_frame_pretreatment, 170 | width=120, 171 | height=24, 172 | tiny_space=True 173 | ) 174 | 175 | @staticmethod 176 | def check_btn_event(src: tk.IntVar, entry: tk.Entry): 177 | if src.get() == 1: 178 | entry['state'] = tk.NORMAL 179 | else: 180 | entry['state'] = tk.DISABLED 181 | return None 182 | 183 | def read_conf(self, entity): 184 | self.pretreatment_entity = entity 185 | 186 | try: 187 | 188 | if entity.blend_frames == -1 or self.blend_frames_entry['state'] == tk.DISABLED: 189 | self.blend_frames_check_val.set(0) 190 | self.blend_frames_val.set(json.dumps([-1])) 191 | else: 192 | self.blend_frames_check_val.set(1) 193 | self.blend_frames_entry['state'] = tk.NORMAL 194 | self.blend_frames_val.set(json.dumps(entity.blend_frames)) 195 | 196 | if entity.concat_frames == -1 or self.concat_frames_entry['state'] == tk.DISABLED: 197 | self.concat_frames_check_val.set(0) 198 | self.concat_frames_val.set(json.dumps([0, -1])) 199 | else: 200 | self.concat_frames_check_val.set(1) 201 | self.concat_frames_entry['state'] = tk.NORMAL 202 | self.concat_frames_val.set(json.dumps(entity.concat_frames)) 203 | 204 | self.horizontal_stitching_check_val.set(1 if entity.horizontal_stitching else 0) 205 | self.replace_transparent_check_val.set(1 if entity.replace_transparent else 0) 206 | 207 | self.binaryzation_val.set(entity.binaryzation) 208 | 209 | except Exception as e: 210 | messagebox.showerror( 211 | e.__class__.__name__, json.dumps(e.args) 212 | ) 213 | return 214 | 215 | def save_conf(self): 216 | try: 217 | 218 | if self.concat_frames_check_val.get() == 1: 219 | self.pretreatment_entity.concat_frames = json.loads(self.concat_frames_val.get()) 220 | else: 221 | self.pretreatment_entity.concat_frames = -1 222 | if self.blend_frames_check_val.get() == 1: 223 | self.pretreatment_entity.blend_frames = json.loads(self.blend_frames_val.get()) 224 | else: 225 | self.pretreatment_entity.blend_frames = -1 226 | self.pretreatment_entity.horizontal_stitching = True if self.horizontal_stitching_check_val.get() == 1 else False 227 | self.pretreatment_entity.replace_transparent = True if self.replace_transparent_check_val.get() == 1 else False 228 | self.pretreatment_entity.binaryzation = self.binaryzation_val.get() 229 | except Exception as e: 230 | messagebox.showerror( 231 | e.__class__.__name__, json.dumps(e.args) 232 | ) 233 | return 234 | 235 | self.destroy() -------------------------------------------------------------------------------- /gui/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | 6 | class LayoutGUI(object): 7 | 8 | def __init__(self, layout, window_width): 9 | self.layout = layout 10 | self.window_width = window_width 11 | 12 | def widget_from_right(self, src, target, width, height, tiny_space=False): 13 | target_edge = self.object_edge_info(target) 14 | src.place( 15 | x=self.window_width - width - self.layout['global']['space']['x'], 16 | y=target_edge['edge_y'] + self.layout['global']['tiny_space' if tiny_space else 'space']['y'], 17 | width=width, 18 | height=height 19 | ) 20 | 21 | def before_widget(self, src, target, width, height, tiny_space=False): 22 | target_edge = self.object_edge_info(target) 23 | src.place( 24 | x=target_edge['x'] - width - self.layout['global']['tiny_space' if tiny_space else 'space']['x'], 25 | y=target_edge['y'], 26 | width=width, 27 | height=height 28 | ) 29 | 30 | @staticmethod 31 | def object_edge_info(obj): 32 | info = obj.place_info() 33 | x = int(info['x']) 34 | y = int(info['y']) 35 | edge_x = int(info['x']) + int(info['width']) 36 | edge_y = int(info['y']) + int(info['height']) 37 | return {'x': x, 'y': y, 'edge_x': edge_x, 'edge_y': edge_y} 38 | 39 | def inside_widget(self, src, target, width, height): 40 | target_edge = self.object_edge_info(target) 41 | src.place( 42 | x=target_edge['x'] + self.layout['global']['space']['x'], 43 | y=target_edge['y'] + self.layout['global']['space']['y'], 44 | width=width, 45 | height=height 46 | ) 47 | 48 | def below_widget(self, src, target, width, height, tiny_space=False): 49 | target_edge = self.object_edge_info(target) 50 | src.place( 51 | x=target_edge['x'], 52 | y=target_edge['edge_y'] + self.layout['global']['tiny_space' if tiny_space else 'space']['y'], 53 | width=width, 54 | height=height 55 | ) 56 | 57 | def next_to_widget(self, src, target, width, height, tiny_space=False, offset_y=0): 58 | target_edge = self.object_edge_info(target) 59 | src.place( 60 | x=target_edge['edge_x'] + self.layout['global']['tiny_space' if tiny_space else 'space']['x'], 61 | y=target_edge['y'] + offset_y, 62 | width=width, 63 | height=height 64 | ) 65 | 66 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from config import ModelConfig 6 | 7 | 8 | class Loss(object): 9 | 10 | """损失函数生成器""" 11 | @staticmethod 12 | def cross_entropy(labels, logits): 13 | """交叉熵损失函数""" 14 | 15 | # return tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels) 16 | # return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) 17 | # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 18 | # return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 19 | target = tf.sparse.to_dense(labels) 20 | # target = labels 21 | print('logits', logits.shape) 22 | print('target', target.shape) 23 | # logits = tf.reshape(tensor=logits, shape=[tf.shape(labels)[0], None]) 24 | return tf.compat.v1.keras.backend.sparse_categorical_crossentropy( 25 | target=target, 26 | output=logits, 27 | from_logits=True, 28 | ) 29 | 30 | @staticmethod 31 | def ctc(labels, logits, sequence_length): 32 | """CTC 损失函数""" 33 | 34 | return tf.compat.v1.nn.ctc_loss_v2( 35 | labels=labels, 36 | logits=logits, 37 | logit_length=sequence_length, 38 | label_length=sequence_length, 39 | blank_index=-1, 40 | logits_time_major=True 41 | ) 42 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import sys 5 | import random 6 | from tqdm import tqdm 7 | import tensorflow as tf 8 | from config import * 9 | from constants import RunMode 10 | 11 | _RANDOM_SEED = 0 12 | 13 | 14 | class DataSets: 15 | 16 | """此类用于打包数据集为TFRecords格式""" 17 | def __init__(self, model: ModelConfig): 18 | self.ignore_list = ["Thumbs.db", ".DS_Store"] 19 | self.model: ModelConfig = model 20 | if not os.path.exists(self.model.dataset_root_path): 21 | os.makedirs(self.model.dataset_root_path) 22 | 23 | @staticmethod 24 | def read_image(path): 25 | """ 26 | 读取图片 27 | :param path: 图片路径 28 | :return: 29 | """ 30 | with open(path, "rb") as f: 31 | return f.read() 32 | 33 | def dataset_exists(self): 34 | """数据集是否存在判断函数""" 35 | for file in (self.model.trains_path[DatasetType.TFRecords] + self.model.validation_path[DatasetType.TFRecords]): 36 | if not os.path.exists(file): 37 | return False 38 | return True 39 | 40 | @staticmethod 41 | def bytes_feature(values): 42 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 43 | 44 | def input_to_tfrecords(self, input_data, label): 45 | return tf.train.Example(features=tf.train.Features(feature={ 46 | 'input': self.bytes_feature(input_data), 47 | 'label': self.bytes_feature(label), 48 | })) 49 | 50 | def convert_dataset_from_filename(self, output_filename, file_list, mode: RunMode, is_add=False): 51 | if is_add: 52 | output_filename = self.model.dataset_increasing_name(mode) 53 | if not output_filename: 54 | raise FileNotFoundError('Basic data set missing, please check.') 55 | output_filename = os.path.join(self.model.dataset_root_path, output_filename) 56 | with tf.io.TFRecordWriter(output_filename) as writer: 57 | pbar = tqdm(file_list) 58 | for i, file_name in enumerate(pbar): 59 | try: 60 | if file_name.split("/")[-1] in self.ignore_list: 61 | continue 62 | image_data = self.read_image(file_name) 63 | try: 64 | labels = re.search(self.model.extract_regex, file_name.split(PATH_SPLIT)[-1]) 65 | except re.error as e: 66 | print('error:', e) 67 | return 68 | if labels: 69 | labels = labels.group() 70 | else: 71 | tf.compat.v1.logging.warning('invalid filename {}, ignored.'.format(file_name)) 72 | continue 73 | # raise NameError('invalid filename {}'.format(file_name)) 74 | labels = labels.encode('utf-8') 75 | 76 | example = self.input_to_tfrecords(image_data, labels) 77 | writer.write(example.SerializeToString()) 78 | pbar.set_description('[Processing dataset %s] [filename: %s]' % (mode, file_name)) 79 | 80 | except IOError as e: 81 | print('could not read:', file_list[1]) 82 | print('error:', e) 83 | print('skip it \n') 84 | 85 | def convert_dataset_from_txt(self, output_filename, file_path, label_lines, mode: RunMode, is_add=False): 86 | if is_add: 87 | output_filename = self.model.dataset_increasing_name(mode) 88 | if not output_filename: 89 | raise FileNotFoundError('Basic data set missing, please check.') 90 | output_filename = os.path.join(self.model.dataset_root_path, output_filename) 91 | file_list, label_list = [], [] 92 | for line in label_lines: 93 | filename, label = line.split(" ", 1) 94 | label = label.replace("\n", "") 95 | label_list.append(label.encode('utf-8')) 96 | path = os.path.join(file_path, filename) 97 | file_list.append(path) 98 | 99 | if os.path.exists(output_filename): 100 | print('已存在, 跳过') 101 | return 102 | 103 | with tf.io.TFRecordWriter(output_filename) as writer: 104 | pbar = tqdm(file_list) 105 | for i, file_name in enumerate(pbar): 106 | try: 107 | image_data = self.read_image(file_name) 108 | labels = label_list[i] 109 | example = self.input_to_tfrecords(image_data, labels) 110 | writer.write(example.SerializeToString()) 111 | pbar.set_description('[Processing dataset %s] [filename: %s]' % (mode, file_name)) 112 | except IOError as e: 113 | print('could not read:', file_list[1]) 114 | print('error:', e) 115 | print('skip it \n') 116 | 117 | @staticmethod 118 | def merge_source(source): 119 | if isinstance(source, list): 120 | origin_dataset = [] 121 | for trains_path in source: 122 | origin_dataset += [ 123 | os.path.join(trains_path, trains).replace("\\", "/") for trains in os.listdir(trains_path) 124 | ] 125 | elif isinstance(source, str): 126 | origin_dataset = [os.path.join(source, trains) for trains in os.listdir(source)] 127 | else: 128 | return 129 | random.seed(0) 130 | random.shuffle(origin_dataset) 131 | return origin_dataset 132 | 133 | def make_dataset(self, trains_path=None, validation_path=None, is_add=False, callback=None, msg=None): 134 | if self.dataset_exists() and not is_add: 135 | state = "EXISTS" 136 | if callback: 137 | callback() 138 | if msg: 139 | msg(state) 140 | return 141 | 142 | if not self.model.dataset_path_root: 143 | state = "CONF_ERROR" 144 | if callback: 145 | callback() 146 | if msg: 147 | msg(state) 148 | return 149 | 150 | trains_path = trains_path if is_add else self.model.trains_path[DatasetType.Directory] 151 | validation_path = validation_path if is_add else self.model.validation_path[DatasetType.Directory] 152 | 153 | trains_path = [trains_path] if isinstance(trains_path, str) else trains_path 154 | validation_path = [validation_path] if isinstance(validation_path, str) else validation_path 155 | 156 | if validation_path and not is_add: 157 | if self.model.label_from == LabelFrom.FileName: 158 | trains_dataset = self.merge_source(trains_path) 159 | validation_dataset = self.merge_source(validation_path) 160 | self.convert_dataset_from_filename( 161 | self.model.validation_path[DatasetType.TFRecords][-1 if is_add else 0], 162 | validation_dataset, 163 | mode=RunMode.Validation, 164 | is_add=is_add, 165 | ) 166 | self.convert_dataset_from_filename( 167 | self.model.trains_path[DatasetType.TFRecords][-1 if is_add else 0], 168 | trains_dataset, 169 | mode=RunMode.Trains, 170 | is_add=is_add, 171 | ) 172 | elif self.model.label_from == LabelFrom.TXT: 173 | 174 | train_label_file = os.path.join(os.path.dirname(trains_path[0]), "train.txt") 175 | val_label_file = os.path.join(os.path.dirname(validation_path[0]), "val.txt") 176 | 177 | if not os.path.exists(train_label_file) or not os.path.exists(val_label_file): 178 | msg("Train or validation label file not found!") 179 | if callback: 180 | callback() 181 | return 182 | 183 | with open(train_label_file, "r", encoding="utf8") as f_train: 184 | train_label_line = f_train.readlines() 185 | 186 | with open(val_label_file, "r", encoding="utf8") as f_val: 187 | val_label_line = f_val.readlines() 188 | 189 | self.convert_dataset_from_txt( 190 | self.model.validation_path[DatasetType.TFRecords][-1 if is_add else 0], 191 | label_lines=val_label_line, 192 | file_path=validation_path[0], 193 | mode=RunMode.Validation, 194 | is_add=is_add, 195 | ) 196 | self.convert_dataset_from_txt( 197 | self.model.trains_path[DatasetType.TFRecords][-1 if is_add else 0], 198 | label_lines=train_label_line, 199 | file_path=trains_path[0], 200 | mode=RunMode.Trains, 201 | is_add=is_add, 202 | ) 203 | 204 | else: 205 | if self.model.label_from == LabelFrom.FileName: 206 | origin_dataset = self.merge_source(trains_path) 207 | trains_dataset = origin_dataset[self.model.validation_set_num:] 208 | if self.model.validation_set_num > 0: 209 | validation_dataset = origin_dataset[:self.model.validation_set_num] 210 | self.convert_dataset_from_filename( 211 | self.model.validation_path[DatasetType.TFRecords][-1 if is_add else 0], 212 | validation_dataset, 213 | mode=RunMode.Validation, 214 | is_add=is_add 215 | ) 216 | elif self.model.validation_set_num < 0: 217 | self.convert_dataset_from_filename( 218 | self.model.validation_path[DatasetType.TFRecords][-1 if is_add else 0], 219 | trains_dataset, 220 | mode=RunMode.Validation, 221 | is_add=is_add 222 | ) 223 | self.convert_dataset_from_filename( 224 | self.model.trains_path[DatasetType.TFRecords][-1 if is_add else 0], 225 | trains_dataset, 226 | mode=RunMode.Trains, 227 | is_add=is_add 228 | ) 229 | elif self.model.label_from == LabelFrom.TXT: 230 | 231 | train_label_file = os.path.join(os.path.dirname(trains_path[0]), "train.txt") 232 | 233 | if not os.path.exists(train_label_file): 234 | msg("Train label file not found!") 235 | if callback: 236 | callback() 237 | return 238 | 239 | with open(train_label_file, "r", encoding="utf8") as f: 240 | sample_label_line = f.readlines() 241 | 242 | random.shuffle(sample_label_line) 243 | 244 | train_label_line = sample_label_line[self.model.validation_set_num:] 245 | val_label_line = sample_label_line[:self.model.validation_set_num] 246 | 247 | self.convert_dataset_from_txt( 248 | self.model.validation_path[DatasetType.TFRecords][-1 if is_add else 0], 249 | label_lines=val_label_line, 250 | file_path=trains_path[0], 251 | mode=RunMode.Validation, 252 | is_add=is_add, 253 | ) 254 | self.convert_dataset_from_txt( 255 | self.model.trains_path[DatasetType.TFRecords][-1 if is_add else 0], 256 | label_lines=train_label_line, 257 | file_path=trains_path[0], 258 | mode=RunMode.Trains, 259 | is_add=is_add, 260 | ) 261 | 262 | state = "DONE" 263 | if callback: 264 | callback() 265 | if msg: 266 | msg(state) 267 | return 268 | 269 | 270 | if __name__ == '__main__': 271 | model_conf = ModelConfig(sys.argv[-1]) 272 | _dataset = DataSets(model_conf) 273 | _dataset.make_dataset() 274 | -------------------------------------------------------------------------------- /middleware/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /middleware/random_captcha.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | from enum import Enum, unique 3 | from fontTools.ttLib import TTFont 4 | import numpy as np 5 | import io 6 | import os 7 | import base64 8 | import hashlib 9 | import time 10 | import random 11 | import logging 12 | 13 | 14 | class BackgroundType(Enum): 15 | RANDOM = 'random' 16 | IMAGE = 'image' 17 | RGB = 'rgb' 18 | 19 | 20 | class RandomCaptcha(object): 21 | """随机英数样本生成器""" 22 | def __init__(self): 23 | self.__width = [130, 160] 24 | self.__height = [50, 60] 25 | self.__background_mode = BackgroundType.RGB 26 | self.__background_img_assests_path = None 27 | self.__rgb = { 28 | 'r': [0, 255], 29 | 'g': [0, 255], 30 | 'b': [0, 255] 31 | } 32 | self.__fonts_list = [] 33 | self.__samples = [] 34 | self.__fonts_num = [4, 4] 35 | self.__font_size = [26, 36] 36 | self.__font_mode = 0 37 | self.__max_line_count = 2 38 | self.__max_point_count = 20 39 | 40 | @property 41 | def max_point_count(self): 42 | return self.__max_point_count 43 | 44 | @max_point_count.setter 45 | def max_point_count(self, value: int): 46 | self.__max_point_count = value 47 | 48 | @property 49 | def max_line_count(self): 50 | return self.__max_line_count 51 | 52 | @max_line_count.setter 53 | def max_line_count(self, value: int): 54 | self.__max_line_count = value 55 | 56 | @property 57 | def font_mode(self): 58 | return self.__font_mode 59 | 60 | @font_mode.setter 61 | def font_mode(self, value: int): 62 | self.__font_mode = value 63 | 64 | @property 65 | def font_size(self) -> list: 66 | return self.__font_size 67 | 68 | @font_size.setter 69 | def font_size(self, value: list): 70 | if type(value) == list and type(value[0]) == int and type(value[1]) == int and value[0] >= 0 and value[1] > 0 and value[0] < value[1]: 71 | self.__font_size = value 72 | else: 73 | raise ValueError("input value should be like [0, 255]") 74 | 75 | @property 76 | def fonts_num(self) -> list: 77 | return self.__fonts_num 78 | 79 | @fonts_num.setter 80 | def fonts_num(self, value: list): 81 | self.__fonts_num = value 82 | 83 | @property 84 | def sample(self) -> list: 85 | return self.__samples 86 | 87 | @sample.setter 88 | def sample(self, value: list): 89 | self.__samples = value 90 | 91 | @property 92 | def fonts_list(self) -> list: 93 | return self.__fonts_list 94 | 95 | @fonts_list.setter 96 | def fonts_list(self, value: list): 97 | self.__fonts_list = value 98 | 99 | @property 100 | def rgb(self) -> dict: 101 | return self.__rgb 102 | 103 | @property 104 | def rgb_r(self) -> list: 105 | return self.__rgb['r'] 106 | 107 | @rgb_r.setter 108 | def rgb_r(self, value: list): 109 | if type(value) == list and type(value[0]) == int and type(value[1]) == int and value[0] >= 0 and value[1] > 0 and value[0] < value[1] and value[0] <= 255: 110 | self.__rgb['r'] = value 111 | else: 112 | raise ValueError("input value should be like [0, 255]") 113 | 114 | @property 115 | def rgb_g(self) -> list: 116 | return self.__rgb['g'] 117 | 118 | @rgb_g.setter 119 | def rgb_g(self, value: list): 120 | if type(value) == list and type(value[0]) == int and type(value[1]) == int and value[0] >= 0 and value[1] > 0 and value[0] < value[1] and value[0] <= 255: 121 | self.__rgb['g'] = value 122 | else: 123 | raise ValueError("input value should be like [0, 255]") 124 | 125 | @property 126 | def rgb_b(self) -> list: 127 | return self.__rgb['b'] 128 | 129 | @rgb_b.setter 130 | def rgb_b(self, value: list): 131 | if type(value) == list and type(value[0]) == int and type(value[1]) == int and value[0] >= 0 and value[1] > 0 and value[0] < value[1]: 132 | self.__rgb['b'] = value 133 | else: 134 | raise ValueError("input value should be like [0, 255]") 135 | 136 | @property 137 | def background_mode(self) -> BackgroundType: 138 | return self.__background_mode 139 | 140 | @background_mode.setter 141 | def background_mode(self, value: BackgroundType): 142 | self.__background_mode = value 143 | 144 | @property 145 | def background_img_path(self) -> str: 146 | return self.__background_img_assests_path 147 | 148 | @background_img_path.setter 149 | def background_img_path(self, value: str): 150 | self.__background_img_assests_path = value 151 | 152 | @property 153 | def height(self): 154 | return self.__height 155 | 156 | @height.setter 157 | def height(self, value): 158 | self.__height = value 159 | 160 | @property 161 | def width(self): 162 | return self.__width 163 | 164 | @width.setter 165 | def width(self, value): 166 | self.__width = value 167 | 168 | def check_font(self): 169 | for font_type in self.fonts_list: 170 | try: 171 | font = TTFont(font_type) 172 | uni_map = font['cmap'].tables[0].ttFont.getBestCmap() 173 | for item in self.sample: 174 | codepoint = ord(str(item)) 175 | if codepoint in uni_map.keys(): 176 | continue 177 | else: 178 | font.close() 179 | raise Exception("{} not found!".format(item)) 180 | except Exception as e: 181 | try: 182 | os.remove(font_type) 183 | except: 184 | pass 185 | del self.fonts_list[self.fonts_list.index(font_type)] 186 | 187 | pass 188 | 189 | def set_text(self, __image: ImageDraw, img_width, img_height): 190 | 191 | if img_width >= 150: 192 | font_size = random.choice(range(self.font_size[0], self.font_size[1])) 193 | else: 194 | font_size = random.choice(range(self.font_size[0], int((self.font_size[0] + self.font_size[1])/2))) 195 | 196 | font_num = random.choice(range(self.fonts_num[0], self.fonts_num[1])) 197 | max_width = int(img_width / font_num) 198 | max_height = int(img_height) 199 | font_type = random.choice(self.fonts_list) 200 | try: 201 | font = ImageFont.truetype(font_type, font_size) 202 | except OSError: 203 | del self.fonts_list[self.fonts_list.index(font_type)] 204 | raise Exception("{} opened fail") 205 | labels = [] 206 | for idx in range(font_num): 207 | fw = range(int(max_width - font_size)) 208 | if len(fw) > 0: 209 | x = max_width * idx + random.choice(fw) 210 | else: 211 | x = max_width * idx 212 | y = random.choice(range(int(max_height - font_size))) 213 | f = random.choice(self.sample) 214 | labels.append(f) 215 | __image.text((x, y), f, font=font, 216 | fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) 217 | return labels, font_type 218 | 219 | def set_noise(self, __image: ImageDraw, img_width, img_height): 220 | for i in range(self.max_line_count): 221 | # 噪线的起点横坐标和纵坐标 222 | x1 = random.randint(0, img_width) 223 | y1 = random.randint(0, img_height) 224 | # 噪线的终点横坐标和纵坐标 225 | x2 = random.randint(0, img_width) 226 | y2 = random.randint(0, img_height) 227 | # 通过画笔对象draw.line((起点的xy, 终点的xy), fill='颜色')来划线 228 | __image.line((x1, y1, x2, y2), 229 | fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) 230 | for i in range(self.max_point_count): 231 | __image.point([random.randint(0, img_width), random.randint(0, img_height)], fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) 232 | x = random.randint(0, img_width) 233 | y = random.randint(0, img_height) 234 | __image.arc((x, y, x + 4, y + 4), 0, 40, fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))) 235 | 236 | def set_content(self, __image: ImageDraw, img_width, img_height): 237 | labels, font_type = self.set_text(__image, img_width, img_height) 238 | self.set_noise(__image, img_width, img_height) 239 | return labels, font_type 240 | 241 | def create(self, mode: str = "bytes", img_format: str = "png"): 242 | if type(self.width) == list: 243 | img_width = random.choice(range(self.width[0], self.width[1])) 244 | else: 245 | img_width = self.width 246 | if type(self.height) == list: 247 | img_height = random.choice(range(self.height[0], self.height[1])) 248 | else: 249 | img_height = self.height 250 | 251 | background_mode = self.background_mode 252 | if type(background_mode) is BackgroundType: 253 | if background_mode.value == BackgroundType.RGB.value: 254 | rgb_range = self.rgb 255 | r_range = rgb_range['r'] 256 | g_range = rgb_range['g'] 257 | b_range = rgb_range['b'] 258 | rgb = (random.randint(r_range[0], r_range[1]), random.randint(g_range[0], g_range[1]), 259 | random.randint(b_range[0], b_range[1])) 260 | __image = Image.new('RGB', (img_width, img_height), rgb) 261 | img = ImageDraw.Draw(__image) 262 | labels, font_type = self.set_content(img, img_width, img_height) 263 | if mode == "bytes": 264 | img_byte_arr = io.BytesIO() 265 | __image.save(img_byte_arr, format=img_format) 266 | return img_byte_arr.getvalue(), labels, font_type 267 | elif mode == "numpy": 268 | return np.array(__image), labels, font_type 269 | elif mode == "base64": 270 | img_byte_arr = io.BytesIO() 271 | __image.save(img_byte_arr, format=img_format) 272 | _bytes = img_byte_arr.getvalue() 273 | return base64.b64encode(_bytes).decode(), labels, font_type 274 | else: 275 | raise FutureWarning("暂不支持的输出类型") 276 | else: 277 | raise FutureWarning("暂不支持的背景类型") 278 | else: 279 | raise TypeError("background mode must be BGMODEL.") 280 | -------------------------------------------------------------------------------- /model.template: -------------------------------------------------------------------------------- 1 | # - requirement.txt - GPU: tensorflow-gpu, CPU: tensorflow 2 | # - If you use the GPU version, you need to install some additional applications. 3 | System: 4 | MemoryUsage: {MemoryUsage} 5 | Version: 2 6 | 7 | # CNNNetwork: [CNN5, ResNet, DenseNet] 8 | # RecurrentNetwork: [CuDNNBiLSTM, CuDNNLSTM, CuDNNGRU, BiLSTM, LSTM, GRU, BiGRU, NoRecurrent] 9 | # - The recommended configuration is CNN5+GRU 10 | # UnitsNum: [16, 64, 128, 256, 512] 11 | # - This parameter indicates the number of nodes used to remember and store past states. 12 | # Optimizer: Loss function algorithm for calculating gradient. 13 | # - [AdaBound, Adam, Momentum] 14 | # OutputLayer: [LossFunction, Decoder] 15 | # - LossFunction: [CTC, CrossEntropy] 16 | # - Decoder: [CTC, CrossEntropy] 17 | NeuralNet: 18 | CNNNetwork: {CNNNetwork} 19 | RecurrentNetwork: {RecurrentNetwork} 20 | UnitsNum: {UnitsNum} 21 | Optimizer: {Optimizer} 22 | OutputLayer: 23 | LossFunction: {LossFunction} 24 | Decoder: {Decoder} 25 | 26 | 27 | # ModelName: Corresponding to the model file in the model directory 28 | # ModelField: [Image, Text] 29 | # ModelScene: [Classification] 30 | # - Currently only Image-Classification is supported. 31 | Model: 32 | ModelName: {ModelName} 33 | ModelField: {ModelField} 34 | ModelScene: {ModelScene} 35 | 36 | # FieldParam contains the Image, Text. 37 | # When you filed to Image: 38 | # - Category: Provides a default optional built-in solution: 39 | # -- [ALPHANUMERIC, ALPHANUMERIC_LOWER, ALPHANUMERIC_UPPER, 40 | # -- NUMERIC, ALPHABET_LOWER, ALPHABET_UPPER, ALPHABET, ALPHANUMERIC_CHS_3500_LOWER] 41 | # - or can be customized by: 42 | # -- ['Cat', 'Lion', 'Tiger', 'Fish', 'BigCat'] 43 | # - Resize: [ImageWidth, ImageHeight/-1, ImageChannel] 44 | # - ImageChannel: [1, 3] 45 | # - In order to automatically select models using image size, when multiple models are deployed at the same time: 46 | # -- ImageWidth: The width of the image. 47 | # -- ImageHeight: The height of the image. 48 | # - MaxLabelNum: You can fill in -1, or any integer, where -1 means not defining the value. 49 | # -- Used when the number of label is fixed 50 | # When you filed to Text: 51 | # This type is temporarily not supported. 52 | FieldParam: 53 | Category: {Category} 54 | Resize: {Resize} 55 | ImageChannel: {ImageChannel} 56 | ImageWidth: {ImageWidth} 57 | ImageHeight: {ImageHeight} 58 | MaxLabelNum: {MaxLabelNum} 59 | OutputSplit: {OutputSplit} 60 | AutoPadding: {AutoPadding} 61 | 62 | 63 | # The configuration is applied to the label of the data source. 64 | # LabelFrom: [FileName, XML, LMDB] 65 | # ExtractRegex: Only for methods extracted from FileName: 66 | # - Default matching apple_20181010121212.jpg file. 67 | # - The Default is .*?(?=_.*\.) 68 | # LabelSplit: Only for methods extracted from FileName: 69 | # - The split symbol in the file name is like: cat&big cat&lion_20181010121212.png 70 | # - The Default is null. 71 | Label: 72 | LabelFrom: {LabelFrom} 73 | ExtractRegex: {ExtractRegex} 74 | LabelSplit: {LabelSplit} 75 | 76 | 77 | # DatasetPath: [Training/Validation], The local absolute path of a packed training or validation set. 78 | # SourcePath: [Training/Validation], The local absolute path to the source folder of the training or validation set. 79 | # ValidationSetNum: This is an optional parameter that is used when you want to extract some of the validation set 80 | # - from the training set when you are not preparing the validation set separately. 81 | # SavedSteps: A Session.run() execution is called a Step, 82 | # - Used to save training progress, Default value is 100. 83 | # ValidationSteps: Used to calculate accuracy, Default value is 500. 84 | # EndAcc: Finish the training when the accuracy reaches [EndAcc*100]% and other conditions. 85 | # EndCost: Finish the training when the cost reaches EndCost and other conditions. 86 | # EndEpochs: Finish the training when the epoch is greater than the defined epoch and other conditions. 87 | # BatchSize: Number of samples selected for one training step. 88 | # ValidationBatchSize: Number of samples selected for one validation step. 89 | # LearningRate: [0.1, 0.01, 0.001, 0.0001] 90 | # - Use a smaller learning rate for fine-tuning. 91 | Trains: 92 | DatasetPath: 93 | Training: {DatasetTrainsPath} 94 | Validation: {DatasetValidationPath} 95 | SourcePath: 96 | Training: {SourceTrainPath} 97 | Validation: {SourceValidationPath} 98 | ValidationSetNum: {ValidationSetNum} 99 | SavedSteps: {SavedSteps} 100 | ValidationSteps: {ValidationSteps} 101 | EndAcc: {EndAcc} 102 | EndCost: {EndCost} 103 | EndEpochs: {EndEpochs} 104 | BatchSize: {BatchSize} 105 | ValidationBatchSize: {ValidationBatchSize} 106 | LearningRate: {LearningRate} 107 | 108 | # Binaryzation: The argument is of type list and contains the range of int values, -1 is not enabled. 109 | # MedianBlur: The parameter is an int value, -1 is not enabled. 110 | # GaussianBlur: The parameter is an int value, -1 is not enabled. 111 | # EqualizeHist: The parameter is an bool value. 112 | # Laplace: The parameter is an bool value. 113 | # WarpPerspective: The parameter is an bool value. 114 | # Rotate: The parameter is a positive integer int type greater than 0, -1 is not enabled. 115 | # PepperNoise: This parameter is a float type less than 1, -1 is not enabled. 116 | # Brightness: The parameter is an bool value. 117 | # Saturation: The parameter is an bool value. 118 | # Hue: The parameter is an bool value. 119 | # Gamma: The parameter is an bool value. 120 | # ChannelSwap: The parameter is an bool value. 121 | # RandomBlank: The parameter is a positive integer int type greater than 0, -1 is not enabled. 122 | # RandomTransition: The parameter is a positive integer int type greater than 0, -1 is not enabled. 123 | DataAugmentation: 124 | Binaryzation: {DA_Binaryzation} 125 | MedianBlur: {DA_MedianBlur} 126 | GaussianBlur: {DA_GaussianBlur} 127 | EqualizeHist: {DA_EqualizeHist} 128 | Laplace: {DA_Laplace} 129 | WarpPerspective: {DA_WarpPerspective} 130 | Rotate: {DA_Rotate} 131 | PepperNoise: {DA_PepperNoise} 132 | Brightness: {DA_Brightness} 133 | Saturation: {DA_Saturation} 134 | Hue: {DA_Hue} 135 | Gamma: {DA_Gamma} 136 | ChannelSwap: {DA_ChannelSwap} 137 | RandomBlank: {DA_RandomBlank} 138 | RandomTransition: {DA_RandomTransition} 139 | RandomCaptcha: {DA_RandomCaptcha} 140 | 141 | # Binaryzation: The parameter is an integer number between 0 and 255, -1 is not enabled. 142 | # ReplaceTransparent: Transparent background replacement, bool type. 143 | # HorizontalStitching: Horizontal stitching, bool type. 144 | # ConcatFrames: Horizontally merge two frames according to the provided frame index list, -1 is not enabled. 145 | # BlendFrames: Fusion corresponding frames according to the provided frame index list, -1 is not enabled. 146 | # - [-1] means all frames 147 | Pretreatment: 148 | Binaryzation: {Pre_Binaryzation} 149 | ReplaceTransparent: {Pre_ReplaceTransparent} 150 | HorizontalStitching: {Pre_HorizontalStitching} 151 | ConcatFrames: {Pre_ConcatFrames} 152 | BlendFrames: {Pre_BlendFrames} 153 | ExecuteMap: {Pre_ExecuteMap} 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /network/CNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from network.utils import NetworkUtils 6 | from config import ModelConfig 7 | from tensorflow.python.keras.regularizers import l1 8 | 9 | 10 | class CNN3(object): 11 | 12 | """ 13 | CNN5网络的实现 14 | """ 15 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 16 | """ 17 | :param model_conf: 从配置文件 18 | :param inputs: 网络上一层输入 tf.keras.layers.Input / tf.Tensor 类型 19 | :param utils: 网络工具类 20 | """ 21 | self.model_conf = model_conf 22 | self.inputs = inputs 23 | self.utils = utils 24 | self.loss_func = self.model_conf.loss_func 25 | 26 | def build(self): 27 | with tf.keras.backend.name_scope("CNN3"): 28 | x = self.utils.cnn_layer(0, inputs=self.inputs, kernel_size=7, filters=32, strides=(1, 1)) 29 | x = self.utils.cnn_layer(1, inputs=x, kernel_size=5, filters=64, strides=(1, 2)) 30 | x = self.utils.cnn_layer(2, inputs=x, kernel_size=3, filters=64, strides=(1, 2)) 31 | shape_list = x.get_shape().as_list() 32 | print("x.get_shape()", shape_list) 33 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 34 | 35 | 36 | class CNN5(object): 37 | 38 | """ 39 | CNN5网络的实现 40 | """ 41 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 42 | """ 43 | :param model_conf: 从配置文件 44 | :param inputs: 网络上一层输入 tf.keras.layers.Input / tf.Tensor 类型 45 | :param utils: 网络工具类 46 | """ 47 | self.model_conf = model_conf 48 | self.inputs = inputs 49 | self.utils = utils 50 | self.loss_func = self.model_conf.loss_func 51 | 52 | def build(self): 53 | with tf.keras.backend.name_scope("CNN5"): 54 | x = self.utils.cnn_layer(0, inputs=self.inputs, kernel_size=7, filters=32, strides=(1, 1)) 55 | x = self.utils.cnn_layer(1, inputs=x, kernel_size=5, filters=64, strides=(1, 2)) 56 | x = self.utils.cnn_layer(2, inputs=x, kernel_size=3, filters=128, strides=(1, 2)) 57 | x = self.utils.cnn_layer(3, inputs=x, kernel_size=3, filters=128, strides=(1, 2)) 58 | x = self.utils.cnn_layer(4, inputs=x, kernel_size=3, filters=64, strides=(1, 2)) 59 | shape_list = x.get_shape().as_list() 60 | print("x.get_shape()", shape_list) 61 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 62 | 63 | 64 | class CNNX(object): 65 | 66 | """ 网络结构 """ 67 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 68 | self.model_conf = model_conf 69 | self.inputs = inputs 70 | self.utils = utils 71 | self.loss_func = self.model_conf.loss_func 72 | 73 | def block(self, inputs, filters, kernel_size, strides, dilation_rate=(1, 1)): 74 | inputs = tf.keras.layers.Conv2D( 75 | filters=filters, 76 | dilation_rate=dilation_rate, 77 | kernel_size=kernel_size, 78 | strides=strides, 79 | kernel_regularizer=l1(0.1), 80 | kernel_initializer=self.utils.msra_initializer(kernel_size, filters), 81 | padding='SAME', 82 | )(inputs) 83 | inputs = tf.compat.v1.layers.batch_normalization( 84 | inputs, 85 | reuse=False, 86 | momentum=0.9, 87 | training=self.utils.is_training 88 | ) 89 | inputs = self.utils.hard_swish(inputs) 90 | return inputs 91 | 92 | def build(self): 93 | with tf.keras.backend.name_scope('CNNX'): 94 | x = self.inputs 95 | 96 | x = self.block(x, filters=16, kernel_size=7, strides=1) 97 | 98 | max_pool0 = tf.keras.layers.MaxPooling2D( 99 | pool_size=(1, 2), 100 | strides=2, 101 | padding='same')(x) 102 | max_pool1 = tf.keras.layers.MaxPooling2D( 103 | pool_size=(1, 3), 104 | strides=2, 105 | padding='same')(x) 106 | max_pool2 = tf.keras.layers.MaxPooling2D( 107 | pool_size=(1, 5), 108 | strides=2, 109 | padding='same')(x) 110 | max_pool3 = tf.keras.layers.MaxPooling2D( 111 | pool_size=(1, 7), 112 | strides=2, 113 | padding='same')(x) 114 | 115 | multi_scale_pool = tf.keras.layers.Add()([max_pool0, max_pool1, max_pool2, max_pool3]) 116 | 117 | x = self.block(multi_scale_pool, filters=32, kernel_size=5, strides=1) 118 | 119 | x1 = self.utils.inverted_res_block(x, filters=16, stride=2, expansion=6, block_id=1) 120 | x1 = self.utils.inverted_res_block(x1, filters=16, stride=1, expansion=6, block_id=2) 121 | 122 | x2 = tf.keras.layers.MaxPooling2D( 123 | pool_size=(2, 2), 124 | strides=2, 125 | padding='same')(x) 126 | x = tf.keras.layers.Concatenate()([x2, x1]) 127 | 128 | x = self.utils.inverted_res_block(x, filters=32, stride=2, expansion=6, block_id=3) 129 | x = self.utils.inverted_res_block(x, filters=32, stride=1, expansion=6, block_id=4) 130 | 131 | x = self.utils.dense_block(x, 2, name='dense_block') 132 | 133 | x = self.utils.inverted_res_block(x, filters=64, stride=1, expansion=6, block_id=5) 134 | 135 | shape_list = x.get_shape().as_list() 136 | print("x.get_shape()", shape_list) 137 | 138 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 139 | -------------------------------------------------------------------------------- /network/DenseNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | # This network was temporarily suspended 5 | import tensorflow as tf 6 | from network.utils import NetworkUtils 7 | from config import ModelConfig 8 | 9 | 10 | class DenseNet(object): 11 | 12 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 13 | self.model_conf = model_conf 14 | self.inputs = inputs 15 | self.utils = utils 16 | self.loss_func = self.model_conf.loss_func 17 | self.type = { 18 | '121': [6, 12, 24, 16], 19 | '169': [6, 12, 32, 32], 20 | '201': [6, 12, 48, 32] 21 | } 22 | self.blocks = self.type['121'] 23 | self.padding = "SAME" 24 | 25 | def build(self): 26 | 27 | with tf.keras.backend.name_scope('DenseNet'): 28 | 29 | x = tf.keras.layers.Conv2D(64, 3, strides=2, use_bias=False, name='conv1/conv', padding='same')(self.inputs) 30 | x = tf.layers.batch_normalization( 31 | x, 32 | epsilon=1.001e-5, 33 | axis=3, 34 | reuse=False, 35 | momentum=0.9, 36 | name='conv1/bn', 37 | training=self.utils.is_training, 38 | ) 39 | 40 | x = tf.keras.layers.LeakyReLU(0.01, name='conv1/relu')(x) 41 | x = tf.keras.layers.MaxPooling2D(3, strides=2, name='pool1', padding='same')(x) 42 | x = self.utils.dense_block(x, self.blocks[0], name='conv2') 43 | x = self.utils.transition_block(x, 0.5, name='pool2') 44 | x = self.utils.dense_block(x, self.blocks[1], name='conv3') 45 | x = self.utils.transition_block(x, 0.5, name='pool3') 46 | x = self.utils.dense_block(x, self.blocks[2], name='conv4') 47 | x = self.utils.transition_block(x, 0.5, name='pool4') 48 | x = self.utils.dense_block(x, self.blocks[3], name='conv5') 49 | x = tf.layers.batch_normalization( 50 | x, 51 | epsilon=1.001e-5, 52 | axis=3, 53 | reuse=False, 54 | momentum=0.9, 55 | name='bn', 56 | training=self.utils.is_training, 57 | ) 58 | 59 | x = tf.keras.layers.LeakyReLU(0.01, name='conv6/relu')(x) 60 | 61 | shape_list = x.get_shape().as_list() 62 | print("x.get_shape()", shape_list) 63 | 64 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 65 | -------------------------------------------------------------------------------- /network/GRU.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | import tensorflow as tf 6 | from config import RunMode, ModelConfig 7 | from network.utils import NetworkUtils 8 | 9 | 10 | class GRU(object): 11 | 12 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 13 | """ 14 | :param model_conf: 配置 15 | :param inputs: 网络上一层输入tf.keras.layers.Input/tf.Tensor类型 16 | :param utils: 网络工具类 17 | """ 18 | self.model_conf = model_conf 19 | self.inputs = inputs 20 | self.utils = utils 21 | self.layer = None 22 | 23 | def build(self): 24 | """ 25 | 循环层构建参数 26 | :return: 返回循环层的输出层 27 | """ 28 | with tf.keras.backend.name_scope('GRU'): 29 | mask = tf.keras.layers.Masking()(self.inputs) 30 | self.layer = tf.keras.layers.GRU( 31 | units=self.model_conf.units_num * 2, 32 | return_sequences=True, 33 | input_shape=mask.shape, 34 | # reset_after=True, 35 | ) 36 | outputs = self.layer(mask, training=self.utils.is_training) 37 | return outputs 38 | 39 | 40 | class BiGRU(object): 41 | 42 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 43 | self.model_conf = model_conf 44 | self.inputs = inputs 45 | self.utils = utils 46 | self.layer = None 47 | 48 | def build(self): 49 | with tf.keras.backend.name_scope('BiGRU'): 50 | mask = tf.keras.layers.Masking()(self.inputs) 51 | self.layer = tf.keras.layers.Bidirectional( 52 | layer=tf.keras.layers.GRU( 53 | units=self.model_conf.units_num, 54 | return_sequences=True, 55 | ), 56 | input_shape=mask.shape, 57 | ) 58 | outputs = self.layer(mask, training=self.utils.is_training) 59 | return outputs 60 | 61 | 62 | class GRUcuDNN(object): 63 | 64 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 65 | self.model_conf = model_conf 66 | self.inputs = inputs 67 | self.utils = utils 68 | self.layer = None 69 | 70 | def build(self): 71 | with tf.keras.backend.name_scope('GRU'): 72 | mask = tf.keras.layers.Masking()(self.inputs) 73 | self.layer = tf.keras.layers.GRU( 74 | units=self.model_conf.units_num * 2, 75 | return_sequences=True, 76 | input_shape=mask.shape, 77 | reset_after=True 78 | ) 79 | outputs = self.layer(mask, training=self.utils.is_training) 80 | return outputs 81 | -------------------------------------------------------------------------------- /network/LSTM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from config import RunMode, ModelConfig 6 | from network.utils import NetworkUtils 7 | 8 | 9 | class LSTM(object): 10 | """ 11 | LSTM 网络实现 12 | """ 13 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 14 | """ 15 | :param model_conf: 配置 16 | :param inputs: 网络上一层输入 tf.keras.layers.Input / tf.Tensor 类型 17 | :param utils: 网络工具类 18 | """ 19 | self.model_conf = model_conf 20 | self.inputs = inputs 21 | self.utils = utils 22 | self.layer = None 23 | 24 | def build(self): 25 | """ 26 | 循环层构建参数 27 | :return: 返回循环层的输出层 28 | """ 29 | with tf.keras.backend.name_scope('LSTM'): 30 | mask = tf.keras.layers.Masking()(self.inputs) 31 | self.layer = tf.keras.layers.LSTM( 32 | units=self.model_conf.units_num * 2, 33 | return_sequences=True, 34 | input_shape=mask.shape, 35 | dropout=0.2, 36 | recurrent_dropout=0.1 37 | ) 38 | outputs = self.layer(mask, training=self.utils.is_training) 39 | return outputs 40 | 41 | 42 | class BiLSTM(object): 43 | 44 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 45 | """同上""" 46 | self.model_conf = model_conf 47 | self.inputs = inputs 48 | self.utils = utils 49 | self.layer = None 50 | 51 | def build(self): 52 | """同上""" 53 | with tf.keras.backend.name_scope('BiLSTM'): 54 | mask = tf.keras.layers.Masking()(self.inputs) 55 | self.layer = tf.keras.layers.Bidirectional( 56 | layer=tf.keras.layers.LSTM( 57 | units=self.model_conf.units_num, 58 | return_sequences=True, 59 | ), 60 | input_shape=mask.shape, 61 | ) 62 | outputs = self.layer(mask, training=self.utils.is_training) 63 | return outputs 64 | 65 | 66 | class LSTMcuDNN(object): 67 | 68 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 69 | """同上""" 70 | self.model_conf = model_conf 71 | self.inputs = inputs 72 | self.utils = utils 73 | self.layer = None 74 | 75 | def build(self): 76 | """同上""" 77 | with tf.keras.backend.name_scope('LSTM'): 78 | self.layer = tf.keras.layers.CuDNNLSTM( 79 | units=self.model_conf.units_num * 2, 80 | return_sequences=True, 81 | ) 82 | outputs = self.layer(self.inputs, training=self.utils.is_training) 83 | return outputs 84 | 85 | 86 | class BiLSTMcuDNN(object): 87 | 88 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 89 | """同上""" 90 | self.model_conf = model_conf 91 | self.inputs = inputs 92 | self.utils = utils 93 | self.layer = None 94 | 95 | def build(self): 96 | """同上""" 97 | with tf.keras.backend.name_scope('BiLSTM'): 98 | self.layer = tf.keras.layers.Bidirectional( 99 | layer=tf.keras.layers.CuDNNLSTM( 100 | units=self.model_conf.units_num, 101 | return_sequences=True 102 | ) 103 | ) 104 | outputs = self.layer(self.inputs, training=self.utils.is_training) 105 | return outputs 106 | -------------------------------------------------------------------------------- /network/MobileNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | # This network was temporarily suspended 5 | import tensorflow as tf 6 | from network.utils import NetworkUtils 7 | from config import ModelConfig 8 | 9 | 10 | class MobileNetV2(object): 11 | 12 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 13 | self.model_conf = model_conf 14 | self.inputs = inputs 15 | self.utils = utils 16 | self.loss_func = self.model_conf.loss_func 17 | self.last_block_filters = 1280 18 | self.padding = "SAME" 19 | 20 | def first_layer(self, inputs): 21 | x = tf.keras.layers.Conv2D( 22 | filters=32, 23 | kernel_size=(3, 3), 24 | strides=(2, 2), 25 | padding='same', 26 | kernel_initializer='he_normal', 27 | name='conv1')(inputs) 28 | x = tf.layers.batch_normalization( 29 | x, 30 | reuse=False, 31 | momentum=0.9, 32 | training=self.utils.is_training 33 | ) 34 | # x = self.utils.BatchNormalization(name='bn_conv1', momentum=0.999)(x, training=self.utils.is_training) 35 | x = tf.keras.layers.LeakyReLU(0.01)(x) 36 | 37 | return x 38 | 39 | def pwise_block(self, inputs): 40 | x = tf.keras.layers.Conv2D( 41 | self.last_block_filters, 42 | kernel_size=1, 43 | use_bias=False, 44 | name='Conv_1')(inputs) 45 | x = tf.layers.batch_normalization( 46 | x, 47 | reuse=False, 48 | momentum=0.9, 49 | training=self.utils.is_training 50 | ) 51 | 52 | x = tf.keras.layers.ReLU(6., name='out_relu')(x) 53 | return x 54 | 55 | def build(self): 56 | 57 | with tf.keras.backend.name_scope('MobileNetV2'): 58 | 59 | x = self.first_layer(self.inputs) 60 | 61 | x = self.utils.inverted_res_block(x, filters=16, stride=1, expansion=1, block_id=0) 62 | 63 | x = self.utils.inverted_res_block(x, filters=24, stride=2, expansion=6, block_id=1) 64 | x = self.utils.inverted_res_block(x, filters=24, stride=1, expansion=6, block_id=2) 65 | 66 | x = self.utils.inverted_res_block(x, filters=32, stride=2, expansion=6, block_id=3) 67 | x = self.utils.inverted_res_block(x, filters=32, stride=1, expansion=6, block_id=4) 68 | x = self.utils.inverted_res_block(x, filters=32, stride=1, expansion=6, block_id=5) 69 | 70 | x = self.utils.inverted_res_block(x, filters=64, stride=2, expansion=6, block_id=6) 71 | x = self.utils.inverted_res_block(x, filters=64, stride=1, expansion=6, block_id=7) 72 | x = self.utils.inverted_res_block(x, filters=64, stride=1, expansion=6, block_id=8) 73 | x = self.utils.inverted_res_block(x, filters=64, stride=1, expansion=6, block_id=9) 74 | 75 | x = self.utils.inverted_res_block(x, filters=96, stride=1, expansion=6, block_id=10) 76 | x = self.utils.inverted_res_block(x, filters=96, stride=1, expansion=6, block_id=11) 77 | x = self.utils.inverted_res_block(x, filters=96, stride=1, expansion=6, block_id=12) 78 | 79 | x = self.utils.inverted_res_block(x, filters=160, stride=2, expansion=6, block_id=13) 80 | x = self.utils.inverted_res_block(x, filters=160, stride=1, expansion=6, block_id=14) 81 | x = self.utils.inverted_res_block(x, filters=160, stride=1, expansion=6, block_id=15) 82 | 83 | x = self.utils.inverted_res_block(x, filters=320, stride=1, expansion=6, block_id=16) 84 | 85 | x = self.pwise_block(x) 86 | 87 | shape_list = x.get_shape().as_list() 88 | print("x.get_shape()", shape_list) 89 | 90 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 91 | -------------------------------------------------------------------------------- /network/ResNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | import tensorflow as tf 6 | from network.utils import NetworkUtils 7 | from config import ModelConfig 8 | 9 | 10 | class ResNetUtils(object): 11 | 12 | def __init__(self, utils: NetworkUtils): 13 | self.utils = utils 14 | 15 | def first_layer(self, inputs): 16 | x = tf.keras.layers.Conv2D( 17 | filters=64, 18 | kernel_size=(7, 7), 19 | strides=(2, 2), 20 | padding='same', 21 | kernel_initializer='he_normal', 22 | name='conv1')(inputs) 23 | x = tf.layers.batch_normalization( 24 | x, 25 | reuse=False, 26 | momentum=0.9, 27 | training=self.utils.is_training, 28 | name='bn_conv1', 29 | ) 30 | x = tf.keras.layers.LeakyReLU(0.01)(x) 31 | x = tf.keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same',)(x) 32 | return x 33 | 34 | 35 | class ResNet50(object): 36 | 37 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 38 | self.model_conf = model_conf 39 | self.inputs = inputs 40 | self.utils = utils 41 | self.loss_func = self.model_conf.loss_func 42 | 43 | def build(self): 44 | 45 | with tf.keras.backend.name_scope('ResNet50'): 46 | x = ResNetUtils(self.utils).first_layer(self.inputs) 47 | x = self.utils.residual_building_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 48 | x = self.utils.identity_block(x, 3, [64, 64, 256], stage=2, block='b') 49 | x = self.utils.identity_block(x, 3, [64, 64, 256], stage=2, block='c') 50 | 51 | x = self.utils.residual_building_block(x, 3, [128, 128, 512], stage=3, block='a') 52 | x = self.utils.identity_block(x, 3, [128, 128, 512], stage=3, block='b') 53 | x = self.utils.identity_block(x, 3, [128, 128, 512], stage=3, block='c') 54 | x = self.utils.identity_block(x, 3, [128, 128, 512], stage=3, block='d') 55 | 56 | x = self.utils.residual_building_block(x, 3, [256, 256, 1024], stage=4, block='a') 57 | x = self.utils.identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 58 | x = self.utils.identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 59 | x = self.utils.identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 60 | x = self.utils.identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 61 | x = self.utils.identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 62 | 63 | x = self.utils.residual_building_block(x, 3, [512, 512, 2048], stage=5, block='a', strides=(1, 1)) 64 | x = self.utils.identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 65 | x = self.utils.identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 66 | 67 | print("x.get_shape()", x.get_shape()) 68 | shape_list = x.get_shape().as_list() 69 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 70 | 71 | 72 | class ResNetTiny(object): 73 | 74 | def __init__(self, model_conf: ModelConfig, inputs: tf.Tensor, utils: NetworkUtils): 75 | self.model_conf = model_conf 76 | self.inputs = inputs 77 | self.utils = utils 78 | self.loss_func = self.model_conf.loss_func 79 | 80 | def build(self): 81 | 82 | with tf.keras.backend.name_scope('ResNetTiny'): 83 | x = ResNetUtils(self.utils).first_layer(self.inputs) 84 | x = self.utils.residual_building_block(x, 3, [64, 64, 128], stage=2, block='a', strides=(1, 1), s2=False) 85 | x = self.utils.identity_block(x, 3, [64, 64, 128], stage=2, block='b') 86 | 87 | x = self.utils.residual_building_block(x, 3, [128, 128, 256], stage=3, block='a', s1=False, s2=False) 88 | x = self.utils.identity_block(x, 3, [128, 128, 256], stage=3, block='b') 89 | 90 | x = self.utils.residual_building_block(x, 3, [256, 256, 512], stage=4, block='a', s1=False, s2=False) 91 | x = self.utils.identity_block(x, 3, [256, 256, 512], stage=4, block='b') 92 | 93 | x = self.utils.residual_building_block(x, 3, [512, 512, 1024], stage=5, block='a', strides=(1, 1), s1=False) 94 | x = self.utils.identity_block(x, 3, [512, 512, 1024], stage=5, block='b') 95 | 96 | shape_list = x.get_shape().as_list() 97 | print("x.get_shape()", shape_list) 98 | 99 | return self.utils.reshape_layer(x, self.loss_func, shape_list) 100 | -------------------------------------------------------------------------------- /optimizer/AdaBound.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from distutils.version import StrictVersion 6 | from tensorflow.python.eager import context 7 | from tensorflow.python.framework import ops 8 | from tensorflow.python.ops import control_flow_ops 9 | from tensorflow.python.ops import math_ops 10 | from tensorflow.python.ops import resource_variable_ops 11 | from tensorflow.python.ops import state_ops 12 | from tensorflow.python.ops import variable_scope 13 | from tensorflow.python.training import optimizer 14 | from tensorflow.python.ops.clip_ops import clip_by_value 15 | 16 | """Implements AdaBound algorithm. 17 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 18 | Arguments: 19 | params (iterable): iterable of parameters to optimize or dicts defining 20 | parameter groups 21 | lr (float, optional): Adam learning rate (default: 1e-3) 22 | betas (Tuple[float, float], optional): coefficients used for computing 23 | running averages of gradient and its square (default: (0.9, 0.999)) 24 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 25 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 26 | eps (float, optional): term added to the denominator to improve 27 | numerical stability (default: 1e-8) 28 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 29 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 30 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 31 | https://openreview.net/forum?id=Bkg3g2R9FX 32 | """ 33 | 34 | 35 | class AdaBoundOptimizer(optimizer.Optimizer): 36 | def __init__(self, learning_rate=0.001, final_lr=0.1, beta1=0.9, beta2=0.999, 37 | gamma=1e-3, epsilon=1e-8, amsbound=False, 38 | use_locking=False, name="AdaBound"): 39 | super(AdaBoundOptimizer, self).__init__(use_locking, name) 40 | self._lr = learning_rate 41 | self._final_lr = final_lr 42 | self._beta1 = beta1 43 | self._beta2 = beta2 44 | self._epsilon = epsilon 45 | 46 | self._gamma = gamma 47 | self._amsbound = amsbound 48 | 49 | self._lr_t = None 50 | self._beta1_t = None 51 | self._beta2_t = None 52 | self._epsilon_t = None 53 | 54 | def _create_slots(self, var_list): 55 | first_var = min(var_list, key=lambda x: x.name) 56 | if StrictVersion(tf.__version__) >= StrictVersion('1.10.0'): 57 | graph = None if context.executing_eagerly() else ops.get_default_graph() 58 | else: 59 | graph = ops.get_default_graph() 60 | create_new = self._get_non_slot_variable("beta1_power", graph) is None 61 | if not create_new and context.in_graph_mode(): 62 | create_new = (self._get_non_slot_variable("beta1_power", graph).graph is not first_var.graph) 63 | 64 | if create_new: 65 | self._create_non_slot_variable(initial_value=self._beta1, 66 | name="beta1_power", 67 | colocate_with=first_var) 68 | self._create_non_slot_variable(initial_value=self._beta2, 69 | name="beta2_power", 70 | colocate_with=first_var) 71 | self._create_non_slot_variable(initial_value=self._gamma, 72 | name="gamma_multi", 73 | colocate_with=first_var) 74 | # Create slots for the first and second moments. 75 | for v in var_list : 76 | self._zeros_slot(v, "m", self._name) 77 | self._zeros_slot(v, "v", self._name) 78 | self._zeros_slot(v, "vhat", self._name) 79 | 80 | def _prepare(self): 81 | self._lr_t = ops.convert_to_tensor(self._lr) 82 | self._base_lr_t = ops.convert_to_tensor(self._lr) 83 | self._beta1_t = ops.convert_to_tensor(self._beta1) 84 | self._beta2_t = ops.convert_to_tensor(self._beta2) 85 | self._epsilon_t = ops.convert_to_tensor(self._epsilon) 86 | self._gamma_t = ops.convert_to_tensor(self._gamma) 87 | 88 | def _apply_dense(self, grad, var): 89 | if StrictVersion(tf.__version__) >= StrictVersion('1.10.0'): 90 | graph = None if context.executing_eagerly() else ops.get_default_graph() 91 | else: 92 | graph = ops.get_default_graph() 93 | beta1_power = math_ops.cast(self._get_non_slot_variable("beta1_power", graph=graph), var.dtype.base_dtype) 94 | beta2_power = math_ops.cast(self._get_non_slot_variable("beta2_power", graph=graph), var.dtype.base_dtype) 95 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 96 | base_lr_t = math_ops.cast(self._base_lr_t, var.dtype.base_dtype) 97 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 98 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 99 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 100 | gamma_multi = math_ops.cast(self._get_non_slot_variable("gamma_multi", graph=graph), var.dtype.base_dtype) 101 | 102 | step_size = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 103 | final_lr = self._final_lr * lr_t / base_lr_t 104 | lower_bound = final_lr * (1. - 1. / (gamma_multi + 1.)) 105 | upper_bound = final_lr * (1. + 1. / (gamma_multi)) 106 | 107 | # m_t = beta1 * m + (1 - beta1) * g_t 108 | m = self.get_slot(var, "m") 109 | m_scaled_g_values = grad * (1 - beta1_t) 110 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 111 | 112 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 113 | v = self.get_slot(var, "v") 114 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 115 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 116 | 117 | # amsgrad 118 | vhat = self.get_slot(var, "vhat") 119 | if self._amsbound : 120 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 121 | v_sqrt = math_ops.sqrt(vhat_t) 122 | else: 123 | vhat_t = state_ops.assign(vhat, vhat) 124 | v_sqrt = math_ops.sqrt(v_t) 125 | 126 | # Compute the bounds 127 | step_size_bound = step_size / (v_sqrt + epsilon_t) 128 | bounded_lr = m_t * clip_by_value(step_size_bound, lower_bound, upper_bound) 129 | 130 | var_update = state_ops.assign_sub(var, bounded_lr, use_locking=self._use_locking) 131 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 132 | 133 | def _resource_apply_dense(self, grad, var): 134 | if StrictVersion(tf.__version__) >= StrictVersion('1.10.0'): 135 | graph = None if context.executing_eagerly() else ops.get_default_graph() 136 | else: 137 | graph = ops.get_default_graph() 138 | beta1_power = math_ops.cast(self._get_non_slot_variable("beta1_power", graph=graph), grad.dtype.base_dtype) 139 | beta2_power = math_ops.cast(self._get_non_slot_variable("beta2_power", graph=graph), grad.dtype.base_dtype) 140 | lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype) 141 | base_lr_t = math_ops.cast(self._base_lr_t, var.dtype.base_dtype) 142 | beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype) 143 | beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype) 144 | epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype) 145 | gamma_multi = math_ops.cast(self._get_non_slot_variable("gamma_multi", graph=graph), var.dtype.base_dtype) 146 | 147 | step_size = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 148 | final_lr = self._final_lr * lr_t / base_lr_t 149 | lower_bound = final_lr * (1. - 1. / (gamma_multi + 1.)) 150 | upper_bound = final_lr * (1. + 1. / (gamma_multi)) 151 | 152 | # m_t = beta1 * m + (1 - beta1) * g_t 153 | m = self.get_slot(var, "m") 154 | m_scaled_g_values = grad * (1 - beta1_t) 155 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 156 | 157 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 158 | v = self.get_slot(var, "v") 159 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 160 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 161 | 162 | # amsgrad 163 | vhat = self.get_slot(var, "vhat") 164 | if self._amsbound: 165 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 166 | v_sqrt = math_ops.sqrt(vhat_t) 167 | else: 168 | vhat_t = state_ops.assign(vhat, vhat) 169 | v_sqrt = math_ops.sqrt(v_t) 170 | 171 | # Compute the bounds 172 | step_size_bound = step_size / (v_sqrt + epsilon_t) 173 | bounded_lr = m_t * clip_by_value(step_size_bound, lower_bound, upper_bound) 174 | 175 | var_update = state_ops.assign_sub(var, bounded_lr, use_locking=self._use_locking) 176 | 177 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 178 | 179 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 180 | if StrictVersion(tf.__version__) >= StrictVersion('1.10.0'): 181 | graph = None if context.executing_eagerly() else ops.get_default_graph() 182 | else: 183 | graph = ops.get_default_graph() 184 | beta1_power = math_ops.cast(self._get_non_slot_variable("beta1_power", graph=graph), var.dtype.base_dtype) 185 | beta2_power = math_ops.cast(self._get_non_slot_variable("beta2_power", graph=graph), var.dtype.base_dtype) 186 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 187 | base_lr_t = math_ops.cast(self._base_lr_t, var.dtype.base_dtype) 188 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 189 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 190 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 191 | gamma_t = math_ops.cast(self._gamma_t, var.dtype.base_dtype) 192 | 193 | step_size = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 194 | final_lr = self._final_lr * lr_t / base_lr_t 195 | lower_bound = final_lr * (1. - 1. / (gamma_t + 1.)) 196 | upper_bound = final_lr * (1. + 1. / (gamma_t)) 197 | 198 | # m_t = beta1 * m + (1 - beta1) * g_t 199 | m = self.get_slot(var, "m") 200 | m_scaled_g_values = grad * (1 - beta1_t) 201 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 202 | with ops.control_dependencies([m_t]): 203 | m_t = scatter_add(m, indices, m_scaled_g_values) 204 | 205 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 206 | v = self.get_slot(var, "v") 207 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 208 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 209 | with ops.control_dependencies([v_t]): 210 | v_t = scatter_add(v, indices, v_scaled_g_values) 211 | 212 | # amsgrad 213 | vhat = self.get_slot(var, "vhat") 214 | if self._amsbound: 215 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 216 | v_sqrt = math_ops.sqrt(vhat_t) 217 | else: 218 | vhat_t = state_ops.assign(vhat, vhat) 219 | v_sqrt = math_ops.sqrt(v_t) 220 | 221 | # Compute the bounds 222 | step_size_bound = step_size / (v_sqrt + epsilon_t) 223 | bounded_lr = m_t * clip_by_value(step_size_bound, lower_bound, upper_bound) 224 | 225 | var_update = state_ops.assign_sub(var, bounded_lr, use_locking=self._use_locking) 226 | 227 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 228 | 229 | def _apply_sparse(self, grad, var): 230 | return self._apply_sparse_shared( 231 | grad.values, var, grad.indices, 232 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 233 | x, i, v, use_locking=self._use_locking)) 234 | 235 | def _resource_scatter_add(self, x, i, v): 236 | with ops.control_dependencies( 237 | [resource_variable_ops.resource_scatter_add(x, i, v)]): 238 | return x.value() 239 | 240 | def _resource_apply_sparse(self, grad, var, indices): 241 | return self._apply_sparse_shared( 242 | grad, var, indices, self._resource_scatter_add) 243 | 244 | def _finish(self, update_ops, name_scope): 245 | # Update the power accumulators. 246 | with ops.control_dependencies(update_ops): 247 | if StrictVersion(tf.__version__) >= StrictVersion('1.10.0'): 248 | graph = None if context.executing_eagerly() else ops.get_default_graph() 249 | else: 250 | graph = ops.get_default_graph() 251 | beta1_power = self._get_non_slot_variable("beta1_power", graph=graph) 252 | beta2_power = self._get_non_slot_variable("beta2_power", graph=graph) 253 | gamma_multi = self._get_non_slot_variable("gamma_multi", graph=graph) 254 | with ops.colocate_with(beta1_power): 255 | update_beta1 = beta1_power.assign( 256 | beta1_power * self._beta1_t, 257 | use_locking=self._use_locking) 258 | update_beta2 = beta2_power.assign( 259 | beta2_power * self._beta2_t, 260 | use_locking=self._use_locking) 261 | update_gamma = gamma_multi.assign( 262 | gamma_multi + self._gamma_t, 263 | use_locking=self._use_locking) 264 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2, update_gamma], name=name_scope) 265 | -------------------------------------------------------------------------------- /optimizer/RAdam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | from tensorflow.python.eager import context 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.ops import control_flow_ops 8 | from tensorflow.python.ops import math_ops, state_ops, array_ops 9 | from tensorflow.python.ops import resource_variable_ops 10 | from tensorflow.python.training import optimizer 11 | 12 | 13 | __all__ = ['RAdamOptimizer'] 14 | 15 | 16 | class RAdamOptimizer(optimizer.Optimizer): 17 | """RAdam optimizer. 18 | According to the paper 19 | [On The Variance Of The Adaptive Learning Rate And Beyond](https://arxiv.org/pdf/1908.03265v1.pdf). 20 | """ 21 | 22 | def __init__(self, 23 | learning_rate=0.001, 24 | beta1=0.9, 25 | beta2=0.999, 26 | epsilon=1e-7, 27 | weight_decay=0., 28 | amsgrad=False, 29 | total_steps=0, 30 | warmup_proportion=0.1, 31 | min_lr=0., 32 | use_locking=False, 33 | name="RAdam"): 34 | r"""Construct a new Adam optimizer. 35 | Args: 36 | learning_rate: A Tensor or a floating point value. The learning rate. 37 | beta1: A float value or a constant float tensor. The exponential decay 38 | rate for the 1st moment estimates. 39 | beta2: A float value or a constant float tensor. The exponential decay 40 | rate for the 2nd moment estimates. 41 | epsilon: A small constant for numerical stability. This epsilon is 42 | "epsilon hat" in the Kingma and Ba paper (in the formula just before 43 | Section 2.1), not the epsilon in Algorithm 1 of the paper. 44 | weight_decay: A floating point value. Weight decay for each param. 45 | amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from 46 | the paper "On the Convergence of Adam and beyond". 47 | total_steps: An integer. Total number of training steps. 48 | Enable warmup by setting a positive value. 49 | warmup_proportion: A floating point value. The proportion of increasing steps. 50 | min_lr: A floating point value. Minimum learning rate after warmup. 51 | name: Optional name for the operations created when applying gradients. 52 | Defaults to "Adam". @compatibility(eager) When eager execution is 53 | enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be 54 | a callable that takes no arguments and returns the actual value to use. 55 | This can be useful for changing these values across different 56 | invocations of optimizer functions. @end_compatibility 57 | **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 58 | `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 59 | gradients by value, `decay` is included for backward compatibility to 60 | allow time inverse decay of learning rate. `lr` is included for backward 61 | compatibility, recommended to use `learning_rate` instead. 62 | """ 63 | super(RAdamOptimizer, self).__init__(use_locking, name) 64 | self._lr = learning_rate 65 | self._beta1 = beta1 66 | self._beta2 = beta2 67 | self._epsilon = epsilon 68 | self._weight_decay = weight_decay 69 | self._amsgrad = amsgrad 70 | self._total_steps = float(total_steps) 71 | self._warmup_proportion = warmup_proportion 72 | self._min_lr = min_lr 73 | self._initial_weight_decay = weight_decay 74 | self._initial_total_steps = total_steps 75 | 76 | self._lr_t = None 77 | self._step_t = None 78 | self._beta1_t = None 79 | self._beta2_t = None 80 | self._epsilon_t = None 81 | self._weight_decay_t = None 82 | self._total_steps_t = None 83 | self._warmup_proportion_t = None 84 | self._min_lr_t = None 85 | 86 | def _get_beta_accumulators(self): 87 | with ops.init_scope(): 88 | if context.executing_eagerly(): 89 | graph = None 90 | else: 91 | graph = ops.get_default_graph() 92 | return (self._get_non_slot_variable("step", graph=graph), 93 | self._get_non_slot_variable("beta1_power", graph=graph), 94 | self._get_non_slot_variable("beta2_power", graph=graph)) 95 | 96 | def _create_slots(self, var_list): 97 | first_var = min(var_list, key=lambda x: x.name) 98 | self._create_non_slot_variable(initial_value=1.0, name="step", colocate_with=first_var) 99 | self._create_non_slot_variable(initial_value=self._beta1, name="beta1_power", colocate_with=first_var) 100 | self._create_non_slot_variable(initial_value=self._beta2, name="beta2_power", colocate_with=first_var) 101 | for v in var_list: 102 | self._zeros_slot(v, "m", self._name) 103 | self._zeros_slot(v, "v", self._name) 104 | if self._amsgrad: 105 | self._zeros_slot(v, "vhat", self._name) 106 | 107 | def _prepare(self): 108 | lr = self._call_if_callable(self._lr) 109 | beta1 = self._call_if_callable(self._beta1) 110 | beta2 = self._call_if_callable(self._beta2) 111 | epsilon = self._call_if_callable(self._epsilon) 112 | weight_decay = self._call_if_callable(self._weight_decay) 113 | total_steps = self._call_if_callable(self._total_steps) 114 | warmup_proportion = self._call_if_callable(self._warmup_proportion) 115 | min_lr = self._call_if_callable(self._min_lr) 116 | 117 | self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") 118 | self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") 119 | self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") 120 | self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") 121 | self._weight_decay_t = ops.convert_to_tensor(weight_decay, name="weight_decay") 122 | self._total_steps_t = ops.convert_to_tensor(total_steps, name="total_steps") 123 | self._warmup_proportion_t = ops.convert_to_tensor(warmup_proportion, name="warmup_proportion") 124 | self._min_lr_t = ops.convert_to_tensor(min_lr, name="min_lr") 125 | 126 | def _apply_dense(self, grad, var): 127 | return self._resource_apply_dense(grad, var) 128 | 129 | def _resource_apply_dense(self, grad, var): 130 | step, beta1_power, beta2_power = self._get_beta_accumulators() 131 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 132 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 133 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 134 | 135 | if self._initial_total_steps > 0: 136 | total_steps = math_ops.cast(self._total_steps_t, var.dtype.base_dtype) 137 | warmup_proportion = math_ops.cast(self._warmup_proportion_t, var.dtype.base_dtype) 138 | min_lr = math_ops.cast(self._min_lr_t, var.dtype.base_dtype) 139 | warmup_steps = total_steps * warmup_proportion 140 | decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) 141 | decay_rate = (min_lr - lr_t) / decay_steps 142 | lr_t = tf.where( 143 | step <= warmup_steps, 144 | lr_t * (step / warmup_steps), 145 | lr_t + decay_rate * math_ops.minimum(step - warmup_steps, decay_steps), 146 | ) 147 | 148 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 149 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 150 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 151 | 152 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 153 | sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) 154 | 155 | m = self.get_slot(var, "m") 156 | m_t = state_ops.assign(m, beta1_t * m + (1.0 - beta1_t) * grad, use_locking=self._use_locking) 157 | m_corr_t = m_t / (1.0 - beta1_power) 158 | 159 | v = self.get_slot(var, "v") 160 | v_t = state_ops.assign(v, beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) 161 | if self._amsgrad: 162 | vhat = self.get_slot(var, 'vhat') 163 | vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) 164 | v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power)) 165 | else: 166 | v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power)) 167 | 168 | r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * 169 | (sma_t - 2.0) / (sma_inf - 2.0) * 170 | sma_inf / sma_t) 171 | 172 | var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) 173 | 174 | if self._initial_weight_decay > 0.0: 175 | var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var 176 | 177 | var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) 178 | 179 | updates = [var_update, m_t, v_t] 180 | if self._amsgrad: 181 | updates.append(vhat_t) 182 | return control_flow_ops.group(*updates) 183 | 184 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 185 | step, beta1_power, beta2_power = self._get_beta_accumulators() 186 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 187 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 188 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 189 | 190 | if self._initial_total_steps > 0: 191 | total_steps = math_ops.cast(self._total_steps_t, var.dtype.base_dtype) 192 | warmup_proportion = math_ops.cast(self._warmup_proportion_t, var.dtype.base_dtype) 193 | min_lr = math_ops.cast(self._min_lr_t, var.dtype.base_dtype) 194 | warmup_steps = total_steps * warmup_proportion 195 | decay_steps = math_ops.maximum(total_steps - warmup_steps, 1) 196 | decay_rate = (min_lr - lr_t) / decay_steps 197 | lr_t = tf.where( 198 | step <= warmup_steps, 199 | lr_t * (step / warmup_steps), 200 | lr_t + decay_rate * math_ops.minimum(step - warmup_steps, decay_steps), 201 | ) 202 | 203 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 204 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 205 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 206 | 207 | sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 208 | sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) 209 | 210 | m = self.get_slot(var, "m") 211 | m_scaled_g_values = grad * (1 - beta1_t) 212 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 213 | with ops.control_dependencies([m_t]): 214 | m_t = scatter_add(m, indices, m_scaled_g_values) 215 | m_corr_t = m_t / (1.0 - beta1_power) 216 | 217 | v = self.get_slot(var, "v") 218 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 219 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 220 | with ops.control_dependencies([v_t]): 221 | v_t = scatter_add(v, indices, v_scaled_g_values) 222 | if self._amsgrad: 223 | vhat = self.get_slot(var, 'vhat') 224 | vhat_t = state_ops.assign(vhat, math_ops.maximum(vhat, v_t), use_locking=self._use_locking) 225 | v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power)) 226 | else: 227 | v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power)) 228 | 229 | r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * 230 | (sma_t - 2.0) / (sma_inf - 2.0) * 231 | sma_inf / sma_t) 232 | 233 | var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t) 234 | 235 | if self._initial_weight_decay > 0.0: 236 | var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var 237 | 238 | var_t = lr_t * var_t 239 | var_update = state_ops.scatter_sub( 240 | var, 241 | indices, 242 | array_ops.gather(var_t, indices), 243 | use_locking=self._use_locking) 244 | 245 | updates = [var_update, m_t, v_t] 246 | if self._amsgrad: 247 | updates.append(vhat_t) 248 | return control_flow_ops.group(*updates) 249 | 250 | def _apply_sparse(self, grad, var): 251 | return self._apply_sparse_shared( 252 | grad.values, 253 | var, 254 | grad.indices, 255 | lambda x, i, v: state_ops.scatter_add(x, i, v, use_locking=self._use_locking)) 256 | 257 | def _resource_scatter_add(self, x, i, v): 258 | with ops.control_dependencies([resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 259 | return x.value() 260 | 261 | def _resource_apply_sparse(self, grad, var, indices): 262 | return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add) 263 | 264 | def _finish(self, update_ops, name_scope): 265 | with ops.control_dependencies(update_ops): 266 | step, beta1_power, beta2_power = self._get_beta_accumulators() 267 | with ops.colocate_with(beta1_power): 268 | update_step = step.assign(step + 1.0, use_locking=self._use_locking) 269 | update_beta1 = beta1_power.assign(beta1_power * self._beta1_t, use_locking=self._use_locking) 270 | update_beta2 = beta2_power.assign(beta2_power * self._beta2_t, use_locking=self._use_locking) 271 | return control_flow_ops.group(*update_ops + [update_step, update_beta1, update_beta2], name=name_scope) -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /predict_testing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | """此脚本用于训练过程中检验训练效果的脚本,功能为:通过启动参数加载【工程名】中的网络进行预测""" 5 | import random 6 | import numpy as np 7 | import tensorflow as tf 8 | from config import * 9 | from constants import RunMode 10 | from encoder import Encoder 11 | from core import NeuralNetwork 12 | 13 | # argv = sys.argv[1] 14 | 15 | 16 | class Predict: 17 | def __init__(self, project_name): 18 | self.model_conf = ModelConfig(project_name=project_name) 19 | self.encoder = Encoder(model_conf=self.model_conf, mode=RunMode.Predict) 20 | 21 | def get_image_batch(self, img_bytes): 22 | if not img_bytes: 23 | return [] 24 | return [self.encoder.image(index) for index in [img_bytes]] 25 | 26 | @staticmethod 27 | def decode_maps(categories): 28 | """解码器""" 29 | return {index: category for index, category in enumerate(categories, 0)} 30 | 31 | def predict_func(self, image_batch, _sess, dense_decoded, op_input): 32 | """预测函数""" 33 | dense_decoded_code = _sess.run(dense_decoded, feed_dict={ 34 | op_input: image_batch, 35 | }) 36 | # print(dense_decoded_code) 37 | decoded_expression = [] 38 | for item in dense_decoded_code: 39 | expression = '' 40 | # print(item) 41 | if isinstance(item, int) or isinstance(item, np.int64): 42 | item = [item] 43 | for class_index in item: 44 | if class_index == -1 or class_index == self.model_conf.category_num: 45 | expression += '' 46 | else: 47 | expression += self.decode_maps(self.model_conf.category)[class_index] 48 | decoded_expression.append(expression) 49 | return ''.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0] 50 | 51 | def testing(self, image_dir, limit=None): 52 | 53 | graph = tf.Graph() 54 | sess = tf.compat.v1.Session( 55 | graph=graph, 56 | config=tf.compat.v1.ConfigProto( 57 | # allow_soft_placement=True, 58 | # log_device_placement=True, 59 | gpu_options=tf.compat.v1.GPUOptions( 60 | allocator_type='BFC', 61 | # allow_growth=True, # it will cause fragmentation. 62 | per_process_gpu_memory_fraction=0.1 63 | )) 64 | ) 65 | 66 | with sess.graph.as_default(): 67 | 68 | sess.run(tf.compat.v1.global_variables_initializer()) 69 | # tf.keras.backend.set_session(session=sess) 70 | 71 | model = NeuralNetwork( 72 | self.model_conf, 73 | RunMode.Predict, 74 | self.model_conf.neu_cnn, 75 | self.model_conf.neu_recurrent 76 | ) 77 | model.build_graph() 78 | model.build_train_op() 79 | 80 | saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables()) 81 | 82 | """从项目中加载最后一次训练的网络参数""" 83 | saver.restore(sess, tf.train.latest_checkpoint(self.model_conf.model_root_path)) 84 | # model.build_graph() 85 | # _ = tf.import_graph_def(graph_def, name="") 86 | 87 | """定义操作符""" 88 | dense_decoded_op = sess.graph.get_tensor_by_name("dense_decoded:0") 89 | x_op = sess.graph.get_tensor_by_name('input:0') 90 | """固定网络""" 91 | sess.graph.finalize() 92 | 93 | true_count = 0 94 | false_count = 0 95 | """ 96 | 以下为根据路径调用预测函数输出结果的demo 97 | """ 98 | # Fill in your own sample path 99 | dir_list = os.listdir(image_dir) 100 | random.shuffle(dir_list) 101 | lines = [] 102 | for i, p in enumerate(dir_list): 103 | n = os.path.join(image_dir, p) 104 | if limit and i > limit: 105 | break 106 | with open(n, "rb") as f: 107 | b = f.read() 108 | 109 | batch = self.get_image_batch(b) 110 | if not batch: 111 | continue 112 | st = time.time() 113 | predict_text = self.predict_func( 114 | batch, 115 | sess, 116 | dense_decoded_op, 117 | x_op, 118 | ) 119 | et = time.time() 120 | # t = p.split(".")[0].lower() == predict_text.lower() 121 | # csv_output = "{},{}".format(p.split(".")[0], predict_text) 122 | # lines.append(csv_output) 123 | # print(csv_output) 124 | # is_mark = '_' in p 125 | # p = p.replace("\\", "/") 126 | label = re.search(self.model_conf.extract_regex, p.split(PATH_SPLIT)[-1]) 127 | label = label.group() if label else p.split(".")[0] 128 | # if is_mark: 129 | if 'LOWER' in self.model_conf.category_param: 130 | label = label.lower() 131 | t = label == predict_text.lower() 132 | elif 'UPPER' in self.model_conf.category_param: 133 | label = label.upper() 134 | t = label == predict_text.upper() 135 | else: 136 | t = label == predict_text 137 | # Used to verify test sets 138 | if t: 139 | true_count += 1 140 | else: 141 | false_count += 1 142 | print(i, p, label, predict_text, t, true_count / (true_count + false_count), (et-st) * 1000) 143 | # else: 144 | # print(i, p, predict_text, true_count / (true_count + false_count), (et - st) * 1000) 145 | # with open("competition_format.csv", "w", encoding="utf8") as f: 146 | # f.write("\n".join(lines)) 147 | sess.close() 148 | 149 | 150 | if __name__ == '__main__': 151 | 152 | predict = Predict(project_name=sys.argv[1]) 153 | predict.testing(image_dir=r"H:\TrainSet\*", limit=None) 154 | 155 | 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | opencv-python-headless 3 | numpy 4 | pyyaml>=3.13 5 | tqdm 6 | colorama 7 | pyinstaller 8 | astor 9 | fonttools 10 | tensorflow-gpu -------------------------------------------------------------------------------- /resource/VERSION: -------------------------------------------------------------------------------- 1 | 20221023 -------------------------------------------------------------------------------- /resource/captcha_snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/captcha_snapshot.png -------------------------------------------------------------------------------- /resource/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/icon.ico -------------------------------------------------------------------------------- /resource/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/logo.png -------------------------------------------------------------------------------- /resource/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/main.png -------------------------------------------------------------------------------- /resource/net_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/net_structure.png -------------------------------------------------------------------------------- /resource/sample_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kerlomz/captcha_trainer/6fd35c0c789aaa43130de46d4c04622ec2948052/resource/sample_process.png -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz -------------------------------------------------------------------------------- /test/test_preprocessing_by_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | if __name__ == '__main__': 6 | import io 7 | import os 8 | import PIL.Image 9 | import hashlib 10 | import cv2 11 | import numpy as np 12 | from pretreatment import preprocessing_by_func 13 | 14 | src_color = "yellow" 15 | root_dir = r"H:\Samples\tax_gen\simulation\gen_yellow_2".format(src_color) 16 | target_dir = r"H:\Samples\tax_gen\gen\{}2red".format(src_color) 17 | if not os.path.exists(target_dir): 18 | os.makedirs(target_dir) 19 | 20 | for name in os.listdir(root_dir): 21 | label = name.split("_")[0] 22 | path = os.path.join(root_dir, name) 23 | with open(path, "rb") as f: 24 | path_or_bytes = f.read() 25 | path_or_stream = io.BytesIO(path_or_bytes) 26 | pil_image = PIL.Image.open(path_or_stream).convert("RGB") 27 | im = np.array(pil_image) 28 | im = preprocessing_by_func(exec_map={ 29 | "black": [ 30 | "$$target_arr[:, :, 2] = 255 - target_arr[:, :, 2]", 31 | ], 32 | "red": [], 33 | "yellow": [ 34 | "@@target_arr[:, :, (0, 0, 1)]", 35 | # "$$target_arr[:, :, 2] = 255 - target_arr[:, :, 2]", 36 | # "@@target_arr[:, :, (0, 2, 0)]", 37 | # "$$target_arr[:, :, 2] = 255 - target_arr[:, :, 2]", 38 | 39 | # "$$target_arr[:, :, 2] = 255 - target_arr[:, :, 2]", 40 | # "@@target_arr[:, :, (0, 2, 1)]", 41 | 42 | # "$$target_arr[:, :, 1] = 255 - target_arr[:, :, 1]", 43 | # "@@target_arr[:, :, (2, 1, 0)]", 44 | # "@@target_arr[:, :, (1, 2, 0)]", 45 | ], 46 | "blue": [ 47 | "@@target_arr[:, :, (1, 2, 0)]", 48 | ] 49 | }, 50 | src_arr=im, 51 | key=src_color 52 | ) 53 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 54 | cv_img = cv2.imencode('.png', im)[1] 55 | img_bytes = bytes(bytearray(cv_img)) 56 | # tag = hashlib.md5(img_bytes).hexdigest() 57 | tag = src_color 58 | new_name = "{}_{}.png".format(label, tag) 59 | new_path = os.path.join(target_dir, new_name) 60 | print(src_color, new_name) 61 | with open(new_path, "wb") as f: 62 | f.write(img_bytes) -------------------------------------------------------------------------------- /tf_onnx_util2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | # Copyright (c) Microsoft Corporation. All rights reserved. 5 | # Licensed under the MIT license. 6 | 7 | """ 8 | python -m tf2onnx.convert : tool to convert a frozen tensorflow graph to onnx 9 | """ 10 | 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | 15 | import argparse 16 | import sys 17 | 18 | import tensorflow as tf 19 | 20 | from tf2onnx.tfonnx import process_tf_graph, tf_optimize 21 | from tf2onnx import constants, logging, utils, optimizer 22 | from tf_graph_util import convert_variables_to_constants 23 | # from tensorflow.python.framework.graph_util import convert_variables_to_constants 24 | 25 | # pylint: disable=unused-argument 26 | 27 | _HELP_TEXT = """ 28 | Usage Examples: 29 | 30 | python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx 31 | python -m tf2onnx.convert --input frozen_graph.pb --inputs X:0 --outputs output:0 --output model.onnx 32 | python -m tf2onnx.convert --checkpoint checkpoint.meta --inputs X:0 --outputs output:0 --output model.onnx 33 | 34 | For help and additional information see: 35 | https://github.com/onnx/tensorflow-onnx 36 | 37 | If you run into issues, open an issue here: 38 | https://github.com/onnx/tensorflow-onnx/issues 39 | """ 40 | 41 | logger = tf.compat.v1.logging 42 | # logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME) 43 | 44 | 45 | def freeze_session(sess, keep_var_names=None, output_names=None, clear_devices=True): 46 | """Freezes the state of a session into a pruned computation graph.""" 47 | output_names = [i.split(':')[:-1][0] for i in output_names] 48 | graph = sess.graph 49 | with graph.as_default(): 50 | freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or [])) 51 | output_names = output_names or [] 52 | output_names += [v.op.name for v in tf.compat.v1.global_variables()] 53 | input_graph_def = graph.as_graph_def(add_shapes=True) 54 | if clear_devices: 55 | for node in input_graph_def.node: 56 | node.device = "" 57 | frozen_graph = convert_variables_to_constants(sess, input_graph_def, output_names, freeze_var_names) 58 | return frozen_graph 59 | 60 | 61 | def remove_redundant_inputs(frozen_graph, input_names): 62 | """Remove redundant inputs not in frozen graph.""" 63 | frozen_inputs = [] 64 | # get inputs in frozen graph 65 | for n in frozen_graph.node: 66 | for inp in input_names: 67 | if utils.node_name(inp) == n.name: 68 | frozen_inputs.append(inp) 69 | deleted_inputs = list(set(input_names) - set(frozen_inputs)) 70 | if deleted_inputs: 71 | logger.warning("inputs [%s] is not in frozen graph, delete them", ",".join(deleted_inputs)) 72 | return frozen_inputs 73 | 74 | 75 | def from_graphdef(sess, graph_def, model_path, input_names, output_names): 76 | """Load tensorflow graph from graphdef.""" 77 | # make sure we start with clean default graph 78 | with tf.io.gfile.GFile(model_path, 'rb') as f: 79 | graph_def.ParseFromString(f.read()) 80 | tf.import_graph_def(graph_def, name='') 81 | frozen_graph = freeze_session(sess, output_names=output_names) 82 | input_names = remove_redundant_inputs(frozen_graph, input_names) 83 | # clean up 84 | return frozen_graph, input_names, output_names 85 | 86 | 87 | def convert_onnx(sess, graph_def, input_path, inputs_op, outputs_op): 88 | 89 | graphdef = input_path 90 | 91 | if inputs_op: 92 | inputs_op, shape_override = utils.split_nodename_and_shape(inputs_op) 93 | if outputs_op: 94 | outputs_op = outputs_op.split(",") 95 | 96 | # logging.basicConfig(level=logging.get_verbosity_level(True)) 97 | 98 | utils.set_debug_mode(True) 99 | 100 | graph_def, inputs_op, outputs_op = from_graphdef(sess, graph_def, graphdef, inputs_op, outputs_op) 101 | model_path = graphdef 102 | 103 | graph_def = tf_optimize(inputs_op, outputs_op, graph_def, True) 104 | 105 | with tf.Graph().as_default() as tf_graph: 106 | tf.compat.v1.import_graph_def(graph_def, name='') 107 | with tf.compat.v1.Session(graph=tf_graph): 108 | g = process_tf_graph(tf_graph, 109 | continue_on_error=True, 110 | target=",".join(constants.DEFAULT_TARGET), 111 | opset=9, 112 | custom_op_handlers=None, 113 | extra_opset=None, 114 | shape_override=None, 115 | input_names=inputs_op, 116 | output_names=outputs_op, 117 | inputs_as_nchw=None) 118 | 119 | onnx_graph = optimizer.optimize_graph(g) 120 | model_proto = onnx_graph.make_model("converted from {}".format(model_path)) 121 | 122 | # write onnx graph 123 | logger.info("") 124 | logger.info("Successfully converted TensorFlow model %s to ONNX", model_path) 125 | # if args.output: 126 | output_path = input_path.replace(".pb", ".onnx") 127 | utils.save_protobuf(output_path, model_proto) 128 | logger.info("ONNX model is saved at %s", output_path) 129 | 130 | 131 | if __name__ == "__main__": 132 | 133 | model_path = r"E:\Workplaces\PythonProjects\captcha_trainer\projects\test-CNN3-GRU-H64-CTC-C1\out\graph\test-CNN3-GRU-H64-CTC-C1_0.pb" 134 | tf.compat.v1.disable_eager_execution() 135 | graph = tf.compat.v1.Graph() 136 | sess = tf.compat.v1.Session( 137 | graph=graph, 138 | config=tf.compat.v1.ConfigProto( 139 | 140 | # allow_soft_placement=True, 141 | # log_device_placement=True, 142 | gpu_options=tf.compat.v1.GPUOptions( 143 | # allocator_type='BFC', 144 | allow_growth=True, # it will cause fragmentation. 145 | # per_process_gpu_memory_fraction=self.model_conf.device_usage 146 | per_process_gpu_memory_fraction=0.1 147 | ) 148 | ) 149 | ) 150 | graph_def = graph.as_graph_def() 151 | with tf.io.gfile.GFile(model_path, "rb") as f: 152 | graph_def_file = f.read() 153 | graph_def.ParseFromString(graph_def_file) 154 | with graph.as_default(): 155 | sess.run(tf.compat.v1.global_variables_initializer()) 156 | _ = tf.import_graph_def(graph_def, name="") 157 | 158 | output_graph_def = convert_variables_to_constants( 159 | sess, 160 | graph_def, 161 | output_node_names=['dense_decoded'] 162 | ) 163 | 164 | def compile_onnx(path): 165 | convert_onnx( 166 | sess=sess, 167 | graph_def=output_graph_def, 168 | input_path=path, 169 | inputs_op="input:0", 170 | # outputs_op="output/transpose:0" 171 | outputs_op="output/predict:0", 172 | # outputs_op="dense_decoded:0" 173 | ) 174 | tf.compat.v1.reset_default_graph() 175 | tf.compat.v1.keras.backend.clear_session() 176 | sess.close() 177 | 178 | 179 | for op in graph.get_operations(): 180 | print(op.name, ": ", op.values()) 181 | 182 | compile_onnx(model_path) 183 | 184 | -------------------------------------------------------------------------------- /tools/delete_repeat_img.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 as cv 5 | import time 6 | import numpy as np 7 | from concurrent.futures import ThreadPoolExecutor 8 | from skimage.measure import compare_ssim 9 | 10 | EXT = ['.jpg', '.jpeg'] 11 | path = r'C:\Users\sml2h\Desktop\sb' 12 | codes = [item.split("_")[0].lower() for item in os.listdir(path)] 13 | codes = list(set(codes)) 14 | codes_dict = {} 15 | 16 | for code in codes: 17 | codes_dict[code] = [] 18 | for item in os.listdir(path): 19 | codes_dict[item.split("_")[0].lower()].append(item) 20 | 21 | def delete(imgs_n): 22 | # return 23 | for image in imgs_n: 24 | os.remove(image) 25 | 26 | # 27 | def find_sim_images(code_lists): 28 | imgs_n = [] 29 | img_files = [os.path.join(path, code) for code in code_lists] 30 | for currIndex, filename in enumerate(img_files): 31 | if filename in imgs_n: 32 | continue 33 | if currIndex >= len(img_files) - 1: 34 | break 35 | for filename2 in img_files[currIndex + 1:]: 36 | if filename2 in imgs_n: 37 | continue 38 | img = cv.imdecode(np.fromfile(filename, dtype=np.uint8), -1) 39 | img1 = cv.imdecode(np.fromfile(filename2, dtype=np.uint8), -1) 40 | try: 41 | ssim = compare_ssim(img, img1, multichannel=True) 42 | if ssim > 0.9: 43 | imgs_n.append(filename2) 44 | print(filename, filename2, ssim) 45 | except ValueError: 46 | pass 47 | print(imgs_n) 48 | delete(imgs_n) 49 | return imgs_n 50 | 51 | 52 | with ThreadPoolExecutor(max_workers=30) as t: 53 | tasks = [] 54 | for key in codes_dict: 55 | codes = codes_dict[key] 56 | task = t.submit(find_sim_images, codes) 57 | tasks.append(task) 58 | while True: 59 | result = [] 60 | for i in range(len(tasks)): 61 | result.append(task.done()) 62 | if False not in result: 63 | break 64 | time.sleep(1) -------------------------------------------------------------------------------- /tools/gif_frames.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import ImageSequence 8 | 9 | 10 | def split_frames(image_obj, need_frame=None): 11 | if not need_frame: 12 | need_frame = [0] 13 | image_seq = ImageSequence.all_frames(image_obj) 14 | image_arr_last = [np.asarray(image_seq[-1])] if -1 in need_frame and len(need_frame) > 1 else [] 15 | image_arr = [np.asarray(item) for i, item in enumerate(image_seq) if (i in need_frame or need_frame == [-1])] 16 | image_arr += image_arr_last 17 | return image_arr 18 | 19 | 20 | def concat_arr(img_arr): 21 | if len(img_arr) < 1: 22 | return img_arr 23 | all_slice = img_arr[0] 24 | for im_slice in img_arr[1:]: 25 | all_slice = np.concatenate((all_slice, im_slice), axis=1) 26 | return all_slice 27 | 28 | 29 | def numpy_to_bytes(numpy_arr): 30 | cv_img = cv2.imencode('.png', numpy_arr)[1] 31 | img_bytes = bytes(bytearray(cv_img)) 32 | return img_bytes 33 | 34 | 35 | def concat_frames(image_obj, need_frame=None): 36 | if not need_frame: 37 | need_frame = [0] 38 | img_arr = split_frames(image_obj, need_frame) 39 | img_arr = concat_arr(img_arr) 40 | return img_arr 41 | 42 | 43 | def blend_arr(img_arr): 44 | if len(img_arr) < 1: 45 | return img_arr 46 | all_slice = img_arr[0] 47 | for im_slice in img_arr[1:]: 48 | all_slice = cv2.addWeighted(all_slice, 0.5, im_slice, 0.5, 0) 49 | # print(all_slice) 50 | # all_slice = cv2.equalizeHist(all_slice) 51 | return all_slice 52 | 53 | 54 | def blend_frame(image_obj, need_frame=None): 55 | if not need_frame: 56 | need_frame = [-1] 57 | img_arr = split_frames(image_obj, need_frame) 58 | img_arr = blend_arr(img_arr) 59 | if len(img_arr.shape) > 2 and img_arr.shape[2] == 3: 60 | img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY) 61 | img_arr = cv2.equalizeHist(img_arr) 62 | return img_arr 63 | 64 | 65 | if __name__ == "__main__": 66 | pass -------------------------------------------------------------------------------- /tools/package.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import time 5 | from PyInstaller.__main__ import run 6 | from config import resource_path 7 | 8 | 9 | with open("../resource/VERSION", "w", encoding="utf8") as f: 10 | today = time.strftime("%Y%m%d", time.localtime(time.time())) 11 | f.write(today) 12 | 13 | 14 | def package(prefix): 15 | """基于PyInstaller打包编译为单可执行文件""" 16 | opts = ['{}app.spec'.format(prefix), '--distpath={}dist'.format(prefix), '--workpath={}build'.format(prefix)] 17 | run(opts) 18 | 19 | 20 | if __name__ == '__main__': 21 | try: 22 | package("../") 23 | except FileNotFoundError: 24 | package("/") 25 | -------------------------------------------------------------------------------- /trains.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import tensorflow as tf 5 | tf.compat.v1.disable_v2_behavior() 6 | tf.compat.v1.disable_eager_execution() 7 | try: 8 | gpus = tf.config.list_physical_devices('GPU') 9 | tf.config.experimental.set_memory_growth(gpus[0], True) 10 | 11 | except Exception as e: 12 | print(e, "No available gpu found.") 13 | # from tensorflow.python.platform.build_info import build_info 14 | import core 15 | import utils 16 | import utils.data 17 | import validation 18 | from config import * 19 | from tf_graph_util import convert_variables_to_constants 20 | from PIL import ImageFile 21 | # if build_info['cuda_version'] == '64_110': 22 | 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 25 | 26 | 27 | class Trains: 28 | 29 | stop_flag: bool = False 30 | """训练任务的类""" 31 | 32 | def __init__(self, model_conf: ModelConfig): 33 | """ 34 | :param model_conf: 读取工程配置文件 35 | """ 36 | self.model_conf = model_conf 37 | self.validation = validation.Validation(self.model_conf) 38 | 39 | def compile_graph(self, acc): 40 | """ 41 | 编译当前准确率下对应的计算图为pb模型,准确率仅作为模型命名的一部分 42 | :param acc: 准确率 43 | :return: 44 | """ 45 | input_graph = tf.compat.v1.Graph() 46 | tf.compat.v1.keras.backend.clear_session() 47 | tf.compat.v1.reset_default_graph() 48 | predict_sess = tf.compat.v1.Session(graph=input_graph) 49 | tf.compat.v1.keras.backend.set_session(predict_sess) 50 | 51 | with predict_sess.graph.as_default(): 52 | model = core.NeuralNetwork( 53 | model_conf=self.model_conf, 54 | mode=RunMode.Predict, 55 | backbone=self.model_conf.neu_cnn, 56 | recurrent=self.model_conf.neu_recurrent 57 | ) 58 | model.build_graph() 59 | model.build_train_op() 60 | input_graph_def = predict_sess.graph.as_graph_def() 61 | saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables()) 62 | tf.compat.v1.logging.info(tf.train.latest_checkpoint(self.model_conf.model_root_path)) 63 | saver.restore(predict_sess, tf.train.latest_checkpoint(self.model_conf.model_root_path)) 64 | 65 | output_graph_def = convert_variables_to_constants( 66 | predict_sess, 67 | input_graph_def, 68 | output_node_names=['dense_decoded'] 69 | ) 70 | 71 | if not os.path.exists(self.model_conf.compile_model_path): 72 | os.makedirs(self.model_conf.compile_model_path) 73 | 74 | last_compile_model_path = ( 75 | os.path.join(self.model_conf.compile_model_path, "{}.pb".format(self.model_conf.model_name)) 76 | ).replace('.pb', '_{}.pb'.format(int(acc * 10000))) 77 | 78 | self.model_conf.output_config(target_model_name="{}_{}".format(self.model_conf.model_name, int(acc * 10000))) 79 | with tf.io.gfile.GFile(last_compile_model_path, mode='wb') as gf: 80 | gf.write(output_graph_def.SerializeToString()) 81 | 82 | def achieve_cond(self, acc, cost, epoch): 83 | achieve_accuracy = acc >= self.model_conf.trains_end_acc 84 | achieve_cost = cost <= self.model_conf.trains_end_cost 85 | achieve_epochs = epoch >= self.model_conf.trains_end_epochs 86 | over_epochs = epoch > 10000 87 | if (achieve_accuracy and achieve_epochs and achieve_cost) or over_epochs: 88 | return True 89 | return False 90 | 91 | def init_captcha_gennerator(self, ran_captcha): 92 | 93 | path = self.model_conf.da_random_captcha['FontPath'] 94 | if not os.path.exists(path): 95 | exception("Font path does not exist.", code=-6754) 96 | items = os.listdir(path) 97 | fonts = [os.path.join(path, item) for item in items] 98 | ran_captcha.sample = NUMBER + ALPHA_UPPER + ALPHA_LOWER 99 | ran_captcha.fonts_list = fonts 100 | ran_captcha.check_font() 101 | ran_captcha.rgb_r = [0, 255] 102 | ran_captcha.rgb_g = [0, 255] 103 | ran_captcha.rgb_b = [0, 255] 104 | ran_captcha.fonts_num = [4, 8] 105 | 106 | def train_process(self): 107 | """ 108 | 训练任务 109 | :return: 110 | """ 111 | # 输出重要的配置参数 112 | self.model_conf.println() 113 | # 定义网络结构 114 | model = core.NeuralNetwork( 115 | mode=RunMode.Trains, 116 | model_conf=self.model_conf, 117 | backbone=self.model_conf.neu_cnn, 118 | recurrent=self.model_conf.neu_recurrent 119 | ) 120 | model.build_graph() 121 | 122 | tf.compat.v1.logging.info('Loading Trains DataSet...') 123 | train_feeder = utils.data.DataIterator( 124 | model_conf=self.model_conf, mode=RunMode.Trains 125 | ) 126 | train_feeder.read_sample_from_tfrecords(self.model_conf.trains_path[DatasetType.TFRecords]) 127 | 128 | tf.compat.v1.logging.info('Loading Validation DataSet...') 129 | validation_feeder = utils.data.DataIterator( 130 | model_conf=self.model_conf, mode=RunMode.Validation 131 | ) 132 | validation_feeder.read_sample_from_tfrecords(self.model_conf.validation_path[DatasetType.TFRecords]) 133 | 134 | tf.compat.v1.logging.info('Total {} Trains DataSets'.format(train_feeder.size)) 135 | tf.compat.v1.logging.info('Total {} Validation DataSets'.format(validation_feeder.size)) 136 | if validation_feeder.size >= train_feeder.size: 137 | exception("The number of training sets cannot be less than the validation set.", ) 138 | if validation_feeder.size < self.model_conf.validation_batch_size: 139 | exception("The number of validation sets cannot be less than the validation batch size.", ) 140 | 141 | num_train_samples = train_feeder.size 142 | num_validation_samples = validation_feeder.size 143 | 144 | if num_validation_samples < self.model_conf.validation_batch_size: 145 | self.model_conf.validation_batch_size = num_validation_samples 146 | tf.compat.v1.logging.warn( 147 | 'The number of validation sets is less than the validation batch size, ' 148 | 'will use validation set size as validation batch size.'.format(validation_feeder.size)) 149 | 150 | num_batches_per_epoch = int(num_train_samples / self.model_conf.batch_size) 151 | 152 | model.build_train_op(num_train_samples) 153 | 154 | # 会话配置 155 | sess_config = tf.compat.v1.ConfigProto( 156 | allow_soft_placement=True, 157 | # log_device_placement=False, 158 | gpu_options=tf.compat.v1.GPUOptions( 159 | allocator_type='BFC', 160 | allow_growth=True, # it will cause fragmentation. 161 | # per_process_gpu_memory_fraction=0.3 162 | ) 163 | ) 164 | accuracy = 0 165 | epoch_count = 1 166 | 167 | if num_train_samples < 500: 168 | save_step = 10 169 | trains_validation_steps = 50 170 | 171 | else: 172 | save_step = 100 173 | trains_validation_steps = self.model_conf.trains_validation_steps 174 | 175 | sess = tf.compat.v1.Session(config=sess_config) 176 | 177 | init_op = tf.compat.v1.global_variables_initializer() 178 | sess.run(init_op) 179 | saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables(), max_to_keep=3) 180 | train_writer = tf.compat.v1.summary.FileWriter('logs', sess.graph) 181 | # try: 182 | checkpoint_state = tf.train.get_checkpoint_state(self.model_conf.model_root_path) 183 | if checkpoint_state and checkpoint_state.model_checkpoint_path: 184 | # 加载被中断的训练任务 185 | saver.restore(sess, checkpoint_state.model_checkpoint_path) 186 | 187 | tf.compat.v1.logging.info('Start training...') 188 | 189 | # 进入训练任务循环 190 | while 1: 191 | 192 | start_time = time.time() 193 | batch_cost = 65535 194 | # 批次循环 195 | for cur_batch in range(num_batches_per_epoch): 196 | 197 | if self.stop_flag: 198 | break 199 | 200 | batch_time = time.time() 201 | 202 | trains_batch = train_feeder.generate_batch_by_tfrecords(sess) 203 | 204 | batch_inputs, batch_labels = trains_batch 205 | 206 | feed = { 207 | model.inputs: batch_inputs, 208 | model.labels: batch_labels, 209 | model.utils.is_training: True 210 | } 211 | 212 | summary_str, batch_cost, step, _, seq_len = sess.run( 213 | [model.merged_summary, model.cost, model.global_step, model.train_op, model.seq_len], 214 | feed_dict=feed 215 | ) 216 | train_writer.add_summary(summary_str, step) 217 | 218 | if step % save_step == 0 and step != 0: 219 | tf.compat.v1.logging.info( 220 | 'Step: {} Time: {:.3f} sec/batch, Cost = {:.8f}, BatchSize: {}, Shape[1]: {}'.format( 221 | step, 222 | time.time() - batch_time, 223 | batch_cost, 224 | len(batch_inputs), 225 | seq_len[0] 226 | ) 227 | ) 228 | 229 | # 达到保存步数对模型过程进行存储 230 | if step % save_step == 0 and step != 0: 231 | saver.save(sess, self.model_conf.save_model, global_step=step) 232 | 233 | # 进入验证集验证环节 234 | if step % trains_validation_steps == 0 and step != 0: 235 | 236 | batch_time = time.time() 237 | validation_batch = validation_feeder.generate_batch_by_tfrecords(sess) 238 | 239 | test_inputs, test_labels = validation_batch 240 | val_feed = { 241 | model.inputs: test_inputs, 242 | model.labels: test_labels, 243 | model.utils.is_training: False 244 | } 245 | dense_decoded, lr = sess.run( 246 | [model.dense_decoded, model.lrn_rate], 247 | feed_dict=val_feed 248 | ) 249 | # 计算准确率 250 | accuracy = self.validation.accuracy_calculation( 251 | validation_feeder.labels, 252 | dense_decoded, 253 | ) 254 | log = "Epoch: {}, Step: {}, Accuracy = {:.4f}, Cost = {:.5f}, " \ 255 | "Time = {:.3f} sec/batch, LearningRate: {}" 256 | tf.compat.v1.logging.info(log.format( 257 | epoch_count, 258 | step, 259 | accuracy, 260 | batch_cost, 261 | time.time() - batch_time, 262 | lr / len(validation_batch), 263 | )) 264 | 265 | # 满足终止条件但尚未完成当前epoch时跳出epoch循环 266 | if self.achieve_cond(acc=accuracy, cost=batch_cost, epoch=epoch_count): 267 | break 268 | 269 | # 满足终止条件时,跳出任务循环 270 | if self.stop_flag: 271 | break 272 | if self.achieve_cond(acc=accuracy, cost=batch_cost, epoch=epoch_count): 273 | # sess.close() 274 | tf.compat.v1.keras.backend.clear_session() 275 | sess.close() 276 | self.compile_graph(accuracy) 277 | tf.compat.v1.logging.info('Total Time: {} sec.'.format(time.time() - start_time)) 278 | 279 | break 280 | epoch_count += 1 281 | tf.compat.v1.logging.info('Total Time: {} sec.'.format(time.time() - start_time)) 282 | 283 | 284 | def main(argv): 285 | project_name = argv[-1] 286 | model_conf = ModelConfig(project_name=project_name) 287 | Trains(model_conf).train_process() 288 | tf.compat.v1.logging.info('Training completed.') 289 | pass 290 | 291 | 292 | if __name__ == '__main__': 293 | # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 294 | tf.compat.v1.app.run() 295 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | 5 | # from . import sparse 6 | # from . import data -------------------------------------------------------------------------------- /utils/category_frequency_statistics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import re 5 | import os 6 | import json 7 | from config import ModelConfig, LabelFrom, DatasetType 8 | 9 | ignore_list = ["Thumbs.db", ".DS_Store"] 10 | PATH_SPLIT = "/" 11 | 12 | 13 | def extract_labels_from_filename(filename: str, extract_regex): 14 | if filename.split("/")[-1] in ignore_list: 15 | return None 16 | try: 17 | labels = re.search(extract_regex, filename.split(PATH_SPLIT)[-1]) 18 | except re.error as e: 19 | print('error:', e) 20 | return None 21 | if labels: 22 | labels = labels.group() 23 | else: 24 | print('invalid filename {}, ignored.'.format(filename)) 25 | return None 26 | return labels 27 | 28 | 29 | def fetch_category_freq(model: ModelConfig): 30 | if model.label_from == LabelFrom.FileName: 31 | category_dict = dict() 32 | for iter_dir in model.trains_path[DatasetType.Directory]: 33 | for filename in os.listdir(iter_dir): 34 | 35 | labels = extract_labels_from_filename(filename, model.extract_regex) 36 | 37 | if not labels: 38 | continue 39 | 40 | for label_item in labels: 41 | if label_item in category_dict: 42 | category_dict[label_item] += 1 43 | else: 44 | category_dict[label_item] = 0 45 | 46 | return sorted(category_dict.items(), key=lambda item: item[1], reverse=True) 47 | 48 | 49 | def fetch_category_list(model: ModelConfig, is_json=False): 50 | if model.label_from == LabelFrom.FileName: 51 | category_set = set() 52 | for iter_dir in model.trains_path[DatasetType.Directory]: 53 | for filename in os.listdir(iter_dir): 54 | 55 | labels = extract_labels_from_filename(filename, model.extract_regex) 56 | if not labels: 57 | continue 58 | if int(model.max_label_num) == 1: 59 | category_set.add(labels) 60 | elif '&' in labels: 61 | for label_item in labels.split('&'): 62 | category_set.add(label_item) 63 | else: 64 | for label_item in labels: 65 | category_set.add(label_item) 66 | category_list = list(category_set) 67 | category_list.sort() 68 | if is_json: 69 | return json.dumps(category_list, ensure_ascii=False) 70 | return category_list 71 | 72 | 73 | if __name__ == '__main__': 74 | model_conf = ModelConfig("test-CNNX-GRU-H64-CTC-C1") 75 | # labels_dict = fetch_category_freq(model_conf) 76 | # label_list = [k for k, v in labels_dict if v < 5000] 77 | # label_list.sort() 78 | # high_freq = "".join(label_list) 79 | # print(high_freq) 80 | # print(len(high_freq)) 81 | labels_list = fetch_category_list(model_conf) 82 | print(labels_list) -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import os 5 | import hashlib 6 | import utils 7 | import random 8 | import utils.sparse 9 | import tensorflow as tf 10 | import numpy as np 11 | from constants import RunMode, ModelField, DatasetType, LossFunction 12 | from config import ModelConfig, EXCEPT_FORMAT_MAP 13 | from encoder import Encoder 14 | from exception import exception 15 | 16 | 17 | class DataIterator: 18 | """数据集迭代类""" 19 | 20 | def __init__(self, model_conf: ModelConfig, mode: RunMode, ran_captcha=None): 21 | """ 22 | :param model_conf: 工程配置 23 | :param mode: 运行模式(区分:训练/验证) 24 | """ 25 | self.model_conf = model_conf 26 | self.mode = mode 27 | self.path_map = { 28 | RunMode.Trains: self.model_conf.trains_path[DatasetType.TFRecords], 29 | RunMode.Validation: self.model_conf.validation_path[DatasetType.TFRecords] 30 | } 31 | self.batch_map = { 32 | RunMode.Trains: self.model_conf.batch_size, 33 | RunMode.Validation: self.model_conf.validation_batch_size 34 | } 35 | self.data_dir = self.path_map[mode] 36 | self.next_element = None 37 | self.image_path = [] 38 | self.label_list = [] 39 | self._label_list = [] 40 | self._size = 0 41 | self.encoder = Encoder(self.model_conf, self.mode) 42 | self.ran_captcha = ran_captcha 43 | 44 | @staticmethod 45 | def parse_example(serial_example): 46 | 47 | features = tf.io.parse_single_example( 48 | serial_example, 49 | features={ 50 | 'label': tf.io.FixedLenFeature([], tf.string), 51 | 'input': tf.io.FixedLenFeature([], tf.string), 52 | } 53 | ) 54 | _input = tf.cast(features['input'], tf.string) 55 | _label = tf.cast(features['label'], tf.string) 56 | 57 | return _input, _label 58 | 59 | @staticmethod 60 | def total_sample(file_name): 61 | sample_nums = 0 62 | for _ in tf.compat.v1.python_io.tf_record_iterator(file_name): 63 | sample_nums += 1 64 | return sample_nums 65 | 66 | def read_sample_from_tfrecords(self, path): 67 | """ 68 | 从TFRecords中读取样本 69 | :param path: TFRecords文件路径 70 | :return: 71 | """ 72 | if isinstance(path, list): 73 | for p in path: 74 | self._size += self.total_sample(p) 75 | else: 76 | self._size = self.total_sample(path) 77 | 78 | min_after_dequeue = 1000 79 | batch = self.batch_map[self.mode] 80 | if self.model_conf.da_random_captcha['Enable']: 81 | batch = random.randint(int(batch / 3 * 2), batch) 82 | 83 | dataset_train = tf.data.TFRecordDataset( 84 | filenames=path, 85 | num_parallel_reads=20 86 | ).map(self.parse_example) 87 | dataset_train = dataset_train.shuffle( 88 | min_after_dequeue, 89 | reshuffle_each_iteration=True 90 | ).prefetch(128).batch(batch, drop_remainder=True).repeat() 91 | iterator = tf.compat.v1.data.make_one_shot_iterator(dataset_train) 92 | self.next_element = iterator.get_next() 93 | 94 | @property 95 | def size(self): 96 | """样本数""" 97 | return self._size 98 | 99 | @property 100 | def labels(self): 101 | """标签""" 102 | return self.label_list 103 | 104 | @staticmethod 105 | def to_sparse(input_batch, label_batch): 106 | """密集输入转稀疏""" 107 | batch_inputs = input_batch 108 | batch_labels = utils.sparse.sparse_tuple_from_sequences(label_batch) 109 | return batch_inputs, batch_labels 110 | 111 | def generate_captcha(self, num) -> (list, list): 112 | _images = [] 113 | _labels = [] 114 | for i in range(num): 115 | try: 116 | image, labels, font_type = self.ran_captcha.create() 117 | _images.append(image) 118 | _labels.append(''.join(labels).encode()) 119 | except Exception as e: 120 | print(e) 121 | pass 122 | return _images, _labels 123 | 124 | def generate_batch_by_tfrecords(self, session): 125 | """根据TFRecords生成当前批次,输入为当前TensorFlow会话,输出为稀疏型X和Y""" 126 | # print(session.graph) 127 | batch = self.batch_map[self.mode] 128 | 129 | _input, _label = session.run(self.next_element) 130 | if self.model_conf.da_random_captcha['Enable']: 131 | remain_batch = batch - len(_label) 132 | extra_input, extra_label = self.generate_captcha(remain_batch) 133 | _input = np.concatenate((_input, extra_input), axis=0) 134 | _label = np.concatenate((_label, extra_label), axis=0) 135 | 136 | input_batch = [] 137 | label_batch = [] 138 | for index, (i1, i2) in enumerate(zip(_input, _label)): 139 | try: 140 | label_array = self.encoder.text(i2) 141 | if self.model_conf.model_field == ModelField.Image: 142 | input_array = self.encoder.image(i1) 143 | else: 144 | input_array = self.encoder.text(i1) 145 | 146 | if input_array is None: 147 | # tf.compat.v1.logging.warn( 148 | # "{}, Cannot identify image file labeled: {}, ignored.".format(input_array, label_array)) 149 | continue 150 | 151 | if isinstance(input_array, str): 152 | # tf.compat.v1.logging.warn("{}, \nInput errors labeled: {} [{}], ignored.".format(input_array, i1, label_array)) 153 | continue 154 | if isinstance(label_array, dict): 155 | # tf.logging.warn("The sample label {} contains invalid charset: {}.".format( 156 | # label_array['label'], label_array['char'] 157 | # )) 158 | continue 159 | 160 | if input_array.shape[-1] != self.model_conf.image_channel: 161 | # pass 162 | tf.compat.v1.logging.warn("{}, \nInput shape: {}, ignored.".format( 163 | self.model_conf.image_channel, input_array.shape[-1]) 164 | ) 165 | continue 166 | 167 | label_len_correct = len(label_array) != self.model_conf.max_label_num 168 | using_cross_entropy = self.model_conf.loss_func == LossFunction.CrossEntropy 169 | if label_len_correct and using_cross_entropy and not self.model_conf.auto_padding: 170 | tf.compat.v1.logging.warn("The number of labels must be fixed when using cross entropy, label: {}, " 171 | "the number of tags is incorrect, ignored.".format(i2)) 172 | continue 173 | 174 | if len(label_array) > self.model_conf.max_label_num and using_cross_entropy: 175 | tf.compat.v1.logging.warn( 176 | "The number of label[{}] exceeds the maximum number of labels, ignored.{}".format(i2, 177 | label_array)) 178 | continue 179 | 180 | input_batch.append(input_array) 181 | label_batch.append(label_array) 182 | except OSError: 183 | random_suffix = hashlib.md5(i1).hexdigest() 184 | file_format = EXCEPT_FORMAT_MAP[self.model_conf.model_field] 185 | with open(file="oserror_{}.{}".format(random_suffix, file_format), mode="wb") as f: 186 | f.write(i1) 187 | tf.compat.v1.logging.warn("OSError [{}]".format(i2)) 188 | continue 189 | 190 | # 如果图片尺寸不固定则padding当前批次,使用最大的宽度作为序列最大长度 191 | if self.model_conf.model_field == ModelField.Image and self.model_conf.resize[0] == -1: 192 | input_batch = tf.keras.preprocessing.sequence.pad_sequences( 193 | sequences=input_batch, 194 | maxlen=None, 195 | dtype='float32', 196 | padding='post', 197 | truncating='post', 198 | value=0 199 | ) 200 | 201 | self.label_list = label_batch 202 | return self.to_sparse(input_batch, self.label_list) 203 | -------------------------------------------------------------------------------- /utils/sparse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import numpy as np 5 | 6 | 7 | def sparse_tuple_from_sequences(sequences, dtype=np.int32): 8 | """密集序列转稀疏序列""" 9 | indices = [] 10 | values = [] 11 | for n, seq in enumerate(sequences): 12 | indices.extend(zip([n] * len(seq), range(0, len(seq), 1))) 13 | values.extend(seq) 14 | 15 | indices = np.asarray(indices, dtype=np.int64) 16 | try: 17 | values = np.asarray(values, dtype=dtype) 18 | except Exception as e: 19 | print(e, values) 20 | shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) 21 | return indices, values, shape 22 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Author: kerlomz 4 | import json 5 | import numpy as np 6 | import tensorflow as tf 7 | from config import ModelConfig 8 | 9 | 10 | class Validation(object): 11 | """验证类,用于准确率计算""" 12 | def __init__(self, model: ModelConfig): 13 | """ 14 | :param model: 读取配置文件获取当前工程的重要参数:category_num, category 15 | """ 16 | self.model = model 17 | self.category_num = self.model.category_num 18 | self.category = self.model.category 19 | 20 | def accuracy_calculation(self, original_seq, decoded_seq): 21 | """ 22 | 准确率计算函数 23 | :param original_seq: 密集数组-Y标签 24 | :param decoded_seq: 密集数组-预测标签 25 | :return: 26 | """ 27 | if isinstance(decoded_seq, np.ndarray): 28 | decoded_seq = decoded_seq.tolist() 29 | 30 | ignore_value = [-1, self.category_num, 0] 31 | original_seq_len = len(original_seq) 32 | decoded_seq_len = len(decoded_seq) 33 | 34 | if original_seq_len != decoded_seq_len: 35 | tf.compat.v1.logging.error(original_seq) 36 | tf.compat.v1.logging.error(decoded_seq) 37 | tf.compat.v1.logging.error('original lengths {} is different from the decoded_seq {}, please check again'.format( 38 | original_seq_len, 39 | decoded_seq_len 40 | )) 41 | return 0 42 | count = 0 43 | 44 | # Here is for debugging, positioning error source use 45 | error_sample = [] 46 | for i, origin_label in enumerate(original_seq): 47 | 48 | decoded_label = decoded_seq[i] 49 | if isinstance(decoded_label, int): 50 | decoded_label = [decoded_label] 51 | processed_decoded_label = [j for j in decoded_label if j not in ignore_value] 52 | processed_origin_label = [j for j in origin_label if j not in ignore_value] 53 | 54 | if i < 5: 55 | tf.compat.v1.logging.info( 56 | "{} {} {} {} {} --> {} {}".format( 57 | i, 58 | len(processed_origin_label), 59 | len(processed_decoded_label), 60 | origin_label, 61 | decoded_label, 62 | [self.category[_] if _ != self.category_num else '-' for _ in origin_label if _ != -1], 63 | [self.category[_] if _ != self.category_num else '-' for _ in decoded_label if _ != -1] 64 | ) 65 | ) 66 | if processed_origin_label == processed_decoded_label: 67 | count += 1 68 | # Training is not useful for decoding 69 | # Here is for debugging, positioning error source use 70 | if processed_origin_label != processed_decoded_label and len(error_sample) < 5: 71 | error_sample.append({ 72 | "origin": "".join([self.category[_] if _ != self.category_num else '-' for _ in origin_label if _ != -1]), 73 | "decode": "".join([self.category[_] if _ != self.category_num else '-' for _ in decoded_label if _ != -1]) 74 | }) 75 | tf.compat.v1.logging.error(json.dumps(error_sample, ensure_ascii=False)) 76 | return count * 1.0 / len(original_seq) --------------------------------------------------------------------------------