├── .gitignore ├── LICENSE ├── README.md ├── acc-easy.svg ├── acc.svg ├── captcha-images ├── .classpath ├── .project ├── .settings │ ├── org.eclipse.core.resources.prefs │ ├── org.eclipse.jdt.apt.core.prefs │ ├── org.eclipse.jdt.core.prefs │ └── org.eclipse.m2e.core.prefs ├── pom.xml └── src │ ├── main │ └── java │ │ ├── log4j.properties │ │ └── me │ │ └── zouzhipeng │ │ ├── App.java │ │ ├── CaptchaGenerator.java │ │ ├── CaptchaTaskRunner.java │ │ ├── EasyCaptchaGeneratorWorker.java │ │ ├── Generator.java │ │ ├── GeneratorMentor.java │ │ ├── KaptchaGeneratorWorker.java │ │ ├── config │ │ ├── Config.java │ │ ├── ConfigBuilder.java │ │ └── ConfigConstants.java │ │ └── utils │ │ └── ImageOutputUtil.java │ └── test │ └── java │ └── me │ └── zouzhipeng │ └── AppTest.java ├── captcha-net.svg ├── captchas.jpg ├── kaptcha-net.svg ├── kaptchas.jpg ├── loss-easy.svg ├── loss.svg ├── requirements.txt ├── src ├── captcha_recognition.ipynb ├── data.py ├── data │ ├── dev.npy │ ├── dev.y.npy │ ├── test.npy │ ├── test.y.npy │ ├── train.npy │ └── train.y.npy ├── dataloader.py ├── eval.py ├── kaptcha_model.py ├── logs │ └── eval.json ├── main.py ├── metrics.py ├── model.py ├── model │ ├── captcha-model.pkl │ ├── kaptcha-model.pkl │ └── model-latest.pkl ├── predict.py └── train.py └── test.svg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/ 132 | .DS_Store 133 | .idea/ 134 | venv/ 135 | target/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Frank Tsau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN验证码识别 2 | 3 | ![Version](https://img.shields.io/badge/version-1.0-brightgreen.svg) ![Build](https://img.shields.io/badge/build-passing-brightgreen.svg) ![https://github.com/Frank17/fight-game/releases](https://img.shields.io/badge/release-v1.0-brightgreen.svg) ![java-version](https://img.shields.io/badge/java->=1.8-brightgreen.svg) ![java-version](https://img.shields.io/badge/python->=3.6-brightgreen.svg) 4 | 5 | ## 介绍 6 | 7 | 本项目使用神经网络方法,基于两个开源验证码框架:EasyCaptcha和Kaptcha,设计了两个神经网络模型,并随机生成一批验证码图片供模型训练及验证。经过实验发现,EasyCaptcha验证码相对简单,较容易识别,并且噪音较少,因此只需简单的模型即可达到高精度识别。而Kaptcha相对复杂,含有噪音,并且验证码可变形式较多,识别较困难。 8 | 9 | 本项目两个模型在EasyCaptcha验证码上达到了$98\sim 99\% $的精度水平,而在Kaptcha上达到了$93\sim 94\%$的水平,基本可实现较为准确的识别。 10 | 11 | 该项目提供了模型训练所用的训练集,均采用开源框架自动生成。此外还提供可直接嵌入使用的预训练模型。 12 | 13 | > 若您的公式无法正常显示,请您移步:[https://chrome.google.com/webstore/detail/mathjax-plugin-for-github/ioemnmodlmafdkllaclgeombjnmnbima](https://chrome.google.com/webstore/detail/mathjax-plugin-for-github/ioemnmodlmafdkllaclgeombjnmnbima),使用Chrome安装公式渲染扩展。 14 | 15 | ### 依赖环境 16 | 17 | * Python 3.6 18 | * click 7.1.1 19 | * PyTorch 1.4.0 20 | * torchvision 0.1.6.dev0 21 | * tqdm 4.45.0 22 | * Scikit-learn 0.22.2.post1 23 | * Pillow 7.1.1 24 | * matploblib 3.2.1 25 | 26 | 详细使用方法,请见下节**开始使用**。 27 | 28 | ## 开始使用 29 | 30 | 若要生成训练验证码图片,请使用`captcha-images-1.0.jar`,其使用方式如下: 31 | 32 | ```shell 33 | usage: java -jar [jarfile].jar 34 | -c The text color of the captcha, if not specified, 35 | use randomly, should be ,, 36 | -e If the noise color is the same with the text 37 | color, should be true or false, default false. 38 | -h Show help 39 | -k The kinds of captchas, integer, default 1 40 | -l The length of characters in the captchas, integer, 41 | default 4 42 | -m The mode of captcha, could be [easycaptcha] or 43 | [kaptcha] 44 | -n The noise color of the captcha, if not specified, 45 | use randomly, should be ,, 46 | -o The output directory of captchas, string 47 | -p Thread pool size, integer, default 20 48 | -s The size of the capthcas, should be: 49 | , 50 | -t The count to produce, integer, default 5,000 51 | -v The height of the capthcas, integer, default 80 52 | -w The width of the capthcas, integer, default 120 53 | 54 | ``` 55 | 56 | 任务启动后,将并发创建一批验证码图片,并存入指定的目录中。 57 | 58 | 如果在notebook中使用,请在`train.py`与`eval.py`中切换如下包 59 | ```python 60 | # from tqdm import tqdm 61 | from tqdm.notebook import tqdm 62 | ``` 63 | 64 | > `tqdm`在shell中展示进度条,而`tqdm.notebook`则在notebook环境中,展示进度条,其显示风格更加符合HTML规范 65 | 66 | 若需要从头开始,或根据已有的快照继续训练模型,请使用`train.py`文件,使用方法如下: 67 | 68 | ```shell 69 | Usage: train.py [OPTIONS] 70 | 71 | Options: 72 | -h, --help Show this message 73 | and exit. 74 | 75 | -i, --data_dir PATH The path of train 76 | data 77 | 78 | -m, --mode [captcha|kaptcha] The model type to 79 | train, could be 80 | captcha or kaptcha 81 | 82 | -e, --epoch INTEGER The number of 83 | epoch model 84 | trained 85 | 86 | -p, --data_split INTEGER... The split of train 87 | data to split 88 | 89 | -c, --continue_train TEXT If continue after 90 | last checkpoint or 91 | a specified one 92 | 93 | -t, --checkpoint INTEGER The initial 94 | checkpoint to 95 | start, if set, it 96 | will load model-[c 97 | heckpoint].pkl 98 | 99 | -b, --batch_size INTEGER The batch size of 100 | input data 101 | 102 | -o, --model_dir PATH The model dir to 103 | save models or 104 | load models 105 | 106 | -r, --lr FLOAT The learning rate 107 | to train 108 | 109 | -l, --log_dir PATH The log files path 110 | -u, --use_gpu BOOLEAN Train by gpu or 111 | cpu 112 | 113 | -s, --save_frequency INTEGER The frequence to 114 | save the models 115 | during training 116 | ``` 117 | 118 | 若需要对已有模型进行评估,请使用`eval.py`文件,其使用方式如下: 119 | 120 | ```shell 121 | Usage: eval.py [OPTIONS] 122 | 123 | Options: 124 | -h, --help Show this message and exit. 125 | -i, --data_dir PATH The path of train data 126 | -m, --mode [captcha|kaptcha] The model type to train, could be captcha or 127 | kaptcha 128 | 129 | -b, --batch_size INTEGER The batch size of input data 130 | -o, --model_dir PATH The model dir to save models or load models 131 | -l, --log_dir PATH The log files path 132 | -u, --use_gpu BOOLEAN Train by gpu or cpu 133 | ``` 134 | 135 | 136 | 若需要使用现有的模型对验证码进行识别,请使用`predict.py`,其使用方法如下: 137 | 138 | 139 | ```shell 140 | Usage: predict.py [OPTIONS] 141 | 142 | Options: 143 | -h, --help Show this message and exit. 144 | -i, --image_path PATH The path of the captcha image [required] 145 | -m, --mode [captcha|kaptcha] The model type to train, could be captcha or 146 | kaptcha 147 | 148 | -o, --model_dir PATH The model dir to load 149 | -u, --use_gpu BOOLEAN Train by gpu or cpu 150 | ``` 151 | 152 | 153 | ### 示例 154 | 155 | ```shell 156 | java -jar captcha-images-1.0.jar 157 | 158 | # 默认为输出路径为:./captchas 159 | # 开始训练模型 160 | python train.py -i ./captchas -m captcha -b 1024 -o ./models -u True 161 | 162 | # 开始评估模型 163 | python eval.py -i ./captchas -m captcha -b 128 164 | 165 | # 开始预测 使用中,请将abcd.jpg换成实际的测试用验证码文件 166 | python predict -i ./captchas/abcd.jpg 167 | ``` 168 | 169 | 170 | ## 数据集介绍 171 | 172 | 所有训练数据均以验证码图片内容为名称命名,如`2ANF.jpg`,因此可以保证训练数据没有重复项,根据文件名即可获取样本label。 173 | 174 | 数据集下载:[Dataset-Google Drive for Easy Captcha](https://drive.google.com/file/d/1HhTTbN8cjBs-1hEv38lpxBhD_z8EOGlk/view?usp=sharing) 175 | 176 | > 数据规格:48,320张验证码图片,全由`Easy Captcha`框架生成,大小为$120 \times 80$。 177 | 178 | 验证码示例: 179 | ![Captcha示例](captchas.jpg) 180 | 181 | [EasyCaptch项目主页](https://github.com/whvcse/EasyCaptcha) 182 | 183 | EasyCaptcha验证码特点在于可以构造Gif动态验证码,而其他验证码则显得相对简单,主要在于该验证码间隔较开,易于区分,因此识别较为简单。根据对上例中的验证码分析可知,验证码由不定位置的1-2个圆圈与曲线构成噪音,对文本加以干扰,文字颜色可变。从布局来看,文字的布局位置相对固定,且间隔也相对固定,这无疑也简化了识别过程。 184 | 185 | 数据集下载:[Dataset-Google Drive for Kaptcha](https://drive.google.com/file/d/1QwLBp35Q_-b6GCZ7BjCsIf5L3YpkLRLg/view?usp=sharing) 186 | 187 | > 数据规格:52,794张验证码图片,全由`Kaptcha`生成,大小为$200 \times 50$。 188 | 189 | 验证码示例: 190 | ![Kaptcha示例](kaptchas.jpg) 191 | 192 | [Kaptcha项目主页](https://github.com/penggle/kaptcha) 193 | 194 | 相对而言,Kaptcha验证码相对而言文本排布默认更加紧凑,但是文字间距再kaptcha中是一个可以调节的超参数。Kaptcha较难识别的主要原因在于其文本存在可能的扭曲形变,并且形变状态不定,因此模型需要能够克服该形变,方可较为准确的识别,因此Kaptcha识别较captcha困难,并且准确度指标会有所下降。 195 | 196 | **注:在直接使用模型时需要严格注意验证码规格,这主要在于图片过小会导致CNN过程异常。若对图片进行分辨率调整,长宽比不一,将导致严重形变,导致识别精度下降。** 197 | 198 | ## 模型介绍 199 | 200 | 针对两个不同的数据集,本项目设计了两个不同的模型,但是总体上都是基于CNN和FCN结构的分类任务。在诸多OCR任务中,通常会使用multi-stage方法设计模型,即:**通常使用对象识别模型,识别图片中的文字,并用框标记出文字所在位置,再利用CNN和FCN的结构对所识别的文字进行分类。并且,若为文档OCR识别,输出层还可能借助LSTM等RNN结构网络。** 201 | 202 | 考虑到验证码通常位数有限,即4位、5位较为常见,因此该模型采用end2end multi-task方法也可满足需求,且模型复杂度并不高。 203 | 204 | 针对EasyCaptcha验证码,其产生的验证码较容易区分,字符分隔较开,且变形选项较少,因此使用很简单的模型即可达到较高的精度,在本项目的模型中,验证集准确度可达到$98-99%$左右。 205 | 206 | 而对于Kaptcha验证码,其存在较多可选的配置项,并且会在验证码中间添加噪音扰动,因此识别较为困难,使用EasyCaptcha的模型,精度仅能达到70%左右,准确度较低,Kaptcha模型适当地加大了CNN网络的深度,并增加了一层全连接隐藏层,在验证集上达到93-94%的准确度。 207 | 208 | 在训练过程中,采用长度为4的验证码,其中验证码中可选字符为:a-zA-Z0-9,共62个可能字符。 209 | 210 | 下面为两个模型:EasyNet, KCapNet的详细介绍。 211 | 212 | ### EasyNet模型 213 | 214 | 215 | ![captcha-net](captcha-net.svg) 216 | 217 | EasyNet模型由2层卷积层和4个输出层构成,该模型结构细节如下: 218 | 219 | 1. 第一层卷积,卷积核大小为$5 \times 5$,步长为1,通道为16,参数量为$3 \times 5 \times 5 \times 16$,得到大小$76\times 116$的特征图共16个; 220 | 2. 第二层为最大池化层,无参数,池化核大小为$5\times 5$,步长为5,得到特征图大小为$15 \times 23$; 221 | 3. 第三层为批归一化层,避免过拟合并加速模型收敛。根据效果,同时还可尝试使用shortcut方法。归一化后,采用RReLu激活函数; 222 | 4. 第四层为卷积层,卷积核大小为$5 \times 5$,步长为1,通道数为32,得到特征图大小为$11 \times 19$,参数量为$16 \times 5 \times 5 \times 32$; 223 | 5. 第五层为最大池化层,无参数,池化核大小为$5 \times 5$,步长为5,得到特征图大小为$2 \times 4$,无参数; 224 | 6. 第六层仍为批归一化层,并采用RReLu函数激活; 225 | 7. 第七层为Dropout层,经过第二个卷积后,将得到特征图展开,得到特征向量维度为256维,对256的特征向量进行Dropout处理,避免过拟合,采用的失效概率为$p=0.3$; 226 | 8. 第八层为输出层,用于得到验证码序列,由于模型为multi-task,因此输出层有4个(根据验证码中字符长度确定),使用softmax激活,参数量为$4\times 256 \times 62$。 227 | 228 | ### Kaptcha模型 229 | 230 | ![kaptcha-net](kaptcha-net.svg) 231 | 232 | KCapNet共由3个卷积层,1个全连接层,4个输出层组成,以下为模型具体细节: 233 | 234 | 1. 第一层为卷积层,卷积核大小为$5 \times 5$,步长为1,通道数为16,输入图片大小为$50 \times 200$,因此可得到16个大小为$46 \times 196$的特征图,参数量为$3\times 5 \times 5 \times 16$; 235 | 2. 第二层为最大池化层,无参数,池化核大小为$3 \times 3$,步长为3,得到特征图大小为$15 \times 64$; 236 | 3. 第三层为批归一化层,在归一化结束后,使用RReLu激活函数激活; 237 | 4. 第四层为第二个卷积层,卷积核大小为$3 \times 3$,通道数为32,可得到大小为$13 \times 62$的特征图32个,参数量为$16 \times 3 \times 3 \times 32$; 238 | 5. 第五层为最大池化层,无参数,池化核大小为$ 3 \times 3$,步长为3,无参数,可将特征图压缩为$4 \times 20$; 239 | 6. 第六层为批归一化层,并使用RReLU函数激活; 240 | 7. 第七层为第三个卷积层,卷积核大小为$3 \times 3$,步长为1,通道数为64,可得到大小为$2 \times 18$的特征图64个,参数量为$32 \times 3 \times 3 \times 64$; 241 | 8. 第八层为最大池化层,池化视野为$2 \times 2$,步长为2,无参数,特征图被进一步压缩为$1 \times 9$; 242 | 9. 第九层为归一化层,归一化后使用RReLu函数激活; 243 | 10. 第十层为Dropout层,输入为第九层输出展开后的特征向量,维度为576维,该层采用$ p = 0.15$ 的概率失效一定神经元; 244 | 11. 第十一层为全连接层,输入为第十层的输出,维度为576维,全连接层输出维度为128维,参数量为$576 \times 128$,并使用RReLU函数激活; 245 | 12. 第十二层为全连接的Dropout层,神经元失效概率为$p=0.1$; 246 | 13. 第十三层为输出层,根据multi-task数量,为4个输出层,维度为62维,使用softmax函数激活,参数量为$4 \times 128 \times 62$; 247 | 248 | > 模型部分参数未描述,由于是少量参数,相比之下可以忽略,如RReLu中的参数。 249 | 250 | 251 | ## 模型训练 252 | 253 | ### 优化方法与超参数 254 | 255 | 在该模型中,采用了Adam作为优化算法,并设定学习率为0.001,可达到较好效果。在模型训练过程中,尝试使用较大学习率,如0.01, 0.1, 0.05等,均不如低学习率收敛效果好。上述两个模型,均在[Google Colab Pro](https://colab.research.google.com)上使用P100训练,该算力可胜任batch至少为1024的配置,在EasyNet模型中使用了512的batch,而KCapNet使用1024的batch。 256 | 257 | > 该batch设置未达到算力极限,如有条件可测试,但是不推荐模型采用较大batch,而应尽可能选择合理的batch。 258 | 259 | 模型训练过程中,优化算法未使用学习率衰减算法。 260 | 261 | 在模型训练过程中,对于EasyNet,采用$p=0.3$的Dropout能达到较好效果,若采用$0.4 \sim 0.5$效果略差,但精度仍然可观,可见对于EasyNet其数据简单因而模型即便简单也仍能达到较好效果。 262 | 263 | 而对于KCapNet,Dropout从最初的$0.5$拟合效果较差,大概稳定在$85\% $上下,而逐步降低Dropout拟合能力逐渐提升,最终在$p_1=p_2=0.2$时效果较好,最终采用$p_1=0.15,p_2=0.1$得到最终模型,其训练集精度为$95\%$左右,验证集精度为$93\sim 94 \%$。 264 | 265 | ### 数据集划分 266 | 267 | 在模型训练过程中,默认采用$6:1:1$的分配比切分训练集、验证集、测试集,切分过程大致为: 268 | 269 | 1. 按3:1第一次切分,其中$75\%$为训练集; 270 | 2. 对上述剩余的$25\%$进行$1:1$切分,得到训练集及测试集。 271 | 272 | 根据需要,开发者自行训练模型时,可根据需要手动指定数据集切分比例。 273 | 274 | ### EasyNet 275 | 276 | 下图为EasyNet训练过程的模型损失曲线,从图中可以看出,模型在前10个epoch迅速收敛,在20 epoch之后,模型达到相对稳定状态。从图中可以看出,验证集损失相较于训练集损失,下降比较健康,并且手链曲线相对光滑,在后期也未出现验证集损失波动情况,说明其未发生严重过拟合,模型可以被认为训练过程可信。 277 | 278 | ![损失函数曲线](loss-easy.svg) 279 | 280 | 从精度曲线中可以看出,在训练初期,验证集上的精度基本能优于训练集上的精度,这得益于正则化手段,使得模型的子模型也能具有较好的表现,在25个epoch直至更后期,验证集精度和训练集精度开始趋于重合,甚至验证集精度略低于训练集精度,并且精度不再明显上升。从精度曲线的光滑程度来看,同样证明模型在训练过程中未发生严重过拟合,因此模型可信度及有效性较高。 281 | 282 | ![ACC曲线](acc-easy.svg) 283 | 284 | ### KCapNet 285 | 286 | 下图分别为KCapNet模型的损失曲线及精度曲线,从曲线中可以看出,在epoch为120时,曲线发生了剧烈波动,**这是因为在训练过程中,调整了batch的缘故**。通常,较大batch可以一定程度地加速模型收敛,使得梯度方向更加准确,更有利于模型收敛,但是batch过大会导致对于部分较低比例的hard sample影响被淡化,从而使得模型不具备hard sample的识别能力,制约了模型的拟合能力。因此在使用较大batch训练模型基本收敛后,调小batch以强化模型对于小部分样本的识别能力。根据损失曲线可以看出,模型收敛过程相对健康,在前25个epoch时,模型迅速收敛,并达到较好效果,随后训练集损失继续稳定下降,而训练集损失开始出现一定范围内的波动,但是未呈现明显的上升趋势,说明模型达到一定稳定程度的拟合能力。随着训练集损失的持续下降,验证集损失始终在1上下波动,无明显的损失整体下降趋势,因此在60 epoch之后,可以选择性早停,即Early stopping。 287 | 288 | 在120 epoch之后,即batch调笑之后,模型损失突然小幅度上升,随后继续下降,但验证集上损失较之前波动情况更加严重,这也一定程度地说明**较大的batch相较于较小的batch,能够使模型损失更加光滑**。 289 | 290 | 在该模型中,较小的batch取为256。 291 | 292 | ![损失曲线](loss.svg) 293 | 294 | 与损失曲线相反,在前25个epochs中,模型精度提升较快,并且能迅速达到0.9上下,随后训练集精度开始小幅度持续上升,而验证集精度开始出现波动,在70 epochs之后,验证集上的精度最好能达到$0.93\sim 0.94$上下。在调小batch之后,验证集的精度波动更大,但最好精度与大batch之前相差较小,说明在较大batch下,模型收敛相对较好。 295 | 296 | 综合损失曲线与精度曲线,可知,在70 epoch之后,选择$70 \sim 120$ epoch中损失最低的模型,可基本视为最佳模型。而在小batch之后,推荐选择$130 \sim 170$ epoch间的最低损失模型可达到较好效果。 297 | 298 | **在本项目提供的预训练模型中,选择了第169个epoch的模型,其训练集精度可达0.94。** 299 | 300 | ![ACC曲线](acc.svg) 301 | 302 | 下图为从测试集随机选择的5组验证码样本,其中大部分均识别正确(标绿),小部分识别错误(标红)。从标红的案例中可以看出,该验证码认为识别正确难度仍然较高,因此识别错误也可以接受。同时,根据更广泛的测试集评估研究,模型对于0与O的识别准确度较低,甚至于O大部分被识别为0,这大程度上地受验证码由于字体形变而引发,根据人工对这些特殊案例的对比,部分能够被人眼正确地分辨,而少部分缺失存在人为无法准确分辨的案例。可以认为,认为地区分0与O,可能有$60\sim 70\%$成功率,这也同样对模型的准确度产生了干扰。 303 | 304 | 由于模型达到了基本可接受的识别准确度,因此再未将识别错误的样本单独挑出并训练,从理论上推测, 将分类错误的样本挑出重新分类,可以一定程度地提升模型效果,进行该操作的方法可有两种: 305 | 306 | 1. 将识别错误的训练数据单独挑出,并重新构成训练集,并重新训练,该方式可能使得模型对这些样本过于拟合,因此训练的迭代次数需要控制; 307 | 2. 将识别错误的样本标记,再下一轮训练时,在损失函数上,为上次识别失败的样本增大权重,使得分类错误的样本对模型的提升影响更大,降低正确识别样本对模型的影响,但是训练时仍提供正确样本,能够避免第一种方法的过拟合。(类似于Boosting) 308 | 309 | 310 | ![测试结果](test.svg) 311 | 312 | > 上述模型的提升方法,有条件地可以进一步实验,以进一步提升模型性能。同时,对于验证码识别,还可以考虑使用注意力机制,针对不同的输出层关注不同的Feature Map,从直观上理解,应该能一定程度地提升模型的拟合能力,开发者们可以进一步尝试。 313 | 314 | ## 算力要求 315 | 316 | 为了能够尽量评估运算所需算力,可以对模型的内存消耗进行评估,此处忽略激活函数中的参数,偏置等少量参数,模型的算力要求应等于**参数量+输入输出+梯度与动量**,根据神经网络反向传播理论,在更新参数时需要计算下一层输出关于上一层参数的梯度,因此**参数量==梯度**,而在优化方法中需要保存动量,以记录之前参数更新的历史记录,因此**参数量==动量**,而对于Adam优化器,则更有**动量==2参数量**,因此整个模型的算力要求为: 317 | 318 | $MEM = W*4+I+O$ 319 | 320 | 通常网络中使用的数据类型为Float32类型,其占4 Byte,于是便可通过存储量来计算内存消耗。 321 | 322 | ### EasyNet算力计算 323 | 324 | 根据下表统计,该模型算力大致要求为:1.9 MB /sample。 325 | 326 | | 层 | 参数量 | 特征图 | 所需内存 | 327 | | -------- | -------------------------------------- | ------------------------------- | -------- | 328 | | Input | 0 | $3 \times 80 \times 120=19200$ | 75 KB | 329 | | Conv1 | $3 \times 5 \times 5 \times 16 = 1200$ | $76\times 116 \times 16=141056$ | 569.8 KB | 330 | | Maxpool1 | 0 | $16 \times 15 \times 23=5520$ | 21.6 KB | 331 | | BN1 | $2 \times 16=32$ | $16 \times 15 \times 23=5520$ | 22.0 KB | 332 | | Conv2 | $16 \times 5 \times 5 \times 32=12800$ | $11 \times 19 \times 32 = 6688$ | 226.1 KB | 333 | | Maxpool2 | 0 | $2 \times 4 \times 32 = 256$ | 1 KB | 334 | | BN2 | $2\times 32=64$ | $2 \times 4 \times 32 = 256$ | 2 KB | 335 | | Output | $4\times 256 \times 62=63488$ | $4 \times 62=248$ | 993.0 KB | 336 | 337 | ### KCap算力计算 338 | 339 | 根据下表统计,其算力要求大致为:2.9 MB /sample。 340 | 341 | | 层 | 参数量 | 特征图 | 所需内存 | 342 | | -------- | -------------------------------------- | -------------------------------- | --------- | 343 | | Input | 0 | $3 \times 50 \times 200=30000$ | 117.2 KB | 344 | | Conv1 | $3 \times 5 \times 5 \times 16=1200$ | $46 \times 196 \times 16=144256$ | 582.3 KB | 345 | | Maxpool2 | 0 | $15 \times 64 \times 16=15360$ | 60 KB | 346 | | BN1 | $2 \times 16=32$ | $15 \times 64 \times 16=15360$ | 60.5 KB | 347 | | Conv2 | $16 \times 3 \times 3 \times 32=4608$ | $13\times 62 \times 32=25792$ | 172.75 KB | 348 | | Maxpool2 | 0 | $4\times 20 \times 32=2560$ | 10 KB | 349 | | BN2 | $2 \times 32=64$ | $4\times 20 \times 32=2560$ | 11 KB | 350 | | Conv3 | $32 \times 3 \times 3 \times 64=18432$ | $2\times 18 \times 64=2304$ | 297 KB | 351 | | Maxpool3 | 0 | $1 \times 9 \times 64=576$ | 2.3 KB | 352 | | BN3 | $2\times 64=128$ | $1 \times 9 \times 64=576$ | 4.3 KB | 353 | | Fatten | 0 | $576$ | 2.25 KB | 354 | | FCN | $576\times 128=73728$ | $128 $ | 1152 KB | 355 | | BN4 | $2 \times 128=256$ | $128$ | 4.5 KB | 356 | | Output | $128 \times 62 \times 4=31744$ | $4 \times 62 = 248$ | 497.0 kB | 357 | 358 | -------------------------------------------------------------------------------- /acc-easy.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 321 | 341 | 365 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 637 | 658 | 679 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 770 | 771 | 772 | 833 | 834 | 835 | 838 | 839 | 840 | 843 | 844 | 845 | 848 | 849 | 850 | 853 | 854 | 855 | 856 | 857 | 871 | 892 | 924 | 940 | 941 | 950 | 951 | 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | 962 | 963 | 964 | 965 | 966 | 967 | 968 | 969 | 970 | 981 | 982 | 983 | 986 | 987 | 988 | 989 | 990 | 991 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1027 | 1028 | 1029 | 1030 | 1031 | 1032 | 1058 | 1059 | 1060 | 1061 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | 1068 | 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1078 | -------------------------------------------------------------------------------- /acc.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 104 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 276 | 302 | 323 | 344 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 571 | 592 | 608 | 640 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 792 | 793 | 794 | 963 | 964 | 965 | 969 | 970 | 971 | 974 | 975 | 976 | 979 | 980 | 981 | 984 | 985 | 986 | 989 | 990 | 991 | 992 | 993 | 994 | 1003 | 1004 | 1005 | 1006 | 1007 | 1008 | 1009 | 1010 | 1011 | 1012 | 1013 | 1014 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1034 | 1035 | 1036 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1064 | 1075 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | 1100 | 1101 | 1102 | 1103 | 1104 | 1105 | 1106 | 1107 | 1108 | 1111 | 1112 | 1113 | 1114 | 1115 | 1116 | 1142 | 1143 | 1144 | 1145 | 1146 | 1147 | 1148 | 1149 | 1150 | 1151 | 1152 | 1153 | 1154 | 1155 | 1156 | 1157 | 1158 | 1159 | 1160 | 1161 | 1162 | -------------------------------------------------------------------------------- /captcha-images/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /captcha-images/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | captcha-images 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.m2e.core.maven2Builder 15 | 16 | 17 | 18 | 19 | 20 | org.eclipse.jdt.core.javanature 21 | org.eclipse.m2e.core.maven2Nature 22 | 23 | 24 | -------------------------------------------------------------------------------- /captcha-images/.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding//src/main/java=UTF-8 3 | encoding//src/test/java=UTF-8 4 | encoding/=UTF-8 5 | -------------------------------------------------------------------------------- /captcha-images/.settings/org.eclipse.jdt.apt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.apt.aptEnabled=false 3 | -------------------------------------------------------------------------------- /captcha-images/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.7 3 | org.eclipse.jdt.core.compiler.compliance=1.7 4 | org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled 5 | org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning 6 | org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore 7 | org.eclipse.jdt.core.compiler.processAnnotations=disabled 8 | org.eclipse.jdt.core.compiler.release=disabled 9 | org.eclipse.jdt.core.compiler.source=1.7 10 | -------------------------------------------------------------------------------- /captcha-images/.settings/org.eclipse.m2e.core.prefs: -------------------------------------------------------------------------------- 1 | activeProfiles= 2 | eclipse.preferences.version=1 3 | resolveWorkspaceProjects=true 4 | version=1 5 | -------------------------------------------------------------------------------- /captcha-images/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | me.zouzhipeng 8 | captcha-images 9 | 1.0 10 | 11 | captcha-images 12 | http://captcha.zouzhipeng.me/ 13 | 14 | 15 | UTF-8 16 | 1.7 17 | 1.7 18 | 19 | 20 | 21 | 22 | junit 23 | junit 24 | 4.13.1 25 | test 26 | 27 | 28 | com.github.whvcse 29 | easy-captcha 30 | 1.6.2 31 | 32 | 33 | 34 | commons-cli 35 | commons-cli 36 | 1.4 37 | 38 | 39 | 40 | com.github.penggle 41 | kaptcha 42 | 2.3.2 43 | 44 | 45 | 46 | log4j 47 | log4j 48 | 1.2.17 49 | 50 | 51 | 52 | 53 | 54 | 55 | src/main/resources 56 | 57 | **/*.properties 58 | **/*.xml 59 | **/*.tld 60 | 61 | false 62 | 63 | 64 | src/main/java 65 | 66 | **/*.properties 67 | **/*.xml 68 | **/*.tld 69 | 70 | false 71 | 72 | 73 | 74 | 75 | 76 | 77 | maven-clean-plugin 78 | 3.1.0 79 | 80 | 81 | 82 | maven-resources-plugin 83 | 3.0.2 84 | 85 | 86 | maven-compiler-plugin 87 | 3.8.0 88 | 89 | 90 | maven-surefire-plugin 91 | 2.22.1 92 | 93 | 94 | maven-jar-plugin 95 | 3.0.2 96 | 97 | 98 | maven-install-plugin 99 | 2.5.2 100 | 101 | 102 | maven-deploy-plugin 103 | 2.8.2 104 | 105 | 106 | 107 | maven-site-plugin 108 | 3.7.1 109 | 110 | 111 | maven-project-info-reports-plugin 112 | 3.0.0 113 | 114 | 115 | maven-assembly-plugin 116 | 117 | 118 | 119 | 120 | 121 | me.zouzhipeng.App 122 | 123 | 124 | 125 | jar-with-dependencies 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /captcha-images/src/main/java/log4j.properties: -------------------------------------------------------------------------------- 1 | 2 | ### 定义了log4j.rootLogger,优先级为debug,即包含了所有的优先级日志 ### 3 | ### 定义了三个 Appender,分别为 stdout, D, E ### 4 | log4j.rootLogger = debug,stdout 5 | 6 | ### 设置 stdout,它实现了控制台日志输出 ### 7 | log4j.appender.stdout = org.apache.log4j.ConsoleAppender 8 | log4j.appender.stdout.Target = System.out 9 | ### 使用了自定义的布局模式 ### 10 | log4j.appender.stdout.layout = org.apache.log4j.PatternLayout 11 | ### 自定义了输出模式 ### 12 | log4j.appender.stdout.layout.ConversionPattern = [%-5p] %d{yyyy-MM-dd HH:mm:ss,SSS} method:%l%n%m%n -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/App.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import org.apache.commons.cli.CommandLine; 4 | import org.apache.commons.cli.CommandLineParser; 5 | import org.apache.commons.cli.DefaultParser; 6 | import org.apache.commons.cli.HelpFormatter; 7 | import org.apache.commons.cli.Option; 8 | import org.apache.commons.cli.Options; 9 | import org.apache.commons.cli.ParseException; 10 | 11 | import me.zouzhipeng.config.Config; 12 | import me.zouzhipeng.config.ConfigConstants; 13 | 14 | /** 15 | * 16 | * 17 | */ 18 | public class App { 19 | public static void main(String[] args) { 20 | // generateCaptcha("/Users/frank/Downloads/captchas", 50000); 21 | // generateKaptcha("/Users/frank/Downloads/kaptchas-test", 10, 10/2); 22 | Options options = defineOptions(); 23 | 24 | // String[] testArgs = {"-m", "kaptcha", "-t", "10"}; 25 | Config conf = parseOptions(options, args); 26 | 27 | if (null == conf) { 28 | return; 29 | } 30 | 31 | Generator generator = new GeneratorMentor(); 32 | generator.generate(conf); 33 | } 34 | 35 | public static Options defineOptions() { 36 | Option modeOption = Option.builder("m").argName("mode").desc(String 37 | .format("The mode of captcha, could be [%s] or [%s]", ConfigConstants.EASY_CAPTCHA, ConfigConstants.KAPTCHA)) 38 | .hasArg().required(false).build(); 39 | Option helpOption = Option.builder("h").argName("help").desc("Show help").required(false).hasArg(false).build(); 40 | Option outputDirOption = Option.builder("o").argName("output_dir").hasArg() 41 | .desc("The output directory of captchas, string").required(false).build(); 42 | 43 | Option widthOption = Option.builder("w").hasArg().argName("width") 44 | .desc("The width of the capthcas, integer, default 120").required(false).build(); 45 | Option heightOption = Option.builder("v").hasArg().argName("height") 46 | .desc("The height of the capthcas, integer, default 80").required(false).build(); 47 | Option sizeOption = Option.builder("s").argName("size") 48 | .desc("The size of the capthcas, should be: ,").hasArg().numberOfArgs(2).valueSeparator(',') 49 | .required(false).build(); 50 | Option textColorOption = Option.builder("c").argName("text_color") 51 | .desc("The text color of the captcha, if not specified, use randomly, should be ,,").hasArg() 52 | .required(false).build(); 53 | Option noiseColorOption = Option.builder("n").argName("noise_color") 54 | .desc("The noise color of the captcha, if not specified, use randomly, should be ,,").hasArg() 55 | .required(false).build(); 56 | Option lengthOption = Option.builder("l").argName("length") 57 | .desc("The length of characters in the captchas, integer, default 4").hasArg().required(false).build(); 58 | Option countOption = Option.builder("t").argName("count").desc("The count to produce, integer, default 5,000") 59 | .hasArg().required(false).build(); 60 | Option kindOption = Option.builder("k").argName("kinds").desc("The kinds of captchas, integer, default 1").hasArg() 61 | .required(false).build(); 62 | Option poolOption = Option.builder("p").argName("pool_size").desc("Thread pool size, integer, default 20").hasArg() 63 | .required(false).build(); 64 | Option sameColorOption = Option.builder("e").argName("noise_same_text") 65 | .desc("If the noise color is the same with the text color, should be true or false, default false.").hasArg() 66 | .required(false).build(); 67 | 68 | Options options = new Options(); 69 | options.addOption(modeOption); 70 | options.addOption(helpOption); 71 | options.addOption(outputDirOption); 72 | options.addOption(widthOption); 73 | options.addOption(heightOption); 74 | options.addOption(sizeOption); 75 | options.addOption(textColorOption); 76 | options.addOption(noiseColorOption); 77 | options.addOption(lengthOption); 78 | options.addOption(countOption); 79 | options.addOption(kindOption); 80 | options.addOption(poolOption); 81 | options.addOption(sameColorOption); 82 | 83 | return options; 84 | } 85 | 86 | public static Config parseOptions(Options options, String[] args) { 87 | HelpFormatter helpFormatter = new HelpFormatter(); 88 | CommandLine cmdLine = null; 89 | CommandLineParser cmdParser = new DefaultParser(); 90 | try { 91 | cmdLine = cmdParser.parse(options, args); 92 | if (null != cmdLine) { 93 | if (0 >= cmdLine.getOptions().length) { 94 | return new Config(); 95 | } 96 | Config config = new Config(); 97 | if (cmdLine.hasOption('h')) { 98 | helpFormatter.printHelp("java -jar [jarfile].jar", options); 99 | return null; 100 | } else { 101 | if (cmdLine.hasOption('m')) { 102 | String modeValue = cmdLine.getOptionValue('m'); 103 | if (modeValue.equals(ConfigConstants.KAPTCHA) || modeValue.equals(ConfigConstants.EASY_CAPTCHA)) { 104 | config.set(ConfigConstants.MODE, modeValue); 105 | } else { 106 | throw new ParseException(String.format("Only %s and %s modes are supported!", ConfigConstants.EASY_CAPTCHA, 107 | ConfigConstants.KAPTCHA)); 108 | } 109 | } 110 | if (cmdLine.hasOption('o')) { 111 | String outputDirValue = cmdLine.getOptionValue('o'); 112 | config.set(ConfigConstants.OUT_DIR, outputDirValue); 113 | } 114 | if (cmdLine.hasOption('w')) { 115 | String widthValue = cmdLine.getOptionValue('w'); 116 | config.set(ConfigConstants.WIDTH, widthValue); 117 | } 118 | if (cmdLine.hasOption('v')) { 119 | String heightValue = cmdLine.getOptionValue('v'); 120 | config.set(ConfigConstants.HEIGHT, heightValue); 121 | } 122 | if (cmdLine.hasOption('s')) { 123 | String[] sizeValue = cmdLine.getOptionValues('s'); 124 | config.set(ConfigConstants.WIDTH, sizeValue[0]); 125 | config.set(ConfigConstants.HEIGHT, sizeValue[1]); 126 | } 127 | if (cmdLine.hasOption('c')) { 128 | String textColorValue = cmdLine.getOptionValue('c'); 129 | config.set(ConfigConstants.TEXT_COLOR, textColorValue); 130 | } 131 | if (cmdLine.hasOption('n')) { 132 | String noiseColor = cmdLine.getOptionValue('n'); 133 | config.set(ConfigConstants.NOISE_COLOR, noiseColor); 134 | } 135 | if (cmdLine.hasOption('l')) { 136 | String lengthValue = cmdLine.getOptionValue('l'); 137 | config.set(ConfigConstants.LENGTH, lengthValue); 138 | } 139 | if (cmdLine.hasOption('t')) { 140 | String countValue = cmdLine.getOptionValue('t'); 141 | config.set(ConfigConstants.COUNT, countValue); 142 | } 143 | if (cmdLine.hasOption('k')) { 144 | String kindCountValue = cmdLine.getOptionValue('k'); 145 | config.set(ConfigConstants.KIND, kindCountValue); 146 | } 147 | if (cmdLine.hasOption('p')) { 148 | String poolSize = cmdLine.getOptionValue('k'); 149 | config.set(ConfigConstants.POOL_SIZE, poolSize); 150 | } 151 | if (cmdLine.hasOption('e')) { 152 | String noiseEqualTextColorValue = cmdLine.getOptionValue('e'); 153 | if (noiseEqualTextColorValue.equals("true") || noiseEqualTextColorValue.equals("false")) { 154 | config.set(ConfigConstants.NOISE_SAME_TEXT_COLOR, noiseEqualTextColorValue); 155 | } else { 156 | throw new ParseException("Only can be true or false."); 157 | } 158 | } 159 | } 160 | return config; 161 | } 162 | } catch (ParseException ex) { 163 | helpFormatter.printHelp("java -jar [jarfile].jar", options); 164 | } 165 | 166 | return null; 167 | } 168 | 169 | } 170 | -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/CaptchaGenerator.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | public interface CaptchaGenerator { 4 | 5 | /** 6 | * Generator captcha images automaticlly to destination. 7 | * @param folder the folder path to output 8 | * @return Successful or failed after created. 9 | */ 10 | @Deprecated 11 | public boolean generate(String folder); 12 | 13 | /** 14 | * Generator captcha images automaticlly to current workspace. 15 | * @return Successful or failed after created. 16 | */ 17 | public boolean generate(); 18 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/CaptchaTaskRunner.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import org.apache.log4j.Logger; 4 | 5 | public class CaptchaTaskRunner implements Runnable { 6 | 7 | private static final Logger LOG = Logger.getLogger(CaptchaTaskRunner.class); 8 | 9 | private CaptchaGenerator generator; 10 | 11 | @Override 12 | public void run() { 13 | boolean success = generator.generate(); 14 | if (success) { 15 | if (LOG.isInfoEnabled()) { 16 | LOG.info("Complete!"); 17 | } 18 | } else { 19 | if (LOG.isInfoEnabled()) { 20 | LOG.info("Failed!"); 21 | } 22 | } 23 | } 24 | 25 | /** 26 | * @return CaptchaGenerator return the generator 27 | */ 28 | public CaptchaGenerator getGenerator() { 29 | return generator; 30 | } 31 | 32 | /** 33 | * @param generator the generator to set 34 | */ 35 | public void setGenerator(CaptchaGenerator generator) { 36 | this.generator = generator; 37 | } 38 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/EasyCaptchaGeneratorWorker.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import java.awt.FontFormatException; 4 | import java.io.IOException; 5 | 6 | import com.wf.captcha.SpecCaptcha; 7 | import com.wf.captcha.base.Captcha; 8 | 9 | import org.apache.log4j.Logger; 10 | 11 | import me.zouzhipeng.config.Config; 12 | import me.zouzhipeng.config.ConfigConstants; 13 | import me.zouzhipeng.utils.ImageOutputUtil; 14 | 15 | public class EasyCaptchaGeneratorWorker implements CaptchaGenerator { 16 | 17 | private Config config; 18 | 19 | private static final Logger LOG = Logger.getLogger(EasyCaptchaGeneratorWorker.class); 20 | 21 | public EasyCaptchaGeneratorWorker(Config config) { 22 | this.config = config; 23 | } 24 | 25 | 26 | @Override 27 | public boolean generate(String path) { 28 | SpecCaptcha captcha = new SpecCaptcha(120, 80, 4); 29 | captcha.setCharType(Captcha.TYPE_DEFAULT); 30 | try { 31 | captcha.setFont(Captcha.FONT_3); 32 | } catch (IOException | FontFormatException e1) { 33 | e1.printStackTrace(); 34 | return false; 35 | } 36 | String codes = captcha.text(); 37 | return ImageOutputUtil.writeToFile(captcha, path, codes); 38 | } 39 | 40 | @Override 41 | public boolean generate() { 42 | String outputFolder = config.get(ConfigConstants.OUT_DIR); 43 | int width = Integer.parseInt(config.get(ConfigConstants.WIDTH, "120")); 44 | int height = Integer.parseInt(config.get(ConfigConstants.HEIGHT, "80")); 45 | int len = Integer.parseInt(config.get(ConfigConstants.LENGTH)); 46 | SpecCaptcha captcha = new SpecCaptcha(width, height, len); 47 | captcha.setCharType(Captcha.TYPE_DEFAULT); 48 | try { 49 | captcha.setFont(Captcha.FONT_3); 50 | } catch (IOException | FontFormatException e1) { 51 | e1.printStackTrace(); 52 | return false; 53 | } 54 | String codes = captcha.text(); 55 | if (LOG.isInfoEnabled()) { 56 | LOG.info("Generating " + codes + "..."); 57 | } 58 | return ImageOutputUtil.writeToFile(captcha, outputFolder, codes); 59 | } 60 | 61 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/Generator.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import java.util.Properties; 4 | 5 | import me.zouzhipeng.config.Config; 6 | 7 | public interface Generator { 8 | /** 9 | * 10 | * @param config 11 | */ 12 | public void generate(Config config); 13 | /** 14 | * 15 | * @param prop 16 | */ 17 | public void generate(Properties prop); 18 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/GeneratorMentor.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import java.util.Properties; 4 | import java.util.concurrent.ExecutorService; 5 | import java.util.concurrent.Executors; 6 | 7 | import org.apache.log4j.Logger; 8 | 9 | import me.zouzhipeng.config.Config; 10 | import me.zouzhipeng.config.ConfigConstants; 11 | 12 | public class GeneratorMentor implements Generator { 13 | 14 | private static final Logger LOG = Logger.getLogger(GeneratorMentor.class); 15 | @Override 16 | public void generate(Config config) { 17 | int count = Integer.parseInt(config.get(ConfigConstants.COUNT)); 18 | int kind = Integer.parseInt(config.get(ConfigConstants.KIND)); 19 | 20 | String mode = config.get(ConfigConstants.MODE); 21 | 22 | if (mode.equals(ConfigConstants.EASY_CAPTCHA)) { 23 | startWorking(mode, count, kind, config); 24 | if (LOG.isInfoEnabled()) { 25 | LOG.info("Starting task, current mode: " + mode + "."); 26 | } 27 | } else if (mode.equals(ConfigConstants.KAPTCHA)) { 28 | startWorking(mode, count, kind, config); 29 | if (LOG.isInfoEnabled()) { 30 | LOG.info("Starting task, current mode: " + mode + "."); 31 | } 32 | } else { 33 | if (LOG.isInfoEnabled()) { 34 | LOG.info("WARN: " + mode + " is not be supported, skip this task."); 35 | } 36 | } 37 | 38 | } 39 | 40 | @Override 41 | public void generate(Properties prop) { 42 | Config cfg = new Config(prop); 43 | generate(cfg); 44 | } 45 | 46 | /** 47 | * 48 | * @param mode 49 | * @param count 50 | * @param kind 51 | * @param config 52 | * @throws NullPointerException 53 | */ 54 | protected void startWorking(String mode, int count, int kind, final Config config) { 55 | int size = Integer.parseInt(config.get(ConfigConstants.POOL_SIZE, "20")); 56 | ExecutorService pool = Executors.newFixedThreadPool(size); 57 | CaptchaGenerator generator = null; 58 | for (int i = 0; i < count; i++) { 59 | CaptchaTaskRunner runner = new CaptchaTaskRunner(); 60 | if (i % (count / kind) == 0) { 61 | if (mode.equals(ConfigConstants.KAPTCHA)) { 62 | generator = new KaptchaGeneratorWorker(config); 63 | } else if (mode.equals(ConfigConstants.EASY_CAPTCHA)) { 64 | generator = new EasyCaptchaGeneratorWorker(config); 65 | } else { 66 | NullPointerException nullGeneratorEx = new NullPointerException("Unsupported mode made generator null."); 67 | if (LOG.isTraceEnabled()) { 68 | LOG.trace(nullGeneratorEx); 69 | } 70 | throw nullGeneratorEx; 71 | } 72 | } 73 | runner.setGenerator(generator); 74 | pool.submit(runner); 75 | } 76 | pool.shutdown(); 77 | } 78 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/KaptchaGeneratorWorker.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.util.Properties; 5 | import java.util.Random; 6 | 7 | import com.google.code.kaptcha.Constants; 8 | import com.google.code.kaptcha.Producer; 9 | import com.google.code.kaptcha.util.Config; 10 | 11 | import org.apache.log4j.Logger; 12 | 13 | import me.zouzhipeng.config.ConfigConstants; 14 | import me.zouzhipeng.utils.ImageOutputUtil; 15 | 16 | public class KaptchaGeneratorWorker implements CaptchaGenerator { 17 | private Random rand = new Random(); 18 | private Producer producer; 19 | private String output; 20 | private static final Logger LOG = Logger.getLogger(KaptchaGeneratorWorker.class); 21 | 22 | public KaptchaGeneratorWorker(me.zouzhipeng.config.Config config) { 23 | Properties prop = new Properties(); 24 | prop.put(Constants.KAPTCHA_BORDER, true); 25 | prop.put(Constants.KAPTCHA_BORDER_COLOR, 26 | String.join(",", rand.nextInt(256) + "", rand.nextInt(256) + "", rand.nextInt(256) + "")); 27 | prop.put(Constants.KAPTCHA_IMAGE_WIDTH, config.get(ConfigConstants.WIDTH, "200")); 28 | prop.put(Constants.KAPTCHA_IMAGE_HEIGHT, config.get(ConfigConstants.HEIGHT, "50")); 29 | String textColor = config.get(ConfigConstants.TEXT_COLOR); 30 | if (null == textColor) { 31 | textColor = String.join(",", rand.nextInt(256) + "", rand.nextInt(256) + "", rand.nextInt(256) + ""); 32 | } 33 | prop.put(Constants.KAPTCHA_TEXTPRODUCER_FONT_COLOR, 34 | textColor); 35 | prop.put(Constants.KAPTCHA_TEXTPRODUCER_CHAR_LENGTH, config.get(ConfigConstants.LENGTH, "4")); 36 | prop.put(Constants.KAPTCHA_TEXTPRODUCER_FONT_NAMES, "彩云,宋体,楷体,微软雅黑,Arial,SimHei,SimKai,SimSum"); 37 | if (Boolean.parseBoolean(config.get(ConfigConstants.NOISE_SAME_TEXT_COLOR, "true"))) { 38 | prop.put(Constants.KAPTCHA_NOISE_COLOR, textColor); 39 | } else { 40 | prop.put(Constants.KAPTCHA_NOISE_COLOR, 41 | String.join(",", rand.nextInt(256) + "", rand.nextInt(256) + "", rand.nextInt(256) + "")); 42 | } 43 | prop.put(Constants.KAPTCHA_TEXTPRODUCER_CHAR_STRING, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"); 44 | this.output = config.get(ConfigConstants.OUT_DIR); 45 | Config kaptchaConfig = new Config(prop); 46 | producer = kaptchaConfig.getProducerImpl(); 47 | } 48 | 49 | @Override 50 | @Deprecated 51 | public boolean generate(String folder) { 52 | String text = producer.createText(); 53 | BufferedImage imageBuffered = producer.createImage(text); 54 | if (LOG.isInfoEnabled()) { 55 | LOG.info("Generating " + text + "..."); 56 | } 57 | return ImageOutputUtil.writeToFile(imageBuffered, folder, text, "jpg"); 58 | 59 | } 60 | 61 | @Override 62 | public boolean generate() { 63 | String text = producer.createText(); 64 | BufferedImage imageBuffered = producer.createImage(text); 65 | 66 | return ImageOutputUtil.writeToFile(imageBuffered, this.output, text, "jpg"); 67 | } 68 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/config/Config.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng.config; 2 | 3 | import java.io.File; 4 | import java.util.Properties; 5 | import java.util.Set; 6 | import java.util.Map.Entry; 7 | 8 | import org.apache.log4j.Logger; 9 | 10 | public class Config { 11 | private static final Logger LOG = Logger.getLogger(Config.class); 12 | private Properties prop; 13 | 14 | public Config() { 15 | this.prop = new Properties(); 16 | this.prop.setProperty(ConfigConstants.MODE, ConfigConstants.EASY_CAPTCHA); 17 | this.prop.setProperty(ConfigConstants.COUNT, "50000"); 18 | this.prop.setProperty(ConfigConstants.KIND, "1"); 19 | this.prop.setProperty(ConfigConstants.LENGTH, "4"); 20 | String cwd = System.getProperty("user.dir"); 21 | File outputDir = new File(cwd, "captchas"); 22 | this.prop.setProperty(ConfigConstants.OUT_DIR, outputDir.getAbsolutePath()); 23 | } 24 | 25 | public Config(Properties prop) { 26 | this(); 27 | Set> entries = this.prop.entrySet(); 28 | for (Entry entry : entries) { 29 | this.prop.put(entry.getKey(), entry.getValue()); 30 | } 31 | this.prop = prop; 32 | } 33 | 34 | public String get(String key) { 35 | if (key.contains(key)) { 36 | return prop.getProperty(key); 37 | } 38 | if (LOG.isInfoEnabled()) { 39 | LOG.info(key + " not found!"); 40 | } 41 | return null; 42 | } 43 | 44 | public String get(String key, String defaultValue) { 45 | String value = get(key); 46 | if (null == value) { 47 | return defaultValue; 48 | } else { 49 | return value; 50 | } 51 | } 52 | 53 | public void set(String key, String value) { 54 | this.prop.setProperty(key, value); 55 | } 56 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/config/ConfigBuilder.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng.config; 2 | 3 | public class ConfigBuilder { 4 | 5 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/config/ConfigConstants.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng.config; 2 | 3 | public class ConfigConstants { 4 | public static final String MODE = "captcha.mode"; 5 | public static final String KAPTCHA = "kaptcha"; 6 | public static final String EASY_CAPTCHA = "easycaptcha"; 7 | public static final String OUT_DIR = "output.dir"; 8 | public static final String WIDTH = "captcha.width"; 9 | public static final String HEIGHT = "captcha.height"; 10 | public static final String TEXT_COLOR = "captcha.kaptcha.textcolor"; 11 | public static final String NOISE_COLOR = "captcha.kaptcha.noisecolor"; 12 | public static final String LENGTH = "captcha.length"; 13 | public static final String COUNT = "captcha.count"; 14 | public static final String KIND = "captcha.kind"; 15 | public static final String NOISE_SAME_TEXT_COLOR = "kaptcha.noise.color.samewithtext"; 16 | public static final String POOL_SIZE = "pool.size"; 17 | } -------------------------------------------------------------------------------- /captcha-images/src/main/java/me/zouzhipeng/utils/ImageOutputUtil.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng.utils; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.io.File; 5 | import java.io.FileNotFoundException; 6 | import java.io.FileOutputStream; 7 | import java.io.IOException; 8 | 9 | import javax.imageio.ImageIO; 10 | 11 | import com.wf.captcha.base.Captcha; 12 | 13 | public class ImageOutputUtil { 14 | /** 15 | * To write buffered image to file. 16 | * @param image the image to save 17 | * @param folder the folder to save images 18 | * @param name the name of the image 19 | * @param extension the extension of the image, i.e. jpg 20 | * @return successful or failed after output 21 | */ 22 | public static boolean writeToFile(BufferedImage image, String folder, String name, String extension) { 23 | File imageFolder = new File(folder); 24 | if (!imageFolder.exists()) { 25 | imageFolder.mkdirs(); 26 | } 27 | File imageFile = new File(imageFolder, name + "." + extension); 28 | if (imageFile.exists()) { 29 | return false; 30 | } 31 | FileOutputStream imageOutput = null; 32 | try { 33 | imageOutput = new FileOutputStream(imageFile); 34 | ImageIO.write(image, extension, imageOutput); 35 | return true; 36 | } catch (FileNotFoundException e) { 37 | e.printStackTrace(); 38 | return false; 39 | } catch (IOException e) { 40 | e.printStackTrace(); 41 | return false; 42 | } finally { 43 | if (null != imageOutput) { 44 | try { 45 | imageOutput.close(); 46 | } catch (IOException e) { 47 | e.printStackTrace(); 48 | return false; 49 | } 50 | } 51 | } 52 | } 53 | 54 | /** 55 | * To write captcha image to an image file. 56 | * @param captcha the captcha to write 57 | * @param folder the folder to save captchas 58 | * @param name the name of the captcha 59 | * @return successful or failed after output 60 | */ 61 | public static boolean writeToFile(Captcha captcha, String folder, String name) { 62 | File imageFolder = new File(folder); 63 | if (!imageFolder.exists()) { 64 | imageFolder.mkdirs(); 65 | } 66 | File imageFile = new File(imageFolder, name + ".jpg"); 67 | if (imageFile.exists()) { 68 | return false; 69 | } 70 | FileOutputStream imageOutput = null; 71 | try { 72 | imageOutput = new FileOutputStream(imageFile); 73 | captcha.out(imageOutput); 74 | return true; 75 | } catch (FileNotFoundException e) { 76 | e.printStackTrace(); 77 | return false; 78 | } finally { 79 | if (null != imageOutput) { 80 | try { 81 | imageOutput.close(); 82 | } catch (IOException e) { 83 | e.printStackTrace(); 84 | return false; 85 | } 86 | } 87 | } 88 | } 89 | } -------------------------------------------------------------------------------- /captcha-images/src/test/java/me/zouzhipeng/AppTest.java: -------------------------------------------------------------------------------- 1 | package me.zouzhipeng; 2 | 3 | import static org.junit.Assert.assertTrue; 4 | 5 | import org.junit.Test; 6 | 7 | /** 8 | * Unit test for simple App. 9 | */ 10 | public class AppTest 11 | { 12 | /** 13 | * Rigorous Test :-) 14 | */ 15 | @Test 16 | public void shouldAnswerWithTrue() 17 | { 18 | assertTrue( true ); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /captchas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/captchas.jpg -------------------------------------------------------------------------------- /kaptchas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/kaptchas.jpg -------------------------------------------------------------------------------- /loss-easy.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 334 | 360 | 381 | 402 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 566 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 668 | 669 | 670 | 731 | 732 | 733 | 736 | 737 | 738 | 741 | 742 | 743 | 746 | 747 | 748 | 751 | 752 | 753 | 754 | 755 | 763 | 764 | 785 | 801 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 837 | 838 | 839 | 842 | 843 | 844 | 845 | 846 | 847 | 867 | 899 | 910 | 929 | 930 | 931 | 932 | 933 | 934 | 935 | 936 | 937 | 938 | 939 | 940 | 941 | 942 | 943 | 944 | 947 | 948 | 949 | 950 | 951 | 952 | 978 | 979 | 980 | 981 | 982 | 983 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 991 | 992 | 993 | 994 | 995 | 996 | 997 | 998 | 999 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.4.5.1 2 | click==7.1.1 3 | cycler==0.10.0 4 | joblib==0.14.1 5 | kiwisolver==1.2.0 6 | matplotlib==3.2.1 7 | numpy==1.18.2 8 | Pillow==8.2.0 9 | pyparsing==2.4.7 10 | python-dateutil==2.8.1 11 | scikit-learn==0.22.2.post1 12 | scipy==1.4.1 13 | six==1.14.0 14 | torch==1.4.0 15 | torch-vision==0.1.6.dev0 16 | tqdm==4.45.0 17 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from sklearn.model_selection import train_test_split 5 | 6 | def get_data(path): 7 | """ 8 | Get images name in path 9 | :param path: the path to save images 10 | :return: image list filled with names and their labels. 11 | """ 12 | image_names = os.listdir(path) 13 | image_names = [name for name in image_names if name.endswith(".jpg")] 14 | label2id, id2label = get_dict() 15 | results = [[label2id[name[:1]], label2id[name[1:2]], label2id[name[2:3]], label2id[name[3:4]]] for name in image_names] 16 | image_names = [os.path.join(path, name) for name in image_names] 17 | 18 | return image_names, np.array(results, dtype=np.int32) - 1 19 | 20 | def get_dict(): 21 | """ 22 | Get dictionary of id2label and label2id, id2label is a dictionary which indicates the label of an id and the label2id is a reversed from `label2id` 23 | :return: two dictionaries: label->id, id->label 24 | """ 25 | label2id = {} 26 | id2label = {} 27 | # upper case 28 | for i in range(26): 29 | label2id[chr(ord('A') + i)] = 1 + i 30 | id2label[1 + i] = chr(ord('A') + i) 31 | # lower case 32 | for i in range(26): 33 | label2id[chr(ord('a') + i)] = 1 + i + 26 34 | id2label[1 + i + 26] = chr(ord('a') + i) 35 | # numbers 36 | for i in range(10): 37 | label2id[chr(ord('0') + i)] = 53 + i 38 | id2label[53 + i] = chr(ord('0') + i) 39 | 40 | return label2id, id2label 41 | 42 | def get_data_split(path, split=[6, 1, 1], save=True, out_dir='./data', modes=['train', 'dev', 'test']): 43 | """ 44 | Get data after split. 45 | :param path: the path to save images 46 | :param split: the ratio of train set, dev set and test set 47 | :param out_dir: the output directory to save data files 48 | :param modes: the modes at different timestamp, support modes like: (train, dev, test), (train, dev) and (test) 49 | :return: six data with ratio specified by `split`. 50 | """ 51 | 52 | if not os.path.exists(out_dir): 53 | os.mkdir(out_dir) 54 | train_path, dev_path, test_path = os.path.join(out_dir, 'train.npy'), os.path.join(out_dir, 'dev.npy'), os.path.join(out_dir, 'test.npy') 55 | if os.path.exists(train_path) and os.path.exists(dev_path) and os.path.exists(test_path): 56 | 57 | if 'train' in modes: 58 | x_train, y_train = np.load(train_path, allow_pickle=True), np.load(os.path.join(out_dir, 'train.y.npy'), allow_pickle=True) 59 | if 'dev' in modes: 60 | x_dev, y_dev = np.load(dev_path, allow_pickle=True), np.load(os.path.join(out_dir, 'dev.y.npy')) 61 | if 'test' in modes: 62 | x_test, y_test = np.load(test_path, allow_pickle=True), np.load(os.path.join(out_dir, 'test.y.npy')) 63 | 64 | else: 65 | names, labels = get_data(path) 66 | 67 | ratios = np.array(split) / np.sum(split) 68 | 69 | x_train, x_dev_test, y_train, y_dev_test = train_test_split(names, labels, train_size=ratios[0]) 70 | ratios = np.array(split[1:]) / np.sum(split[1:]) 71 | x_dev, x_test, y_dev, y_test = train_test_split(x_dev_test, y_dev_test, train_size=ratios[0]) 72 | 73 | if save: 74 | np.save(train_path, x_train, allow_pickle=True) 75 | np.save(os.path.join(out_dir, 'train.y.npy'), y_train, allow_pickle=True) 76 | np.save(dev_path, x_dev, allow_pickle=True) 77 | np.save(os.path.join(out_dir, 'dev.y.npy'), y_dev, allow_pickle=True) 78 | np.save(test_path, x_test, allow_pickle=True) 79 | np.save(os.path.join(out_dir, 'test.y.npy'), y_test, allow_pickle=True) 80 | 81 | if 'train' in modes and 'dev' in modes and 'test' in modes: 82 | return x_train, y_train, x_dev, y_dev, x_test, y_test 83 | elif 'train' in modes and 'dev' in modes: 84 | return x_train, y_train, x_dev, y_dev 85 | elif 'test' in modes: 86 | return x_test, y_test -------------------------------------------------------------------------------- /src/data/dev.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/dev.npy -------------------------------------------------------------------------------- /src/data/dev.y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/dev.y.npy -------------------------------------------------------------------------------- /src/data/test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/test.npy -------------------------------------------------------------------------------- /src/data/test.y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/test.y.npy -------------------------------------------------------------------------------- /src/data/train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/train.npy -------------------------------------------------------------------------------- /src/data/train.y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/data/train.y.npy -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.dataset import Dataset 4 | import torchvision as tv 5 | from PIL import Image 6 | 7 | class CaptchaLoader(Dataset): 8 | def __init__(self, data, shuffle=True): 9 | super(CaptchaLoader, self).__init__() 10 | self.shuffle = shuffle 11 | x_data = data[0] 12 | self.y_data = data[1] 13 | 14 | self.image_transformer = tv.transforms.Compose(tv.transforms.ToTensor()) 15 | 16 | self.x_data = [] 17 | for path in x_data: 18 | img_pil = Image.open(path) 19 | self.x_data.append(self.image_transformer.transforms(img_pil)) 20 | 21 | def __len__(self): 22 | return self.y_data.shape[0] 23 | 24 | def __getitem__(self, index): 25 | actualIndex = index % self.y_data.shape[0] # avoid out of bound 26 | 27 | return self.x_data[actualIndex], torch.tensor(self.y_data[actualIndex], dtype=torch.int64) 28 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data import get_data_split 4 | 5 | from dataloader import CaptchaLoader 6 | 7 | from metrics import acc, multi_acc 8 | 9 | # from tqdm import tqdm 10 | from tqdm.notebook import tqdm 11 | 12 | import numpy as np 13 | 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | import json 17 | 18 | import os 19 | 20 | # from tqdm import tqdm_notebook as tqdm 21 | 22 | def eval(model_dir, data_dir, batch_size=64, log_dir='./logs', use_gpu=True, mode='captcha'): 23 | """ 24 | :param model_dir: 25 | :param data_dir: 26 | :param batch_size: 27 | :param log_dir: 28 | :param use_gpu: 29 | :param mode: 30 | :return: 31 | """ 32 | x_test, y_test = get_data_split(data_dir, modes=['test']) 33 | if mode == 'captcha': 34 | from model import CaptchaModel 35 | elif mode =='kaptcha': 36 | from kaptcha_model import CaptchaModel 37 | model = CaptchaModel() 38 | 39 | gpu_available = torch.cuda.is_available() 40 | 41 | if use_gpu and gpu_available: 42 | model = model.cuda() 43 | model_state = torch.load(model_dir) 44 | else: 45 | model_state = torch.load(model_dir, map_location=lambda storage, loc: storage) 46 | 47 | model.load_state_dict(model_state['network']) 48 | 49 | test_ds = CaptchaLoader((x_test, y_test), shuffle=True) 50 | 51 | test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True) 52 | 53 | model.eval() 54 | 55 | acc_history = [] 56 | with tqdm(total=int(np.ceil(len(test_loader.dataset) / batch_size)), desc='Eval') as eval_bar: 57 | for _, (x, y) in enumerate(test_loader): 58 | x = torch.tensor(x, requires_grad=False) 59 | y = torch.tensor(y, requires_grad=False) 60 | 61 | if use_gpu and gpu_available: 62 | x = x.cuda() 63 | y = y.cuda() 64 | 65 | pred1, pred2, pred3, pred4 = model(x) 66 | acc_mean = np.mean( 67 | [acc(pred1, y[:,0]), acc(pred2, y[:,1]), acc(pred3, y[:,2]), acc(pred4, y[:,3])] 68 | ) 69 | 70 | pred = torch.stack((pred1, pred2, pred3, pred4), dim=-1) 71 | multi_acc_mean = multi_acc(torch.argmax(pred, dim=1), y) 72 | 73 | acc_history.append([acc_mean.item(), multi_acc_mean]) 74 | 75 | eval_bar.update() 76 | eval_bar.set_postfix(acc=acc_mean, multi_acc=multi_acc_mean) 77 | 78 | if not os.path.exists(log_dir): 79 | os.mkdir(log_dir) 80 | with open(os.path.join(log_dir, 'eval.json'), mode=r'w') as out_fp: 81 | json.dump(acc_history, out_fp) 82 | 83 | 84 | import click 85 | 86 | @click.command() 87 | @click.help_option('-h', '--help') 88 | @click.option('-i', '--data_dir', default='./captchas', type=click.Path(), help='The path of train data', required=False) 89 | @click.option('-m', '--mode', default='captcha', help='The model type to train, could be captcha or kaptcha', type=click.Choice(['captcha', 'kaptcha']), required=False) 90 | @click.option('-b', '--batch_size', default=128, type=int, help='The batch size of input data', required=False) 91 | @click.option('-o', '--model_dir', default='./captcha_models/model-latest.pkl', type=click.Path(), help='The model dir to save models or load models', required=False) 92 | @click.option('-l', '--log_dir', default='./logs', type=click.Path(), help='The log files path', required=False) 93 | @click.option('-u', '--use_gpu', type=bool, default=False, help='Train by gpu or cpu', required=False) 94 | def read_cli(data_dir, mode, batch_size, model_dir, log_dir, use_gpu): 95 | """ 96 | 97 | :param data_dir: 98 | :param mode: 99 | :param batch_size: 100 | :param model_dir: 101 | :param log_dir: 102 | :param use_gpu: 103 | :return: 104 | """ 105 | eval(model_dir, data_dir, batch_size, log_dir, use_gpu, mode) 106 | 107 | 108 | 109 | if __name__ == "__main__": 110 | read_cli() -------------------------------------------------------------------------------- /src/kaptcha_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class CaptchaModel(nn.Module): 7 | def __init__(self): 8 | super(CaptchaModel, self).__init__() 9 | 10 | self.conv1 = nn.Sequential( 11 | nn.Conv2d(3, 16, 5), 12 | nn.MaxPool2d(3, 3), 13 | nn.BatchNorm2d(16), 14 | nn.RReLU() 15 | ) 16 | 17 | self.conv2 = nn.Sequential( 18 | nn.Conv2d(16, 32, 3), 19 | nn.MaxPool2d(3, 3), 20 | nn.BatchNorm2d(32), 21 | nn.RReLU() 22 | ) 23 | 24 | self.conv3 = nn.Sequential( 25 | nn.Conv2d(32, 64, 3), 26 | nn.MaxPool2d(2, 2), 27 | nn.BatchNorm2d(64), 28 | nn.RReLU(), 29 | nn.Flatten(), 30 | nn.Dropout(0.15) 31 | ) 32 | 33 | self.dense1 = nn.Sequential( 34 | nn.Linear(576, 128), 35 | nn.BatchNorm1d(128), 36 | nn.RReLU() 37 | ) 38 | 39 | self.dropout = nn.Dropout(0.1) 40 | 41 | self.out1 = nn.Linear(128, 62) 42 | self.out2 = nn.Linear(128, 62) 43 | self.out3 = nn.Linear(128, 62) 44 | self.out4 = nn.Linear(128, 62) 45 | 46 | def forward(self, input): 47 | y_conv1 = self.conv1(input) 48 | 49 | y_conv2 = self.conv2(y_conv1) 50 | 51 | y_conv3 = self.conv3(y_conv2) 52 | 53 | z_1 = self.dense1(y_conv3) 54 | 55 | z_dropout = self.dropout(z_1) 56 | 57 | y_1 = self.out1(z_dropout) 58 | y_2 = self.out2(z_dropout) 59 | y_3 = self.out3(z_dropout) 60 | y_4 = self.out4(z_dropout) 61 | 62 | return y_1, y_2, y_3, y_4 -------------------------------------------------------------------------------- /src/logs/eval.json: -------------------------------------------------------------------------------- 1 | [1.0, 0.99609375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.99609375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from train import train 2 | from data import get_dict 3 | 4 | if __name__ == '__main__': 5 | train('/Users/frank/Downloads/captchas') 6 | 7 | # label2id = get_dict() 8 | # print(label2id) -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def recall(input, target): 4 | pass 5 | 6 | def precision(input, target): 7 | pass 8 | 9 | def acc(pred, y): 10 | """ 11 | 12 | :param pred: 13 | :param y: 14 | :return: 15 | """ 16 | pred = torch.argmax(pred, dim=-1) 17 | eq = pred == y 18 | return eq.mean(dtype=torch.float32).item() 19 | # return (eq.sum(dtype=torch.float32) / eq.shape[0]).item() 20 | 21 | 22 | def multi_acc(pred, y): 23 | """ 24 | 25 | :param pred: 26 | :param y: 27 | :return: 28 | """ 29 | eq = pred == y 30 | all_eq = torch.all(eq, dim=-1) 31 | return torch.mean(all_eq, dtype=torch.float32).item() 32 | # return (all_eq.sum(dtype=torch.float32) / all.shape[0]).item() 33 | 34 | 35 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class CaptchaModel(nn.Module): 6 | def __init__(self): 7 | super(CaptchaModel, self).__init__() 8 | 9 | self.conv1 = nn.Sequential( 10 | nn.Conv2d(3, 16, 5), 11 | nn.MaxPool2d(5, 5), 12 | nn.BatchNorm2d(16), 13 | nn.RReLU() 14 | ) 15 | 16 | self.conv2 = nn.Sequential( 17 | nn.Conv2d(16, 32, 5), 18 | nn.MaxPool2d(4, 4), 19 | nn.BatchNorm2d(32), 20 | nn.RReLU(), 21 | nn.Flatten() 22 | ) 23 | 24 | # self.conv3 = nn.Sequential( 25 | # nn.Conv2d(32, 32, 3), 26 | # nn.MaxPool2d(2, 2), 27 | # nn.BatchNorm2d(32), 28 | # nn.RReLU(), 29 | # ) 30 | 31 | self.dense1 = nn.Sequential( 32 | nn.Linear(256, 64), 33 | nn.BatchNorm1d(64), 34 | nn.RReLU() 35 | ) 36 | 37 | self.dropout = nn.Dropout(0.3) 38 | 39 | self.out1 = nn.Linear(256, 62) 40 | self.out2 = nn.Linear(256, 62) 41 | self.out3 = nn.Linear(256, 62) 42 | self.out4 = nn.Linear(256, 62) 43 | 44 | def forward(self, input): 45 | y_conv1 = self.conv1(input) 46 | 47 | y_conv2 = self.conv2(y_conv1) 48 | 49 | # y_conv3 = self.conv3(y_conv2) 50 | 51 | z_1 = self.dense1(y_conv2) 52 | 53 | z_dropout = self.dropout(y_conv2) 54 | 55 | y_1 = self.out1(z_dropout) 56 | y_2 = self.out2(z_dropout) 57 | y_3 = self.out3(z_dropout) 58 | y_4 = self.out4(z_dropout) 59 | 60 | return y_1, y_2, y_3, y_4 -------------------------------------------------------------------------------- /src/model/captcha-model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/model/captcha-model.pkl -------------------------------------------------------------------------------- /src/model/kaptcha-model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/model/kaptcha-model.pkl -------------------------------------------------------------------------------- /src/model/model-latest.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zipzou/captcha-recognition/05b43461f37925d7e0f228ca183d2288e007ca0a/src/model/model-latest.pkl -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from PIL import Image 4 | from torchvision.transforms import Compose, ToTensor 5 | 6 | from data import get_dict 7 | 8 | def predict(captcha, model_dir='./model/model-latest.pkl', use_gpu=True, mode='captcha'): 9 | """ 10 | 11 | :param captcha: 12 | :param model_dir: 13 | :param use_gpu: 14 | :param mode: 15 | :return: 16 | """ 17 | gpu_available = torch.cuda.is_available() 18 | 19 | if mode == 'captcha': 20 | from model import CaptchaModel 21 | elif mode == 'kaptcha': 22 | from kaptcha_model import CaptchaModel 23 | else: 24 | return 25 | model = CaptchaModel() 26 | 27 | if use_gpu and gpu_available: 28 | model_state = torch.load(model_dir) 29 | else: 30 | model_state = torch.load(model_dir, map_location=lambda storage, loc: storage) 31 | 32 | model.load_state_dict(model_state['network']) 33 | 34 | if use_gpu and gpu_available: 35 | model = model.cuda() 36 | else: 37 | model = model.cpu() 38 | 39 | transformer = Compose(ToTensor()) 40 | 41 | img_pil = Image.open(captcha) 42 | img_tensor = transformer.transforms(img_pil) 43 | 44 | model.eval() 45 | x = torch.stack([img_tensor]) 46 | if use_gpu and gpu_available: 47 | x = x.cuda() 48 | pred1, pred2, pred3, pred4 = model(x) 49 | 50 | pred_seq = [torch.argmax(pred1).item(), torch.argmax(pred2).item(), torch.argmax(pred3).item(), torch.argmax(pred4).item()] 51 | pred_seq = [item + 1 for item in pred_seq] 52 | 53 | _, id2label = get_dict() 54 | 55 | res = ''.join([id2label[i] for i in pred_seq]) 56 | 57 | return res 58 | 59 | import click 60 | 61 | @click.command() 62 | @click.help_option('-h', '--help') 63 | @click.option('-i', '--image_path', type=click.Path(), help='The path of the captcha image', required=True) 64 | @click.option('-m', '--mode', default='captcha', help='The model type to train, could be captcha or kaptcha', type=click.Choice(['captcha', 'kaptcha']), required=False) 65 | @click.option('-o', '--model_dir', default='./captcha_models/model-latest.pkl', type=click.Path(), help='The model dir to save models or load models', required=False) 66 | @click.option('-u', '--use_gpu', default=False, type=bool, help='Train by gpu or cpu', required=False) 67 | def read_cli(image_path, model_dir, mode, use_gpu): 68 | """ 69 | 70 | :param image_path: 71 | :param model_dir: 72 | :param mode: 73 | :param use_gpu: 74 | :return: 75 | """ 76 | res = predict(image_path, model_dir, use_gpu, mode) 77 | print('The result of the captcha is: ' + str(res)) 78 | 79 | if __name__ == "__main__": 80 | read_cli() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | from torch.utils.data.dataloader import DataLoader 5 | 6 | import numpy as np 7 | 8 | import os 9 | 10 | import json 11 | 12 | # from tqdm import tqdm 13 | from tqdm.notebook import tqdm 14 | 15 | from data import get_data_split 16 | from dataloader import CaptchaLoader 17 | 18 | from metrics import acc, multi_acc 19 | 20 | 21 | def save_history(filename, history, history_path): 22 | """ 23 | 24 | :param filename: 25 | :param history: 26 | :param history_path: 27 | :return: 28 | """ 29 | if not os.path.exists(history_path): 30 | os.mkdir(history_path) 31 | out_file = os.path.join(history_path, filename) 32 | with open(out_file, mode=r'w', encoding='utf-8') as out_fp: 33 | json.dump(history, out_fp) 34 | 35 | def load_history(filename, history_path): 36 | """ 37 | 38 | :param filename: 39 | :param history_path: 40 | :return: 41 | """ 42 | in_path = os.path.join(history_path, filename) 43 | if not os.path.exists(in_path): 44 | return [] 45 | with open(in_path, mode=r'r') as in_fp: 46 | history = json.load(in_fp) 47 | return history 48 | 49 | def train(path, split=[6, 1, 1], batch_size=64, epochs=100, learning_rate=0.001, initial_epoch=0, step_saving=2, model_dir='./', log_file='./history', continue_pkl=None, gpu=True, mode='captcha'): 50 | """ 51 | 52 | :param path: 53 | :param split: 54 | :param batch_size: 55 | :param epochs: 56 | :param learning_rate: 57 | :param initial_epoch: 58 | :param step_saving: 59 | :param model_dir: 60 | :param log_file: 61 | :param continue_pkl: 62 | :param gpu: 63 | :param mode: 64 | :return: 65 | """ 66 | if mode == 'captcha': 67 | from model import CaptchaModel 68 | CaptchaModelDynamic = CaptchaModel 69 | elif mode == 'kaptcha': 70 | from kaptcha_model import CaptchaModel 71 | CaptchaModelDynamic = CaptchaModel 72 | else: 73 | return 74 | if not os.path.exists(path): 75 | raise FileNotFoundError("未知的训练数据") 76 | x_train, y_train, x_dev, y_dev = get_data_split(path, split=split, modes=['train', 'dev']) 77 | 78 | train_ds = CaptchaLoader((x_train, y_train), shuffle=True) 79 | dev_ds = CaptchaLoader((x_dev, y_dev)) 80 | 81 | train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) 82 | dev_loader = DataLoader(dev_ds, batch_size=batch_size, shuffle=True) 83 | 84 | gpu_available = torch.cuda.is_available() 85 | 86 | model = CaptchaModelDynamic() 87 | optm = torch.optim.Adam(model.parameters(), lr=learning_rate) 88 | loss_fn = nn.CrossEntropyLoss() 89 | 90 | if gpu and gpu_available: 91 | model = model.cuda() 92 | loss_fn = loss_fn.cuda() 93 | 94 | # start from a pickle 95 | if continue_pkl is not None and os.path.exists(os.path.join(model_dir, continue_pkl)): 96 | if gpu and gpu_available: 97 | initial_state = torch.load(os.path.join(model_dir, continue_pkl)) 98 | else: 99 | initial_state = torch.load(os.path.join(model_dir, continue_pkl), map_location=lambda storage, loc: storage) 100 | model.load_state_dict(initial_state['network']) 101 | optm.load_state_dict(initial_state['optimizer']) 102 | initial_epoch = initial_state['epoch'] + 1 103 | 104 | elif continue_pkl is not None and os.path.exists(os.path.join(model_dir, 'model-latest.pkl')): 105 | if gpu and gpu_available: 106 | latest_state = torch.load(os.path.join(model_dir, 'model-latest.pkl')) 107 | else: 108 | latest_state = torch.load(os.path.join(model_dir, 'model-latest.pkl'), map_location=lambda storage,loc: storage) 109 | model.load_state_dict(latest_state['network']) 110 | optm.load_state_dict(latest_state['optimizer']) 111 | initial_epoch = latest_state['epoch'] + 1 112 | 113 | elif continue_pkl is not None and initial_epoch is not None and os.path.exists(os.path.join(model_dir, 'model-%d.pkl' % initial_epoch)): 114 | if gpu and gpu_available: 115 | initial_state = torch.load(os.path.join(model_dir, 'model-%d.pkl' % initial_epoch)) 116 | else: 117 | initial_state = torch.load(os.path.join(model_dir, 'model-%d.pkl' % initial_epoch), map_location=lambda storage, _: storage) 118 | model.load_state_dict(initial_state['network']) 119 | optm.load_state_dict(initial_state['optimizer']) 120 | initial_epoch = initial_state['epoch'] + 1 121 | else: 122 | initial_epoch = 0 123 | 124 | # load history 125 | batch_history_train = load_history(filename='history_batch_train.json', history_path=log_file) 126 | epoch_history_train = load_history(filename='history_epoch_train.json', history_path=log_file) 127 | epoch_history_dev = load_history(filename='history_epoch_dev.json', history_path=log_file) 128 | # slice 129 | batch_history_train = batch_history_train[:initial_epoch] 130 | epoch_history_train = epoch_history_train[:initial_epoch] 131 | epoch_history_dev = epoch_history_dev[:initial_epoch] 132 | 133 | with tqdm(total=epochs, desc='Epoch', initial=initial_epoch) as epoch_bar: 134 | for epoch in range(initial_epoch, epochs): 135 | model.train() 136 | loss_batchs = [] 137 | acc_batchs = [] 138 | multi_acc_batchs = [] 139 | with tqdm(total=int(np.ceil(len(train_loader.dataset) / batch_size)), desc='Batch') as batch_bar: 140 | for batch, (x, y) in enumerate(train_loader): 141 | optm.zero_grad() 142 | x = torch.tensor(x, requires_grad=True) 143 | y = torch.tensor(y) 144 | if gpu and gpu_available: 145 | x = x.cuda() 146 | y = y.cuda() 147 | pred_1, pred_2, pred_3, pred_4 = model(x) 148 | 149 | loss1, loss2, loss3, loss4 = loss_fn(pred_1, y[:,0]), loss_fn(pred_2, y[:,1]), loss_fn(pred_3, y[:,2]), loss_fn(pred_4, y[:,3]) 150 | 151 | loss_count = loss1 + loss2 + loss3 + loss4 152 | acc_count = acc(pred_1, y[:,0]) + acc(pred_2, y[:,1]) + acc(pred_3, y[:,2]) + acc(pred_4, y[:,3]) 153 | acc_mean = acc_count / 4. 154 | 155 | pred = torch.stack((pred_1, pred_2, pred_3, pred_4), dim=-1) 156 | multi_acc_mean = multi_acc(torch.argmax(pred, dim=1), y) 157 | 158 | loss_batchs.append(loss_count.item()) 159 | acc_batchs.append(acc_mean) 160 | multi_acc_batchs.append(multi_acc_mean) 161 | 162 | batch_bar.set_postfix(loss=loss_count.item(), acc=acc_mean, multi_acc=multi_acc_mean) 163 | batch_bar.update() 164 | batch_history_train.append([loss_count.item(), acc_mean, multi_acc_mean]) 165 | save_history('history_batch_train.json', batch_history_train, log_file) 166 | 167 | loss_count.backward() 168 | optm.step() 169 | 170 | epoch_bar.set_postfix(loss_mean=np.mean(loss_batchs), acc_mean=np.mean(acc_batchs), multi_acc_mean=np.mean(multi_acc_batchs)) 171 | epoch_bar.update() 172 | epoch_history_train.append([np.mean(loss_batchs).item(), np.mean(acc_batchs).item(), np.mean(multi_acc_batchs).item()]) 173 | save_history('history_epoch_train.json', epoch_history_train, log_file) 174 | 175 | # validate 176 | with tqdm(total=int(np.ceil(len(dev_loader.dataset) / batch_size)), desc='Val Batch') as batch_bar: 177 | model.eval() 178 | loss_batchs_dev = [] 179 | acc_batchs_dev = [] 180 | multi_acc_batchs_dev = [] 181 | for batch, (x, y) in enumerate(dev_loader): 182 | x = torch.tensor(x, requires_grad=False) 183 | y = torch.tensor(y, requires_grad=False) 184 | if gpu and gpu_available: 185 | x = x.cuda() 186 | y = y.cuda() 187 | pred_1, pred_2, pred_3, pred_4 = model(x) 188 | 189 | loss1, loss2, loss3, loss4 = loss_fn(pred_1, y[:,0]), loss_fn(pred_2, y[:,1]), loss_fn(pred_3, y[:,2]), loss_fn(pred_4, y[:,3]) 190 | 191 | loss_count = loss1 + loss2 + loss3 + loss4 192 | acc_count = acc(pred_1, y[:,0]) + acc(pred_2, y[:,1]) + acc(pred_3, y[:,2]) + acc(pred_4, y[:,3]) 193 | acc_mean = acc_count / 4. 194 | 195 | pred = torch.stack((pred_1, pred_2, pred_3, pred_4), dim=-1) 196 | multi_acc_mean = multi_acc(torch.argmax(pred, dim=1), y) 197 | 198 | loss_batchs_dev.append(loss_count.item()) 199 | acc_batchs_dev.append(acc_mean) 200 | multi_acc_batchs_dev.append(multi_acc_mean) 201 | 202 | batch_bar.set_postfix(loss=loss_count.item(), acc=acc_mean, multi_acc=multi_acc_mean) 203 | batch_bar.update() 204 | epoch_history_dev.append([np.mean(loss_batchs_dev).item(), np.mean(acc_batchs_dev).item(), np.mean(multi_acc_batchs_dev).item()]) 205 | save_history('history_epoch_dev.json', epoch_history_dev, log_file) 206 | 207 | # saving 208 | if not os.path.exists(model_dir): 209 | os.mkdir(model_dir) 210 | state_dict = { 211 | 'network': model.state_dict(), 212 | 'optimizer': optm.state_dict(), 213 | 'epoch': epoch 214 | } 215 | if epoch % step_saving == 0: 216 | model_path = os.path.join(model_dir, 'model-%d.pkl' % epoch) 217 | torch.save(state_dict, model_path) 218 | 219 | torch.save(state_dict, os.path.join(model_dir, 'model-latest.pkl')) 220 | 221 | 222 | import click 223 | 224 | @click.command() 225 | @click.help_option('-h', '--help') 226 | @click.option('-i', '--data_dir', default='./captchas', type=click.Path(), help='The path of train data', required=False) 227 | @click.option('-m', '--mode', default='captcha', help='The model type to train, could be captcha or kaptcha', type=click.Choice(['captcha', 'kaptcha']), required=False) 228 | @click.option('-e', '--epoch', default=120, help='The number of epoch model trained', required=False) 229 | @click.option('-p', '--data_split', default=[6, 1, 1], nargs=3, type=int, help='The split of train data to split', required=False) 230 | @click.option('-c', '--continue_train', default=None, help='If continue after last checkpoint or a specified one', required=False) 231 | @click.option('-t', '--checkpoint', default=0, type=int, help='The initial checkpoint to start, if set, it will load model-[checkpoint].pkl', required=False) 232 | @click.option('-b', '--batch_size', default=128, type=int, help='The batch size of input data', required=False) 233 | @click.option('-o', '--model_dir', default='./captcha_models', type=click.Path(), help='The model dir to save models or load models', required=False) 234 | @click.option('-r', '--lr', default=0.001, type=float, help='The learning rate to train', required=False) 235 | @click.option('-l', '--log_dir', default='./logs', type=click.Path(), help='The log files path', required=False) 236 | @click.option('-u', '--use_gpu', type=bool, default=False, help='Train by gpu or cpu', required=False) 237 | @click.option('-s', '--save_frequency', default=2, type=int, help='The frequence to save the models during training', required=False) 238 | def read_cli(data_dir, mode, epoch, data_split, continue_train, checkpoint, batch_size, model_dir, lr, log_dir, use_gpu, save_frequency): 239 | """ 240 | 241 | :param data_dir: 242 | :param mode: 243 | :param epoch: 244 | :param data_split: 245 | :param continue_train: 246 | :param checkpoint: 247 | :param batch_size: 248 | :param model_dir: 249 | :param lr: 250 | :param log_dir: 251 | :param use_gpu: 252 | :param save_frequency: 253 | :return: 254 | """ 255 | train(data_dir, data_split, batch_size, epoch, lr, checkpoint, save_frequency, model_dir, log_dir, continue_train, use_gpu, mode) 256 | 257 | 258 | if __name__ == "__main__": 259 | read_cli() --------------------------------------------------------------------------------