├── layout_analysis ├── __init__.py ├── dataset │ ├── image_augment.py │ ├── data_factory.py │ ├── image_augmation.py │ ├── data_util.py │ └── book_data.py ├── images │ └── 40_predict.png ├── README.md ├── plot_xml.py ├── load_saved_model.py ├── train.py ├── model.py └── eval.py ├── dataset ├── 13575.png └── sample │ ├── 1204675708.PDF-189.png │ └── 188.xml ├── crnn_ocr ├── image_augment.py ├── README.md ├── .gitignore ├── input_fn.py ├── load_saved_mode.py ├── create_tfrecord.py ├── model.py ├── eval.py └── train.py ├── single_word_ocr ├── image_augment.py ├── README.md ├── .gitignore ├── create_tfrecord.py ├── input_fn.py ├── data_util.py ├── data_generator.py ├── eval.py ├── load_saved_model.py ├── train.py └── densenet.py ├── README.md ├── .gitignore └── LICENSE /layout_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- -------------------------------------------------------------------------------- /dataset/13575.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/document-ocr/HEAD/dataset/13575.png -------------------------------------------------------------------------------- /dataset/sample/1204675708.PDF-189.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/document-ocr/HEAD/dataset/sample/1204675708.PDF-189.png -------------------------------------------------------------------------------- /layout_analysis/dataset/image_augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | 6 | -------------------------------------------------------------------------------- /layout_analysis/images/40_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rockyzhengwu/document-ocr/HEAD/layout_analysis/images/40_predict.png -------------------------------------------------------------------------------- /layout_analysis/dataset/data_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | from dataset import book_data 6 | 7 | 8 | DATA_GENERATOR = { 9 | "book": book_data 10 | } 11 | 12 | def get_data(name): 13 | data_generator_fn = DATA_GENERATOR.get(name) 14 | if data_generator_fn is None: 15 | print("data %s not exists") 16 | return data_generator_fn 17 | 18 | -------------------------------------------------------------------------------- /crnn_ocr/image_augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | from imgaug import augmenters as iaa 6 | 7 | seq = iaa.SomeOf((1, 4), [ 8 | iaa.Salt(p=(0.1, 0.2)), 9 | iaa.GaussianBlur(sigma=(0, 0.5)) , 10 | iaa.CoarseDropout(p=(0.02, 0.1), size_percent=(0.2, 0.3)), 11 | iaa.JpegCompression(compression=(50,80)), 12 | ]) 13 | 14 | 15 | 16 | def augment_images(images): 17 | images_aug = seq(images=images) 18 | return images_aug 19 | 20 | 21 | -------------------------------------------------------------------------------- /single_word_ocr/image_augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | from imgaug import augmenters as iaa 6 | 7 | seq = iaa.SomeOf((1, 4), [ 8 | iaa.Salt(p=(0.2, 0.4)), 9 | iaa.GaussianBlur(sigma=(0, 1.0)) , 10 | iaa.CoarseDropout(p=(0.02, 0.1), size_percent=(0.2, 0.6)), 11 | iaa.JpegCompression(compression=(50,80)), 12 | ]) 13 | 14 | 15 | def augment_image(image): 16 | image = 255 - image 17 | image_aug = seq(image=image) 18 | image_aug = 255 - image_aug 19 | return image_aug 20 | 21 | 22 | -------------------------------------------------------------------------------- /single_word_ocr/README.md: -------------------------------------------------------------------------------- 1 | ## 单字识别 2 | 3 | 可以使用[合成工具](https://github.com/rockyzhengwu/synthtext)合成数据 4 | 5 | 6 | 模型采用 densenet 7 | 8 | 9 | ## 使用 10 | 11 | 12 | 13 | 1. 准备数据生成image_list 文件格式如下, 准备对应的字典文件 14 | ``` 15 | /data/9916/Z01.png 9916 16 | /data/9916/Z02.png 9916 17 | ``` 18 | 19 | 2. 训练数据 20 | 21 | 训练不需要生成tfrecord 文件 22 | 23 | num_class: 字典文件参数 24 | 25 | python train.py --train_image_list --num_class --checkpoint_path 26 | 27 | 28 | 其他参数见代码 29 | 30 | 3. 测试 31 | 32 | python eval.py 具体参数详见代码 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /layout_analysis/README.md: -------------------------------------------------------------------------------- 1 | # 版面分析 2 | 3 | ## 简介 4 | 5 | 模型使用 UNet 结构在训练时用一个decoder 作为辅助,整体架构如下图 6 | 7 | 8 | ## 用法 9 | 10 | 1. 数据准备 11 |  准备数据如[样例](https://gitee.com/rockyzheng/document-ocr/blob/master/dataset/sample) 12 | 生成label_list 文件如下 13 | 14 | ``` 15 | /data/1.xml 16 | /data/2.xml 17 | ``` 18 | 在 [data_factory.py](./dataset/data_factory.py) 中配置相关参数 19 | 20 | 可自行生成其他格式数据实现对应的 data_generator 即可 21 | 22 | 2. 训练 23 | 24 | ``` 25 | python train.py --name 26 | ``` 27 | 28 | 3. 测试 & 导出模型 29 | ``` 30 | python --eval.py --name 31 | python --eval.py --export 32 | ``` 33 | 34 | 4. 加载导出的模型 35 | 在 ```load_saved_mode.py`` 中添加模型位置然后运行就能运行测试导出的模型或者对外提供API 36 | 37 | 38 | ![](./images/40_predict.png) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # document-ocr 2 | 3 | 一个相对完整的文档分析和识别项目含以下五部分: 4 | 5 | 1. 文档分析数据, 6 | 2. 识别数据合成工具, 7 | 3. 文档版面分析模型, 8 | 4. 文本行识别模型, 9 | 5. 单字识别模型 10 | 11 | ### 数据 12 | 13 | 数据从网络抓取的公开上市公司年报数据 PDF 文件解析生成,有图片和文本的位置信息 14 | 15 | [一份样例](./dataset/sample) 16 | 17 | [网盘下载地](https://pan.baidu.com/s/1dcZAqRxJtsXw9l0n6j8pbg) 提取码: nn1g 18 | 19 | 文本数据是标注到文本行的,部分数据会有些瑕疵,共34000样本 20 | ![](./dataset/13575.png) 21 | 22 | 可以根据标注数据生成文本行识别数据 23 | 24 | #### 识别数据合成 25 | 26 | 单字和文本行[数据合成工具](https://github.com/rockyzhengwu/synthtext) 能比较好的过滤字体中不支持的字符 27 | 28 | ## 相关算法实现 29 | 所有代码依赖 Tensorflow 1.14 和 opencv 3.x 30 | 31 | 1. [版面分析](./layout_analysis/README.md) 32 | 2. [文本行识别](./crnn_ocr/README.md) 33 | 3. [单字识别](./single_word_ocr/README.md) 34 | 35 | ### 注 36 | 37 | - 代码还有很多需要完善的地方,不在此列举,欢迎各种 issue 38 | 39 | - 代码中有很多参数没有提出到命令行比如 learning_rate 等,希望使用的时候多读下代码 40 | -------------------------------------------------------------------------------- /layout_analysis/dataset/image_augmation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # author: wu.zheng midday.me 4 | 5 | import cv2 6 | 7 | from albumentations import ( 8 | HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, 9 | Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, 10 | IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, 11 | IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, RandomBrightness 12 | ) 13 | 14 | 15 | def image_aug(image, p=1.0): 16 | def strong_aug(p=p): 17 | return Compose([ 18 | OneOf([ 19 | IAAAdditiveGaussianNoise(scale=(0.01 * 255, 0.1 * 255)), 20 | GaussNoise(var_limit=(10.0, 100.0)), 21 | ], p=0.5), 22 | OneOf([ 23 | MotionBlur(p=0.2), 24 | MedianBlur(blur_limit=7, p=0.1), 25 | Blur(blur_limit=7, p=0.1), 26 | ], p=0.5), 27 | OneOf([ 28 | IAASharpen(p=1.0), 29 | IAAEmboss(p=1.0), 30 | CLAHE(clip_limit=2, p=1.0), 31 | RandomBrightnessContrast(brightness_limit=(-0.15, 0.15), contrast_limit=(-0.15, 0.15), p=1.0), 32 | ], p=0.5), 33 | HueSaturationValue(p=0.5), 34 | ], p=p) 35 | 36 | augmentation = strong_aug(p=p) 37 | augmented = augmentation(image=image) 38 | aug_image = augmented['image'] 39 | return aug_image 40 | 41 | 42 | -------------------------------------------------------------------------------- /crnn_ocr/README.md: -------------------------------------------------------------------------------- 1 | ## 文本行识别 2 | 3 | CRNN + RNN + CTC 实现 4 | 5 | ## 使用 6 | 7 | 1. 准备标注数据文件,文件路径最好是绝对路径,路径和文本之间用空格隔开 8 | 9 | 10 | ``` 11 | /data/9b9723f0-f7e4-49b4-bc95-28cd1cdd28e0.png 游的片曲「Come Home! Princess」是 12 | /data/4ed93c5d-b0f6-4232-a16a-78bdd5296a08.png 有8个公交港湾,留5个大的出入口潘多脂 13 | /data/1d588889-e28e-4b33-8705-b10865785efe.png 摩哥大 14 | /data/334c4175-d25e-4d61-b5eb-576f8983a0fd.png 甸,中国古代官名,于周礼》中,主管 15 | 16 | ``` 17 | 18 | 字典数据用 json 存放格式如下,如果有在字典之外的符号统一用``代替 19 | 20 | 21 | ``` 22 | { 23 | "": 0, 24 | "天":1, 25 | "文":2, 26 | } 27 | 28 | ``` 29 | 30 | 2. 创建 tfrecord 文件 31 | 32 | - image_list : 是上面准备的数据文件 33 | - data_dir:存放tf_record 路径 34 | - vocab_file: 是准备的词典文件 35 | 36 | 37 | ``` 38 | python ./create_tfrecord.py --image_list ${LABELS_FILE} --vocab_file {vocab.json} --data_dir ${TF_RECRD_DIRS} --max_seq_length ${MAX_SEQ_LENGTH} --channel_size ${CHANNEL_SIZE} 39 | ``` 40 | 41 | 代码会使用多线程创建多份 train_tfrecord 文件,具体其他参数可以自行修改代码 42 | 43 | ``` 44 | start_create_process(train_anno_lines, 100, 10, 'train') 45 | start_create_process(validation_anno_lines, 10, 10, 'validation') 46 | start_create_process(test_anno_lines, 10, 10, 'test') 47 | 48 | ``` 49 | 50 | 3. 训练 51 | 52 | 53 | ``` 54 | python train.py --data_dir ${TF_RECRD_DIRS} --model_dir ${MODEL_DIR} --max_seq_length ${MAX_SEQ_LENGTH} --channel_size ${CHANNEL_SIZE} 55 | 56 | ``` 57 | 58 | 59 | 4. 测试 60 | 61 | 62 | ``` 63 | python ./eval.py --max_seq_length ${MAX_SEQ_LENGTH} --channel_size ${CHANNEL_SIZE} --model_dir ${MODEL_DIR} --image_list ${LABELS_FILE} --image_dir ${IMAGE_DIR} 64 | 65 | ``` 66 | 67 | 直接使用image_list 格式的数据作为输入,方面查看 bad case,如果需要读入 tfrecord 批量测试需要自行实现相关代码 68 | 69 | 增加 ```export```参数可以导出模型 使用 ```load_saved_model.py```的样例代码读取 saved model 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /single_word_ocr/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /crnn_ocr/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/ 107 | tf_records/* 108 | -------------------------------------------------------------------------------- /single_word_ocr/create_tfrecord.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import os 6 | import tensorflow as tf 7 | import random 8 | import cv2 9 | import json 10 | 11 | CHAR_MAP_DICT = json.load(open("vocab.json")) 12 | 13 | def _int64_feature(value): 14 | if not isinstance(value, list): 15 | value = [value] 16 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 17 | 18 | def _bytes_feature(value): 19 | if not isinstance(value, list): 20 | value = [value] 21 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 22 | 23 | def _string_to_int(label): 24 | int_list = [] 25 | for c in list(label): 26 | if CHAR_MAP_DICT.get(c) is None: 27 | print("inot in :",c) 28 | continue 29 | int_list.append(CHAR_MAP_DICT[c]) 30 | return int_list 31 | 32 | def create_tf_record(data_dir, tfrecords_path): 33 | image_names = [] 34 | for root, dirs, files in os.walk(data_dir): 35 | image_names +=[os.path.join(root, name) for name in files] 36 | random.shuffle(image_names) 37 | writer = tf.python_io.TFRecordWriter(tfrecords_path) 38 | print("handle image : %d"%(len(image_names))) 39 | i = 0 40 | for image_name in image_names: 41 | if i % 10000 == 0: 42 | print(i, len(image_names)) 43 | i+=1 44 | im = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE) 45 | try: 46 | is_success, image_buffer = cv2.imencode('.png', im) 47 | except Exception as e: 48 | continue 49 | if not is_success: 50 | continue 51 | label = int(image_name.split("/")[-2]) 52 | features = tf.train.Features(feature={ 53 | 'labels': _int64_feature(label), 54 | 'images': _bytes_feature(image_buffer.tostring()), 55 | 'imagenames': _bytes_feature(image_name.encode("utf-8"))}) 56 | example = tf.train.Example(features=features) 57 | writer.write(example.SerializeToString()) 58 | writer.close() 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | .idea/ 117 | -------------------------------------------------------------------------------- /single_word_ocr/input_fn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | 7 | def _decode_record(record_proto, aug=False): 8 | feature_map = { 9 | 'images': tf.FixedLenFeature((), tf.string), 10 | 'labels' : tf.VarLenFeature(tf.int64), 11 | 'imagenames': tf.FixedLenFeature((), tf.string), 12 | } 13 | features = tf.parse_single_example(record_proto, features=feature_map) 14 | images = tf.image.decode_png(features['images'], channels=1) 15 | 16 | images = tf.image.resize_images(images, [32, 32] ) 17 | images.set_shape([32, 32, 1]) 18 | images = tf.cast(images, tf.float32) 19 | labels = tf.cast(features['labels'], tf.int32) 20 | 21 | #images = _image_augmentation(images) 22 | example = { 23 | "images": images / 255.0 , 24 | "labels" : tf.squeeze(tf.sparse_tensor_to_dense(labels)), 25 | } 26 | return example 27 | 28 | def _image_augmentation(image): 29 | #image = tf.image.random_flip_up_down(image) 30 | image = tf.image.random_brightness(image, max_delta=0.3) 31 | image = tf.image.random_contrast(image, 0.8, 1.2) 32 | return image 33 | 34 | 35 | def input_fn(tf_record_dir, batch_size, mode): 36 | dataset = tf.data.TFRecordDataset(tf_record_dir) 37 | aug = False 38 | if mode == "train": 39 | dataset = dataset.repeat().shuffle(buffer_size=10000) 40 | aug = True 41 | else : 42 | dataset = dataset.repeat(1) 43 | dataset = dataset.map(lambda x: _decode_record(x, aug)) 44 | dataset = dataset.batch(batch_size=batch_size) 45 | return dataset.make_one_shot_iterator().get_next() 46 | 47 | if __name__ == "__main__": 48 | tfrecord_dir = '/data/zhengwu_workspace/ocr/dataset/single_word_gen/tfrecords/train.tfrecord' 49 | iterator = input_fn(tfrecord_dir, 3, 'train') 50 | with tf.Session() as sess: 51 | sess.run(tf.global_variables_initializer()) 52 | batch = sess.run(iterator) 53 | print(batch['labels'].shape) 54 | print(batch['images'].shape) 55 | 56 | -------------------------------------------------------------------------------- /layout_analysis/dataset/data_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | import time 4 | from multiprocessing import Queue 5 | 6 | import numpy as np 7 | 8 | class GeneratorEnqueuer(): 9 | def __init__(self, generator,wait_time=0.05, random_seed=None): 10 | self.wait_time = wait_time 11 | self._generator = generator 12 | self._threads = [] 13 | self._stop_event = None 14 | self.queue = None 15 | self.random_seed = random_seed 16 | 17 | def start(self, workers=1, max_queue_size=10): 18 | def data_generator_task(): 19 | while not self._stop_event.is_set(): 20 | try: 21 | generator_output = next(self._generator) 22 | self.queue.put(generator_output) 23 | except Exception as e: 24 | print(e) 25 | self._stop_event.set() 26 | raise 27 | try: 28 | self.queue = Queue(maxsize=max_queue_size) 29 | self._stop_event = multiprocessing.Event() 30 | for _ in range(workers): 31 | np.random.seed(self.random_seed) 32 | p = multiprocessing.Process(target=data_generator_task) 33 | p.daemon = True 34 | if self.random_seed is not None: 35 | self.random_seed += 1 36 | self._threads.append(p) 37 | p.start() 38 | except Exception as e: 39 | print(e) 40 | p.stop() 41 | raise 42 | 43 | def is_running(self): 44 | return self._stop_event is not None and not self._stop_event.is_set() 45 | 46 | def stop(self, timeout=None): 47 | if self.is_running(): 48 | self._stop_event.set() 49 | 50 | for thread in self._threads: 51 | if thread.is_alive(): 52 | thread.terminate() 53 | 54 | if self.queue is not None: 55 | self.queue.close() 56 | 57 | self._threads = [] 58 | self._stop_event = None 59 | self.queue = None 60 | 61 | def get(self): 62 | while self.is_running(): 63 | if not self.queue.empty(): 64 | inputs = self.queue.get() 65 | if inputs is not None: 66 | yield inputs 67 | else: 68 | time.sleep(self.wait_time) 69 | -------------------------------------------------------------------------------- /single_word_ocr/data_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | 6 | import multiprocessing 7 | import time 8 | from multiprocessing import Queue 9 | 10 | import numpy as np 11 | 12 | class GeneratorEnqueuer(): 13 | def __init__(self, generator,wait_time=0.05, random_seed=None): 14 | 15 | self.wait_time = wait_time 16 | self._generator = generator 17 | self._threads = [] 18 | self._stop_event = None 19 | self.queue = None 20 | self.random_seed = random_seed 21 | 22 | def start(self, workers=1, max_queue_size=10): 23 | def data_generator_task(): 24 | while not self._stop_event.is_set(): 25 | try: 26 | generator_output = next(self._generator) 27 | self.queue.put(generator_output) 28 | except Exception: 29 | self._stop_event.set() 30 | raise 31 | try: 32 | self.queue = Queue(maxsize=max_queue_size) 33 | self._stop_event = multiprocessing.Event() 34 | for _ in range(workers): 35 | np.random.seed(self.random_seed) 36 | p = multiprocessing.Process(target=data_generator_task) 37 | p.daemon = True 38 | if self.random_seed is not None: 39 | self.random_seed += 1 40 | self._threads.append(p) 41 | p.start() 42 | except Exception as e: 43 | print(e) 44 | p.stop() 45 | raise 46 | 47 | def is_running(self): 48 | return self._stop_event is not None and not self._stop_event.is_set() 49 | 50 | def stop(self, timeout=None): 51 | if self.is_running(): 52 | self._stop_event.set() 53 | 54 | for thread in self._threads: 55 | if thread.is_alive(): 56 | thread.terminate() 57 | 58 | if self.queue is not None: 59 | self.queue.close() 60 | 61 | self._threads = [] 62 | self._stop_event = None 63 | self.queue = None 64 | 65 | def get(self): 66 | while self.is_running(): 67 | if not self.queue.empty(): 68 | inputs = self.queue.get() 69 | if inputs is not None: 70 | yield inputs 71 | else: 72 | time.sleep(self.wait_time) 73 | -------------------------------------------------------------------------------- /crnn_ocr/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | import tensorflow as tf 4 | import glob 5 | import random 6 | import image_augment 7 | import numpy as np 8 | 9 | _MAX_LENGTH = 1024 10 | 11 | def standardize(img): 12 | mean = np.mean(img) 13 | std = np.std(img) 14 | img = (img - mean) / std 15 | return img 16 | 17 | def _decode_record(record_proto, channel_size): 18 | feature_map = { 19 | 'images': tf.FixedLenFeature((), tf.string), 20 | 'labels' : tf.VarLenFeature( tf.int64), 21 | 'imagenames': tf.FixedLenFeature((), tf.string), 22 | } 23 | features = tf.parse_single_example(record_proto, features=feature_map) 24 | 25 | images = tf.image.decode_jpeg(features['images'], channels=channel_size) 26 | images = tf.py_func(image_augment.augment_images, [images], tf.uint8) 27 | images = tf.cast(images, tf.float32) 28 | images = images / 255.0 29 | image_w = tf.cast(tf.shape(images)[1], tf.int32) 30 | sequence_length = tf.cast(tf.shape(images)[1] / 4, tf.int32) 31 | paddings = tf.convert_to_tensor([[0, 0], [0, _MAX_LENGTH - image_w], [0, 0]]) 32 | images = tf.pad(images, paddings) 33 | images.set_shape([32, _MAX_LENGTH, channel_size]) 34 | labels = tf.cast(features['labels'], tf.int32) 35 | example = { 36 | "images": images, 37 | "labels" : labels, 38 | "sequence_length": sequence_length 39 | } 40 | return example 41 | 42 | 43 | def _decode_record_estimator(record_proto, channel_size): 44 | feature_map = { 45 | 'images': tf.FixedLenFeature((), tf.string), 46 | 'labels' : tf.VarLenFeature(tf.int64), 47 | 'imagenames': tf.FixedLenFeature((), tf.string), 48 | } 49 | features = tf.parse_single_example(record_proto, features=feature_map) 50 | images = tf.image.decode_jpeg(features['images'], channels=channel_size) 51 | image_w = tf.cast(tf.shape(images)[1], tf.int32) 52 | paddings = tf.convert_to_tensor([[0, 0], [0, _MAX_LENGTH - image_w], [0, 0]]) 53 | images = tf.pad(images, paddings) 54 | images.set_shape([32, _MAX_LENGTH, channel_size]) 55 | images = tf.cast(images, tf.float32) 56 | labels = tf.cast(features['labels'], tf.int32) 57 | sequence_length = tf.cast(tf.shape(images)[1]/ 4, tf.int32) 58 | features = { 59 | "images": images , 60 | "sequence_length": sequence_length 61 | } 62 | return features, labels 63 | 64 | def input_fn(tfrecord_path, batch_size, is_training=True, channel_size=3 ): 65 | filenames = glob.glob(tfrecord_path) 66 | print(filenames) 67 | random.shuffle(filenames) 68 | dataset = tf.data.TFRecordDataset(filenames) 69 | if is_training: 70 | dataset = dataset.repeat().shuffle(buffer_size=10000) 71 | else: 72 | data = dataset.repeat(1) 73 | dataset = dataset.map(lambda x: _decode_record(x, channel_size)) 74 | dataset = dataset.batch(batch_size=batch_size) 75 | dataset = dataset.prefetch(buffer_size=10000) 76 | return dataset.make_one_shot_iterator() 77 | 78 | -------------------------------------------------------------------------------- /layout_analysis/plot_xml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | from bs4 import BeautifulSoup 6 | import random 7 | import cv2 8 | import os 9 | 10 | def plot_element(image, item, page_height, page_width, color): 11 | image_height, image_width = image.shape[:2] 12 | width_ratio = image_width/ page_width 13 | height_ratio = image_height / page_height 14 | left = int(item.get("left")) * width_ratio 15 | top = int(item.get("top")) * height_ratio 16 | width = int(item.get("width")) * width_ratio 17 | height = int(item.get("height")) * height_ratio 18 | pt1 = (int(left), int(top)) 19 | pt2 = (int(left + width), int(top + height)) 20 | image=cv2.rectangle(image, pt1=pt1, pt2=pt2, color=color, thickness=2) 21 | return image 22 | 23 | 24 | def plot_from_path(xml_path, image_path): 25 | with open(xml_path) as f: 26 | soup = BeautifulSoup(f.read(), 'xml') 27 | page = soup.find('page') 28 | image = cv2.imread(image_path) 29 | out_image_path = os.path.join('./test_out', page.get("number") + ".png") 30 | plot_one(page, image, out_image_path) 31 | 32 | def plot_one(page, image, out_image_path): 33 | h,w = image.shape[:2] 34 | page_height = int(page.get("height")) 35 | page_width = int(page.get("width")) 36 | image_list = page.find_all('image') 37 | text_list = page.find_all('text') 38 | for item in image_list: 39 | image = plot_element(image, item, page_height, page_width, (125, 0, 255)) 40 | for item in text_list: 41 | image = plot_element(image, item, page_height, page_width, (255, 255, 0)) 42 | cv2.imwrite(out_image_path, image) 43 | 44 | 45 | def run(xml_path, out_path): 46 | with open(xml_path, encoding='utf-8', errors='ignore') as f: 47 | xml_data = f.read() 48 | soup = BeautifulSoup(xml_data, 'xml') 49 | pages = soup.find_all('page') 50 | for page in pages: 51 | image_path = page.get('image_path') 52 | print(image_path) 53 | image= cv2.imread(image_path) 54 | out_image_path = os.path.join(out_path, page.get("number") + ".png") 55 | plot_one(page, image, out_image_path) 56 | 57 | 58 | def plot_page(): 59 | xml_path_list = [] 60 | for top, dirs, files in os.walk('/data/zhengwu_workspace/document_text_line/layout_xml_anno/year_report'): 61 | for name in files: 62 | xml_path = os.path.join(top, name) 63 | xml_path_list.append(xml_path) 64 | 65 | random.shuffle(xml_path_list) 66 | for xml_path in xml_path_list[:100]: 67 | print(xml_path) 68 | with open(xml_path) as f: 69 | xml_data = f.read() 70 | soup = BeautifulSoup(xml_data, 'xml') 71 | page = soup.find('page') 72 | out_image_path = xml_path.split("/")[-1].replace("xml", "png") 73 | out_image_path = os.path.join('./test_out', out_image_path) 74 | print(out_image_path) 75 | plot_one(page, out_image_path) 76 | 77 | if __name__ == "__main__": 78 | xml_path = "/home/zhengwu/data/book_layout/part_2/314560.xml" 79 | image_path = "./dataset/input.png" 80 | plot_from_path(xml_path, image_path) 81 | -------------------------------------------------------------------------------- /crnn_ocr/load_saved_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | import cv2 7 | import numpy as np 8 | import json 9 | 10 | EXPORT_PATH = "./exported_model" 11 | _IMAGE_HEIGHT = 32 12 | _MAX_LENGTH = 1024 13 | char_map_path = './vocab.json' 14 | CHAR_TO_ID = json.load(open(char_map_path)) 15 | ID_TO_CHAR = {v:k for k,v in CHAR_TO_ID.items()} 16 | CHAR_SIZE = len(CHAR_TO_ID) 17 | 18 | def pre_process_image(image): 19 | h, w,_ = image.shape 20 | height = _IMAGE_HEIGHT 21 | width = int(w * height / h) 22 | image = cv2.resize(image, (width, height)) 23 | image = np.array(image, dtype=np.float32) 24 | if width > _MAX_LENGTH: 25 | image = cv2.resize(image, (_MAX_LENGTH, height)) 26 | width = _MAX_LENGTH 27 | image = image / 255.0 28 | seq_len = np.array([width / 4], dtype=np.int32) 29 | image = np.expand_dims(image, axis=0) 30 | return image, seq_len 31 | 32 | def load_image(image_path): 33 | image = cv2.imread(image_path) 34 | return pre_process_image(image) 35 | 36 | 37 | def _int_to_string(value): 38 | return ID_TO_CHAR.get(int(value), "~") 39 | 40 | def _sparse_matrix_to_list(sparse_matrix ): 41 | indices = sparse_matrix.indices 42 | values = sparse_matrix.values 43 | dense_shape = sparse_matrix.dense_shape 44 | dense_matrix = CHAR_SIZE * np.ones(dense_shape, dtype=np.int32) 45 | 46 | for i, indice in enumerate(indices): 47 | dense_matrix[indice[0], indice[1]] = values[i] 48 | string_list = [] 49 | for row in dense_matrix: 50 | string = [] 51 | for val in row: 52 | string.append(_int_to_string(val )) 53 | string_list.append(''.join(s for s in string if s != '*')) 54 | return string_list 55 | 56 | class Model(): 57 | def __init__(self,): 58 | self.sess = tf.Session() 59 | tf.saved_model.loader.load( 60 | self.sess, 61 | [tf.saved_model.tag_constants.SERVING], EXPORT_PATH) 62 | 63 | graph = tf.get_default_graph() 64 | self.image = graph.get_tensor_by_name('images:0') 65 | self.sequence_length = graph.get_tensor_by_name("sequence_length:0") 66 | self.is_trainig = graph.get_tensor_by_name("training:0") 67 | self.keep_prob = graph.get_tensor_by_name("keep_prob:0") 68 | self.logits = graph.get_tensor_by_name("lstm_layers/logits:0") 69 | self.decoded ,_ = tf.nn.ctc_beam_search_decoder( 70 | self.logits, 71 | self.sequence_length, 72 | merge_repeated=True, 73 | beam_width=10, 74 | top_paths=1) 75 | 76 | def predict(self, im): 77 | image, seq_len = pre_process_image(im) 78 | feed_dict = { 79 | self.image: image, 80 | self.sequence_length:seq_len, 81 | self.keep_prob:1.0, 82 | self.is_trainig:True} 83 | 84 | decoded = self.sess.run(self.decoded, feed_dict = feed_dict) 85 | pred = _sparse_matrix_to_list(decoded[0]) 86 | return pred 87 | 88 | if __name__ == "__main__": 89 | image_path = './497af4b4-08c0-40cd-b46e-b1576d13e689_6.jpg' 90 | im = cv2.imread(image_path) 91 | model = Model() 92 | model.predict(im) 93 | 94 | -------------------------------------------------------------------------------- /layout_analysis/load_saved_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | import os 7 | import copy 8 | import numpy as np 9 | import cv2 10 | 11 | 12 | MODEL_PATH='/data/zhengwu_workspace/ocr/models/book_export_models/image_txt/book/1571968585' 13 | IMAGE_HEIGHT=1024 14 | IMAGE_WIDTH=768 15 | NUM_CLASS=3 16 | COLOR_LIST=[(0, 255,0), (0, 0, 255), (0, 255, 255)] 17 | 18 | def mask_to_bbox(mask, image, num_class, area_threhold=0, out_path=None, out_file_name=None): 19 | bbox_list = [] 20 | im = copy.copy(image) 21 | mask = mask.astype(np.uint8) 22 | for i in range(1, num_class, 1): 23 | c_bbox_list = [] 24 | c_mask = np.zeros_like(mask) 25 | c_mask[np.where(mask==i)] = 255 26 | bimg , countours, hier = cv2.findContours(c_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 27 | for cnt in countours: 28 | area = cv2.contourArea(cnt) 29 | if area < area_threhold: 30 | continue 31 | epsilon = 0.005 * cv2.arcLength(cnt,True) 32 | approx = cv2.approxPolyDP(cnt,epsilon,True) 33 | (x, y, w, h) = cv2.boundingRect(approx) 34 | c_bbox_list.append([x, y, x+w, y+h]) 35 | if out_path is not None: 36 | color = COLOR_LIST[i-1] 37 | im=cv2.rectangle(im, pt1=(x, y), pt2=(x+w, y+h),color=color, thickness=2) 38 | bbox_list.append(c_bbox_list) 39 | if out_path is not None: 40 | outf = os.path.join(out_path, out_file_name) 41 | cv2.imwrite(outf, im) 42 | return bbox_list 43 | 44 | def resize_bbox(bbox_list, w_factor, h_factor, class_names): 45 | bbox_map = {} 46 | for c, c_bbox_list in enumerate(bbox_list): 47 | c_name = class_names[c] 48 | bbox_map[c_name] = [] 49 | for bbox in c_bbox_list: 50 | new_bbox = [bbox[0]/ w_factor, bbox[1]/ h_factor, bbox[2]/w_factor, bbox[3]/h_factor] 51 | new_bbox = list(map(int, new_bbox)) 52 | bbox_map[c_name].append(new_bbox) 53 | return bbox_map 54 | 55 | 56 | class Model(object): 57 | def __init__(self, model_dir, area_threhold, class_names): 58 | self.model_dir = model_dir 59 | self.area_threhold = area_threhold 60 | self.num_class = len(class_names) + 1 61 | self.class_names = class_names 62 | self.graph = tf.Graph() 63 | self.sess = tf.Session(graph=self.graph) 64 | tf.saved_model.loader.load(self.sess, [tf.saved_model.tag_constants.SERVING], self.model_dir) 65 | self.image = self.graph.get_tensor_by_name('image:0') 66 | self.prob = self.graph.get_tensor_by_name("prob:0") 67 | 68 | 69 | def predict(self, img): 70 | h, w = img.shape[:2] 71 | h_factor = IMAGE_HEIGHT / h 72 | w_factor = IMAGE_WIDTH / w 73 | img = cv2.resize(img, (IMAGE_WIDTH, IMAGE_HEIGHT)) 74 | feed_dict = {self.image:[img/255.0]} 75 | prob = self.sess.run([self.prob], feed_dict=feed_dict) 76 | prob = prob[0][0] 77 | mask = np.argmax(prob, axis=-1) 78 | 79 | mask = mask.astype(np.uint8) 80 | bbox_list = mask_to_bbox(mask, img, self.num_class, self.area_threhold, "./", "server_predict.png") 81 | bbox_map = resize_bbox(bbox_list, w_factor, h_factor, self.class_names) 82 | return bbox_map 83 | 84 | 85 | if __name__ =="__main__": 86 | image_path="" 87 | img = cv2.imread(image_path) 88 | model = Model(MODEL_PATH, 10, ['image', 'text']) 89 | model.predict(img) 90 | -------------------------------------------------------------------------------- /single_word_ocr/data_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import cv2 6 | import random 7 | import image_augment 8 | import numpy as np 9 | 10 | import time 11 | import data_util 12 | 13 | def load_image_list(image_data_file_path): 14 | image_path_list = [] 15 | label_list = [] 16 | with open(image_data_file_path) as f: 17 | for _, line in enumerate(f): 18 | line = line.strip("\n").split() 19 | if len(line)<2: 20 | continue 21 | image_path = " ".join(line[0:-1]) 22 | label = int(line[-1]) 23 | image_path_list.append(image_path) 24 | label_list.append(label) 25 | return image_path_list, label_list 26 | 27 | def pre_process_image(image): 28 | image_size = 64 29 | h, w = image.shape[:2] 30 | if h == image_size and w == image_size: 31 | pass 32 | elif h > image_size or w > image_size: 33 | image = cv2.resize(image, (image_size, image_size)) 34 | image = image / 255.0 35 | else: 36 | pad_height = int((image_size - h) / 2) 37 | pad_width = int((image_size - w) / 2) 38 | image = image / 255.0 39 | image = np.pad(image, ((pad_height, image_size-h - pad_height), (pad_width, image_size-w-pad_width),(0,0)), mode='constant' ) 40 | return image 41 | 42 | 43 | def data_generator(image_data_file_path, batch_size, mode='train'): 44 | image_path_list, label_list = load_image_list(image_data_file_path) 45 | print(len(image_path_list), len(label_list)) 46 | index_list = list(range(len(image_path_list))) 47 | while True: 48 | #if mode=='train': 49 | random.shuffle(index_list) 50 | image_batch = [] 51 | label_batch = [] 52 | image_path_batch = [] 53 | for idx in index_list: 54 | image_path = image_path_list[idx] 55 | label = label_list[idx] 56 | #image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 57 | image = cv2.imread(image_path) 58 | # todo crop image 59 | if image is None: 60 | print('image is none: %s'%(image_path)) 61 | continue 62 | image_batch.append(pre_process_image(image)) 63 | label_batch.append(label) 64 | image_path_batch.append(image_path) 65 | if len(image_batch) == batch_size: 66 | yield image_batch, label_batch, image_path_batch 67 | image_batch = [] 68 | label_batch = [] 69 | image_path_batch = [] 70 | if mode=='train': 71 | for i in range(1): 72 | # image_aug = image_augment.augment_image(image) 73 | image_aug = image 74 | image_batch.append(pre_process_image(image_aug)) 75 | label_batch.append(label) 76 | if len(image_batch) == batch_size: 77 | yield image_batch, label_batch, image_path_batch 78 | image_batch = [] 79 | label_batch = [] 80 | image_path_batch= [] 81 | if mode!='train': 82 | break 83 | 84 | def get_batch(data_dir, batch_size, mode='train', workers=1, max_queue_size=32): 85 | enqueuer = data_util.GeneratorEnqueuer(data_generator(data_dir, batch_size, mode)) 86 | enqueuer.start(max_queue_size=max_queue_size, workers=workers) 87 | generator_output = None 88 | while True: 89 | while enqueuer.is_running(): 90 | if not enqueuer.queue.empty(): 91 | generator_output = enqueuer.queue.get() 92 | break 93 | else: 94 | time.sleep(0.01) 95 | yield generator_output 96 | generator_output = None 97 | -------------------------------------------------------------------------------- /single_word_ocr/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | import json 9 | import numpy as np 10 | import time 11 | import densenet 12 | import data_generator 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="SingleWord Reconition") 16 | parser.add_argument('--test_image_list', help='test_data_dir') 17 | parser.add_argument('--checkpoint_path', help='checkpoint path') 18 | parser.add_argument('--batch_size', default=128, help='batch size', type=int) 19 | parser.add_argument('--num_class', help='num of class', type=int) 20 | parser.add_argument('--export', type=bool, help='whether to export model', dest='export') 21 | parser.add_argument('--export_dir', help='export saved model dir') 22 | 23 | 24 | def export_model(sess, model): 25 | tensor_image_input_info = tf.saved_model.utils.build_tensor_info(model.images) 26 | tensor_prob_output_info = tf.saved_model.utils.build_tensor_info(model.predict_prob) 27 | tensor_prediction_output_info = tf.saved_model.utils.build_tensor_info(model.prediction) 28 | signature = tf.saved_model.signature_def_utils.build_signature_def( 29 | inputs = {"images": tensor_image_input_info}, 30 | outputs = {"prob": tensor_prob_output_info, "prediction": tensor_prediction_output_info} 31 | ) 32 | ex_dir = str(int(time.time())) 33 | export_dir = os.path.join(args.export_dir, ex_dir) 34 | builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 35 | builder.add_meta_graph_and_variables( 36 | sess=sess, 37 | tags=[tf.saved_model.tag_constants.SERVING], 38 | signature_def_map={"predict": signature}) 39 | builder.save() 40 | print("export model at %s"%(export_dir)) 41 | 42 | def load_vocab(): 43 | data = json.load(open("gbk_eng.json")) 44 | data_reverse = {v:k for k, v in data.items()} 45 | return data_reverse 46 | 47 | def main(): 48 | save_path = tf.train.latest_checkpoint(args.model_dir) 49 | model = densenet.DenseNet(1, args.num_class, mode='test') 50 | saver = tf.train.Saver() 51 | id_to_word = load_vocab() 52 | 53 | with tf.Session() as sess: 54 | saver.restore(sess=sess, save_path=save_path) 55 | if args.export: 56 | export_model(sess, model) 57 | exit(0) 58 | 59 | print("load model from %s"%(save_path)) 60 | counter = 0 61 | right_counter = 0 62 | for batch_data in data_generator.get_batch(args.test_image_list, batch_size=1, mode='test', workers=1, max_queue_size=12): 63 | image = np.array(batch_data[0]) 64 | label = batch_data[1] 65 | image_path = batch_data[2] 66 | feed_dict = {model.images: image} 67 | prediction, predict_prob = sess.run([model.prediction, model.predict_prob], feed_dict=feed_dict) 68 | predict_id = prediction[0] 69 | predict_label = id_to_word[predict_id] 70 | predict_prob = predict_prob[0][predict_id] 71 | true_label = id_to_word[label[0]] 72 | print("image_path: %s, true_id: %d, true_label: %s, predict_label: %s, predict_prob: %f"%( 73 | image_path, label[0], true_label ,predict_label, predict_prob)) 74 | 75 | if true_label == predict_label : 76 | right_counter += 1 77 | counter += 1 78 | if counter > 100: 79 | break 80 | print("acc : %f"%(1.0 * right_counter / counter )) 81 | 82 | if __name__ == "__main__": 83 | args = parser.parse_args() 84 | print(args) 85 | main() 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /single_word_ocr/load_saved_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | import cv2 7 | import numpy as np 8 | import json 9 | import time 10 | import os 11 | 12 | EXPORT_PATH = "single_word_model/densenet_exported/1565353481" 13 | VOCAB_PATH = './gbk.json' 14 | 15 | 16 | def load_charset_map(): 17 | char_set = json.load(open(VOCAB_PATH)) 18 | char_set = {v:k for k, v in char_set.items()} 19 | return char_set 20 | 21 | def pre_process_image(im): 22 | image_size = 64 23 | h, w = im.shape[:2] 24 | if h > image_size or w > image_size: 25 | im = cv2.resize(im, (image_size, image_size)) 26 | else: 27 | pad_height = int((image_size - h) / 2) 28 | pad_width = int((image_size - w) / 2) 29 | im = np.pad(im, ((pad_height, image_size-h - pad_height), (pad_width, image_size - w - pad_width)), 30 | mode='constant', constant_values=((255, 255),(255, 255))) 31 | print(im.shape) 32 | im = im / 255.0 33 | im = im.reshape([1, image_size, image_size, 1]) 34 | return im 35 | 36 | class Model(): 37 | def __init__(self): 38 | self.sess = tf.Session() 39 | tf.saved_model.loader.load(self.sess, ['serve'], EXPORT_PATH) 40 | graph = tf.get_default_graph() 41 | self.input_image = graph.get_tensor_by_name('image:0') 42 | self.predict_prob = graph.get_tensor_by_name('prob:0') 43 | self.predict_ids = graph.get_tensor_by_name('prediction:0') 44 | self.char_dict = load_charset_map() 45 | 46 | def predict(self, im): 47 | im = pre_process_image(im) 48 | feed_dict = {self.input_image: im } 49 | predict_prob, predict_ids = self.sess.run([self.predict_prob, self.predict_ids], feed_dict=feed_dict) 50 | out_index = int(predict_ids) 51 | print(self.predict_prob, self.char_dict[out_index]) 52 | return self.char_dict[int(out_index)] 53 | 54 | 55 | 56 | def predict(image_path, true_label="", counter=0): 57 | im = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE ) 58 | print(im.shape) 59 | word_im_list = [im] 60 | for i, word_im in enumerate(word_im_list): 61 | cv2.imwrite(os.path.join('word_image', str(counter)+".png"), word_im) 62 | start = time.time() 63 | word = model.predict(word_im) 64 | cost = time.time() - start 65 | print('image_path: %s true_label: %s predict_label: %s cost : %f'%(image_path, true_label, word, cost)) 66 | return true_label == word 67 | 68 | def load_image_path(file_path): 69 | image_path_list = [] 70 | label_path_list = [] 71 | with open(file_path) as f: 72 | for _, line in enumerate(f): 73 | line = line.strip("\n") 74 | line = line.split() 75 | if len(line) < 2: 76 | continue 77 | image_path = line[0] 78 | label = line[1] 79 | label_path_list.append(label) 80 | image_path_list.append(image_path) 81 | return image_path_list, label_path_list 82 | 83 | 84 | model = Model() 85 | 86 | if __name__ == "__main__": 87 | file_path = '/home/zhengwu/data/pdfs/books_line/d83a286f-08c9-5cb4-8dc0-65162a99781e/labels.txt' 88 | image_path_list, label_path_list = load_image_path(file_path) 89 | counter = 0 90 | right_counter = 0 91 | for image_path, label in zip(image_path_list, label_path_list): 92 | print(image_path) 93 | print(label) 94 | res = predict(image_path, label, counter) 95 | if res: 96 | right_counter +=1 97 | counter += 1 98 | if counter > 1000: 99 | break 100 | print(right_counter*1.0 / counter) 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /single_word_ocr/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | import tensorflow as tf 10 | import densenet 11 | import data_generator 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description="SingleWord Reconition") 15 | parser.add_argument("--train_image_list", help='train data label dir') 16 | parser.add_argument('--test_image_list', help='test_data_dir') 17 | parser.add_argument('--checkpoint_path', help='checkpoint path') 18 | parser.add_argument('--batch_size', default=128, help='batch size', type=int) 19 | parser.add_argument('--num_class', help='num of class', type=int) 20 | 21 | def train(): 22 | batch_size = args.batch_size 23 | num_class = args.num_class 24 | model = densenet.DenseNet(batch_size=batch_size, num_classes=num_class) 25 | global_step = tf.train.get_or_create_global_step() 26 | start_learning_rate= 0.0001 27 | learning_rate = tf.train.exponential_decay( 28 | start_learning_rate, 29 | global_step, 30 | 100000, 31 | 0.98, 32 | staircase=False, 33 | name="learning_rate" 34 | ) 35 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 36 | train_op= tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=model.loss, global_step=global_step) 37 | train_op = tf.group([train_op, update_ops]) 38 | #optimizer=tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9).minimize(loss=model.loss) 39 | saver = tf.train.Saver() 40 | tf.summary.scalar(name='loss', tensor=model.loss) 41 | #tf.summary.scalar(name='softmax_loss', tensor=model.softmax_loss) 42 | #tf.summary.scalar(name='center_loss', tensor=model.center_loss) 43 | tf.summary.scalar(name='accuracy', tensor=model.accuracy) 44 | merge_summary_op = tf.summary.merge_all() 45 | sess_config = tf.ConfigProto(allow_soft_placement=True,) 46 | with tf.Session(config=sess_config) as sess: 47 | ckpt = tf.train.latest_checkpoint(args.checkpoint_path) 48 | if ckpt: 49 | print("restore form %s "%(ckpt)) 50 | st = int(ckpt.split('-')[-1]) 51 | saver.restore(sess, ckpt) 52 | sess.run(global_step.assign(st)) 53 | else: 54 | tf.global_variables_initializer().run() 55 | summary_writer = tf.summary.FileWriter(args.checkpoint_path) 56 | summary_writer.add_graph(sess.graph) 57 | start_time = time.time() 58 | step = 0 59 | iterator = data_generator.get_batch(args.train_image_list, batch_size) 60 | for batch in iterator: 61 | if batch is None: 62 | print("batch is None") 63 | continue 64 | image = batch[0] 65 | labels = batch[1] 66 | feed_dict = {model.images: image, model.labels: labels} 67 | _, loss, accuracy,summary, g_step, logits, lr = sess.run( 68 | [train_op, model.loss, model.accuracy, merge_summary_op, global_step, model.logits, learning_rate ], 69 | feed_dict=feed_dict) 70 | if loss is None: 71 | print(np.max(logits), np.min(logits)) 72 | exit(0) 73 | if step % 10 ==0: 74 | print(np.max(logits), np.min(logits)) 75 | print("step:%d, lr: %f, loss: %f, accuracy: %f"%(g_step, lr, loss, accuracy)) 76 | if step % 100 == 0: 77 | summary_writer.add_summary(summary=summary, global_step=g_step) 78 | saver.save(sess=sess, save_path=os.path.join(args.checkpont_path, 'model'), global_step=g_step) 79 | step += 1 80 | print("cost: ", time.time() - start_time) 81 | 82 | if __name__ == "__main__": 83 | args = parser.parse_args() 84 | print(args) 85 | train() 86 | 87 | -------------------------------------------------------------------------------- /layout_analysis/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import os 6 | import tensorflow as tf 7 | import numpy as np 8 | from dataset import data_factory 9 | import model 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser(description='Process some integers.') 13 | # trian 14 | parser.add_argument('--max_step', default=500000, help='max step for train') 15 | parser.add_argument('--learning_rate', default=0.00001, help='init learning rate') 16 | parser.add_argument('--checkpoints_dir', default='./checkpoints', help='checkpoints dir') 17 | parser.add_argument('--batch_size', default=4, help='batch size', type=int) 18 | 19 | # data 20 | parser.add_argument('--train_label_list', help='train label list ') 21 | parser.add_argument('--image_dir', help='image dir') 22 | parser.add_argument('--data_name', help='data generator name') 23 | parser.add_argument('--num_class', default=3, help='num of classes', type=int) 24 | parser.add_argument('--image_width', default=768, help='image width', type=int) 25 | parser.add_argument('--image_height', default=768, help='image height', type=int) 26 | 27 | 28 | 29 | def train(): 30 | train_model = model.UnetModel(args.num_class, is_training=True , dice_loss=True) 31 | global_step = tf.train.get_or_create_global_step() 32 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 33 | saver = tf.train.Saver() 34 | with tf.control_dependencies(update_ops): 35 | loss_op = tf.train.AdamOptimizer(args.learning_rate).minimize(train_model.loss, global_step=global_step) 36 | train_op = tf.group([loss_op, update_ops]) 37 | #tf.summary.scalar(name='seg_loss', tensor=train_model.loss) 38 | #tf.summary.scalar(name='aux_loss', tensor=train_model.aux_loss) 39 | pred = tf.argmax(train_model.logits, axis=-1, output_type=tf.int32) 40 | pred = tf.cast(pred, dtype=tf.float32) 41 | pred = tf.expand_dims(pred, axis=-1) 42 | label = tf.cast(train_model.label, dtype=tf.float32) 43 | tf.summary.image('image', train_model.image * 255.0) 44 | tf.summary.image('label', label* 50.0) 45 | tf.summary.image('mask', train_model.mask* 50.0) 46 | tf.summary.image('pred', pred * 50.0) 47 | tf.summary.scalar(name="loss", tensor=train_model.loss) 48 | tf.summary.scalar(name="acc", tensor=train_model.acc) 49 | merge_summary_op = tf.summary.merge_all() 50 | data_generator_fn = data_factory.get_data(args.data_name) 51 | with tf.Session() as sess: 52 | ckpt = tf.train.latest_checkpoint(args.checkpoints_dir) 53 | init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 54 | sess.run(init) 55 | if ckpt: 56 | st = int(ckpt.split("-")[-1]) 57 | saver.restore(sess, ckpt) 58 | sess.run(global_step.assign(st)) 59 | print("restore from %s"%(ckpt)) 60 | summary_writer = tf.summary.FileWriter(args.checkpoints_dir) 61 | summary_writer.add_graph(sess.graph) 62 | data_iterator = data_generator_fn.get_batch(args.train_label_list, args.image_dir, args.batch_size, 'train') 63 | for batch_data in data_iterator: 64 | if batch_data is None: 65 | print("Warning: batch_data is None") 66 | continue 67 | images = np.array(batch_data[0]) 68 | labels = np.array(batch_data[1]) 69 | labels = np.expand_dims(labels, -1) 70 | mask = np.array(batch_data[2]) 71 | mask = np.expand_dims(mask, -1) 72 | _, loss, _, acc, s, summary = sess.run([train_op, train_model.loss, train_model.acc_op, train_model.acc, global_step, merge_summary_op], 73 | feed_dict={train_model.image: images, train_model.label: labels, train_model.mask: mask}) 74 | print("step:%d loss : %f acc: %f"%(s, loss , acc )) 75 | if s % 100 == 0: 76 | summary_writer.add_summary(summary=summary, global_step=s) 77 | saver.save(sess=sess, save_path=os.path.join(args.checkpoints_dir, 'model'), global_step=s) 78 | if s > args.max_step: 79 | print("train finish") 80 | break 81 | 82 | if __name__ == '__main__': 83 | args = parser.parse_args() 84 | train() 85 | 86 | -------------------------------------------------------------------------------- /crnn_ocr/create_tfrecord.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import sys 11 | import random 12 | import json 13 | import tensorflow as tf 14 | import cv2 15 | from multiprocessing import Pool 16 | 17 | _IMAGE_HEIGHT = 32 18 | 19 | tf.app.flags.DEFINE_string( 20 | 'image_dir', '', 'Dataset root folder with images.') 21 | 22 | tf.app.flags.DEFINE_string( 23 | 'image_list', 'ocr/dataset/text_line_gen/labels.txt', 'Path of dataset annotation file.') 24 | 25 | tf.app.flags.DEFINE_string( 26 | 'data_dir', 'ocr/dataset/text_line_gen_tfrecords', 'Directory where tfrecords are written to.') 27 | 28 | tf.app.flags.DEFINE_float( 29 | 'validation_split_fraction', 0.1, 'Fraction of training data to use for validation.') 30 | 31 | tf.app.flags.DEFINE_boolean( 32 | 'shuffle_list', True, 'Whether shuffle data in annotation file list.') 33 | 34 | tf.app.flags.DEFINE_string( 35 | 'vocab_file', './simple_vocab.json', 'Path to char map json file') 36 | 37 | tf.app.flags.DEFINE_integer('max_seq_length', 1024, 'max sequence length') 38 | tf.app.flags.DEFINE_integer("channel_size", 1, 'image channle size') 39 | 40 | 41 | FLAGS = tf.app.flags.FLAGS 42 | 43 | _MAGE_MAX_LENGTH = FLAGS.max_seq_length 44 | _MAX_LABEL_LENGTH = 150 45 | 46 | if FLAGS.channel_size == 3: 47 | GRAY = False 48 | else: 49 | GRAY = True 50 | VOCAB_DICT = json.load(open(FLAGS.vocab_file, 'r')) 51 | 52 | 53 | def _int64_feature(value): 54 | if not isinstance(value, list): 55 | value = [value] 56 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 57 | 58 | def _bytes_feature(value): 59 | if not isinstance(value, list): 60 | value = [value] 61 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 62 | 63 | 64 | def _string_to_int(label): 65 | # convert string label to int list by char map 66 | int_list = [] 67 | for c in list(label): 68 | if VOCAB_DICT.get(c) is None: 69 | # todo same unk 70 | int_list.append(VOCAB_DICT.get("")) 71 | else: 72 | int_list.append(VOCAB_DICT[c]) 73 | return int_list 74 | 75 | def _write_tfrecord(writer_path, anno_lines ): 76 | writer= tf.io.TFRecordWriter(writer_path) 77 | for i, line in enumerate(anno_lines): 78 | line = line.strip('\n') 79 | line = line.strip() 80 | image_name = line.split(" ")[0] 81 | image_path = image_name 82 | label = " ".join(line.split(" ")[1:]) 83 | if not label: 84 | print("label is None") 85 | continue 86 | if FLAGS.channel_size==1: 87 | image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE ) 88 | else: 89 | image = cv2.imread(image_path) 90 | if image is None: 91 | print("image is None") 92 | continue 93 | h, w = image.shape[:2] 94 | height = _IMAGE_HEIGHT 95 | width = int(w * height / h) 96 | try: 97 | image = cv2.resize(image, (width, height)) 98 | except Exception as e: 99 | print(e) 100 | continue 101 | if width > _MAGE_MAX_LENGTH: 102 | image = cv2.resize(image, (_MAGE_MAX_LENGTH, height)) 103 | is_success, image_buffer = cv2.imencode('.png', image) 104 | if not is_success: 105 | print("encoder image error") 106 | continue 107 | image_name = image_name if sys.version_info[0] < 3 else image_name.encode('utf-8') 108 | labels_ids = _string_to_int(label) 109 | if len(labels_ids) < 1 or len(labels_ids) > _MAX_LABEL_LENGTH - 1: 110 | print("labels_ids is too long or short") 111 | continue 112 | features = tf.train.Features(feature={ 113 | "labels":_int64_feature(labels_ids), 114 | 'images': _bytes_feature(image_buffer.tostring()), 115 | 'imagenames': _bytes_feature(image_name) 116 | }) 117 | example = tf.train.Example(features=features) 118 | writer.write(example.SerializeToString()) 119 | writer.close() 120 | 121 | 122 | 123 | def start_create_process(anno_lines, num_shards, num_thread, dataset_split): 124 | with Pool(num_thread) as pool: 125 | total_num = len(anno_lines) 126 | every_shard_num = int(total_num / num_shards) 127 | shard_anno_lines= [] 128 | for i in range(num_shards-1): 129 | shard_anno_lines.append(anno_lines[i* every_shard_num: (i+1) * every_shard_num]) 130 | shard_anno_lines.append(anno_lines[-1*every_shard_num: ]) 131 | writer_list = [os.path.join(FLAGS.data_dir, 'ocr-%s-%d.tfrecord')%(dataset_split, i) for i in range(num_shards)] 132 | assert len(shard_anno_lines) == len(writer_list) 133 | args = list(zip(writer_list, shard_anno_lines)) 134 | pool.starmap(_write_tfrecord, args) 135 | 136 | 137 | def _convert_dataset(): 138 | with open(FLAGS.anno_file, 'r') as anno_fp: 139 | anno_lines = anno_fp.readlines() 140 | print(FLAGS.anno_file) 141 | print(len(anno_lines)) 142 | if FLAGS.shuffle_list: 143 | random.shuffle(anno_lines) 144 | 145 | if not os.path.exists(FLAGS.data_dir): 146 | os.mkdir(FLAGS.data_dir) 147 | split_train_index = int(len(anno_lines) * 0.97) 148 | split_test_index = int(len(anno_lines) * 0.99) 149 | train_anno_lines = anno_lines[:split_train_index] 150 | test_anno_lines = anno_lines[split_train_index: split_test_index] 151 | validation_anno_lines = anno_lines[split_test_index:] 152 | start_create_process(train_anno_lines, 100, 10, 'train') 153 | start_create_process(validation_anno_lines, 10, 10, 'validation') 154 | start_create_process(test_anno_lines, 10, 10, 'test') 155 | 156 | 157 | 158 | def main(unused_argv): 159 | _convert_dataset() 160 | 161 | if __name__ == '__main__': 162 | tf.app.run() 163 | -------------------------------------------------------------------------------- /crnn_ocr/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | from tensorflow.contrib import rnn 7 | 8 | 9 | class CRNN(object): 10 | def __init__(self, config): 11 | self.config = config 12 | self.lstm_num_units = self.config.lstm_num_units 13 | self.channel_size = config.channel_size 14 | self.num_classes = config.num_classes + 1 15 | self.build_graph() 16 | self.train_loss() 17 | 18 | def build_graph(self): 19 | self.images = tf.placeholder(shape=[None, 32, None, self.channel_size], name='images', dtype=tf.float32) 20 | self.batch_size = tf.shape(self.images)[0] 21 | self.labels = tf.sparse.placeholder(name="labels", dtype=tf.int32) 22 | 23 | self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob') 24 | self.is_training = tf.placeholder(tf.bool, name='training') 25 | 26 | self.sequence_length = tf.placeholder(shape=[None], dtype=tf.int32, name='sequence_length') 27 | conv_output = self.cnn(self.images) # batch_size, 1, w / 4, 512 28 | conv_output = tf.transpose(conv_output, (0, 2, 1, 3)) 29 | self.conv_output = tf.squeeze(conv_output, axis=2) 30 | self.bidirectionnal_rnn(self.conv_output, self.sequence_length) 31 | 32 | 33 | def train_loss(self): 34 | loss = tf.nn.ctc_loss( 35 | labels=self.labels, 36 | inputs=self.logits, 37 | sequence_length=self.sequence_length, 38 | ignore_longer_outputs_than_inputs=True, 39 | #ctc_merge_repeated=True 40 | ) 41 | self.loss = tf.reduce_mean(loss, name='loss') 42 | return self.loss 43 | 44 | 45 | def bidirectionnal_rnn(self, input_tensor, input_sequence_length): 46 | lstm_num_units = self.config.lstm_num_units 47 | print("rnn input tensor ===> ", input_tensor) 48 | with tf.variable_scope('lstm_layers'): 49 | fw_cell_list = [rnn.BasicLSTMCell(nh, forget_bias=1.0, name='fw_cell_%d'%(nh)) for nh in [lstm_num_units] * 2] 50 | bw_cell_list = [rnn.BasicLSTMCell(nh, forget_bias=1.0, name='bw_cell_%d'%(nh)) for nh in [lstm_num_units] * 2] 51 | 52 | stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn( 53 | cells_fw=fw_cell_list, 54 | cells_bw=bw_cell_list, 55 | inputs=input_tensor, 56 | sequence_length=input_sequence_length, 57 | dtype=tf.float32) 58 | hidden_num = lstm_num_units * 2 59 | rnn_reshaped = tf.nn.dropout(stack_lstm_layer, keep_prob=self.keep_prob) 60 | w = tf.get_variable(initializer=tf.truncated_normal([hidden_num, self.num_classes], stddev=0.02), name="w") 61 | w_t = tf.tile(tf.expand_dims(w, 0),[self.batch_size,1,1]) 62 | logits = tf.matmul(rnn_reshaped, w_t, name="nn_logits") 63 | self.logits = tf.identity(tf.transpose(logits, (1, 0, 2)), name='logits') 64 | return logits 65 | 66 | def _build_pred(self): 67 | decoded, log_prob = tf.nn.ctc_greedy_decoder(self.logits, self.sequence_length) 68 | self.decoded = tf.identity(decoded[0], name='decoded') 69 | self.log_prob = tf.identity(log_prob, name='log_prob') 70 | if self.is_training: 71 | pred_str_labels = tf.as_string(self.decoded.values) 72 | pred_tensor = tf.SparseTensor(indices=self.decoded.indices, values=pred_str_labels, dense_shape=self.decoded.dense_shape) 73 | true_str_labels = tf.as_string(self.labels.values) 74 | true_tensor = tf.SparseTensor(indices=self.labels.indices, values=true_str_labels, dense_shape=self.labels.dense_shape) 75 | self.edit_distance = tf.reduce_mean(tf.edit_distance(pred_tensor, true_tensor, normalize=True), name='distance') 76 | 77 | def cnn(self, inputs): 78 | with tf.variable_scope('cnn_feature'): 79 | # (None, 32, w, 64) 80 | conv1 = tf.layers.conv2d(inputs=inputs, filters=64, kernel_size=(3, 3), padding="same", activation=tf.nn.relu, name="conv1",) 81 | # (None, 16, w/2, 64) 82 | pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2, name='pool1', padding='valid') 83 | # (None, 16, w/2, 128) 84 | conv2 = tf.layers.conv2d(inputs=pool1, filters=128, kernel_size=(3, 3), padding="same", activation=tf.nn.relu, name="conv2") 85 | # (None, 8, w/4, 128) 86 | pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2, name="pool2", padding='valid') 87 | # (None, 8, w/4, 256) 88 | conv3 = tf.layers.conv2d(inputs=pool2, filters=256, kernel_size=(3, 3), padding="same", activation=tf.nn.relu, name="conv3") 89 | conv4 = tf.layers.conv2d(inputs=conv3, filters=256, kernel_size=(3, 3), padding="same", activation=tf.nn.relu, name="conv4") 90 | # (None, 4, w/4, 256), 91 | pool3 = tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 1], strides=[2, 1], padding="valid", name="pool3") 92 | # (None, 4, w/4, 512) 93 | conv5 = tf.layers.conv2d(inputs=pool3, filters=512, kernel_size=(3, 3), padding="same", activation=None, name='conv5') 94 | # (None, 4, w/4, 512) 95 | bnorm1 = tf.layers.batch_normalization(conv5, name="bnorm1", training=self.is_training) 96 | bnorm1 = tf.nn.relu(bnorm1) 97 | # (None, 4, w/4, 512) 98 | conv6 = tf.layers.conv2d(inputs=bnorm1, filters=512, kernel_size=(3, 3), padding="same", activation=None, name="conv6") 99 | bnorm2 = tf.layers.batch_normalization(conv6, name="bnorm2", training=self.is_training) 100 | bnorm2 = tf.nn.relu(bnorm2) 101 | # (None, 2, w/4, 512) 102 | pool4 = tf.layers.max_pooling2d(inputs=bnorm2, pool_size=[2, 1], strides=[2, 1], padding="valid", name="pool4") 103 | conv7 = tf.layers.conv2d(inputs=pool4, filters=512, kernel_size=2, strides=[2, 1], padding="same", activation=tf.nn.relu, name="conv7") 104 | return conv7 105 | 106 | -------------------------------------------------------------------------------- /single_word_ocr/densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | 7 | class DenseNet(object): 8 | def __init__(self, batch_size, num_classes, mode='train', center_loss_alpha=0.95): 9 | self.filters = 24 10 | self.center_loss_alpha = center_loss_alpha 11 | 12 | if mode == 'train': 13 | self.dropout_rate = 0.5 14 | else: 15 | self.dropout_rate = 0.0 16 | 17 | self.num_classes = num_classes 18 | self.is_training = mode=='train' 19 | self.images = tf.placeholder(shape=[batch_size, 64, 64, 3], name='image', dtype=tf.float32) 20 | self.labels = tf.placeholder(shape=[batch_size], name='labels', dtype=tf.int64) 21 | self.logits = self.dense_net(self.images) 22 | 23 | if self.is_training: 24 | print("feautres =====> ", self.features) 25 | #self.center_loss, _ = self.center_loss(self.features, self.labels, self.center_loss_alpha, self.num_classes) 26 | #self.softmax_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.labels)) 27 | self.loss = self.softmax_loss 28 | 29 | self.predict_prob = tf.nn.softmax(self.logits, name='prob' ) 30 | print("predict_prob====>", self.predict_prob) 31 | self.prediction = tf.argmax(self.predict_prob, axis=-1, name='prediction') 32 | print("prediction====>", self.prediction) 33 | 34 | 35 | if self.is_training: 36 | equal = tf.equal(self.labels, self.prediction) 37 | self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32)) 38 | 39 | def center_loss(self, features, label, alpha, num_classes): 40 | """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition" 41 | (http://ydwen.github.io/papers/WenECCV16.pdf) 42 | copy from facenet: https://github.com/davidsandberg/facenet 43 | """ 44 | num_features = features.get_shape()[1] 45 | centers = tf.get_variable('centers', [num_classes, num_features], dtype=tf.float32, 46 | initializer=tf.constant_initializer(0), trainable=False) 47 | label = tf.reshape(label, [-1]) 48 | centers_batch = tf.gather(centers, label) 49 | diff = (1 - alpha) * (centers_batch - features) 50 | centers = tf.scatter_sub(centers, label, diff) 51 | with tf.control_dependencies([centers]): 52 | loss = tf.reduce_mean(tf.square(features - centers_batch)) 53 | return loss, centers 54 | 55 | 56 | 57 | def bottleneck_layer(self, net, name): 58 | with tf.variable_scope(name) as scope: 59 | net = tf.layers.batch_normalization(net, training=self.is_training) 60 | net = tf.nn.relu(net) 61 | net = tf.layers.conv2d(net, use_bias=False, filters=4* self.filters, kernel_size=[1, 1], strides=(1, 1) ) 62 | net = tf.nn.dropout(net, keep_prob=1.0-self.dropout_rate) 63 | net = tf.layers.batch_normalization(net, training=self.is_training) 64 | net = tf.nn.relu(net) 65 | net = tf.layers.conv2d(net, use_bias=False, filters=self.filters, kernel_size=(3, 3), strides=(1, 1), padding="SAME") 66 | net = tf.nn.dropout(net, keep_prob=1.0-self.dropout_rate ) 67 | return net 68 | 69 | 70 | def dense_block(self,net, layers, name): 71 | with tf.variable_scope(name) as scope: 72 | layers_output = [] 73 | layers_output.append(net) 74 | net = self.bottleneck_layer(net, 'bottleneck_0') 75 | layers_output.append(net) 76 | for i in range(layers-1): 77 | net = tf.concat(layers_output, axis=-1, name='concat_%d'%(i+1)) 78 | net = self.bottleneck_layer(net, 'bootleneck_%d'%(i+1)) 79 | layers_output.append(net) 80 | net = tf.concat(layers_output, axis=-1, name='concat_out') 81 | return net 82 | 83 | 84 | def transition_layer(self, net, name): 85 | with tf.variable_scope(name) as scope: 86 | in_channel_size = net.get_shape().as_list()[-1] 87 | print("in_channel_size====>", in_channel_size) 88 | net = tf.layers.batch_normalization(net, training=self.is_training) 89 | net = tf.nn.relu(net) 90 | net = tf.layers.conv2d(net, filters= int(in_channel_size*0.5), kernel_size=(1,1), strides=(1,1)) 91 | net = tf.nn.dropout(net, keep_prob=1-self.dropout_rate) 92 | net = tf.layers.average_pooling2d(net, pool_size=[2,2], strides=2, padding="VALID") 93 | return net 94 | 95 | 96 | def global_avg_pool(self, net): 97 | shape = net.get_shape().as_list() 98 | print(shape) 99 | width = shape[2] 100 | height = shape[1] 101 | net = tf.layers.average_pooling2d(net, pool_size=(height, width), strides=1) 102 | return net 103 | 104 | 105 | def dense_net(self, x ): 106 | net = tf.layers.conv2d(inputs=x, 107 | use_bias=False, filters=self.filters, kernel_size=(7,7), strides=(2,2), padding="SAME", name='conv_0') 108 | print("first conv===> ", net) 109 | net = self.dense_block(net, 6, 'block_0') 110 | print("block_0 ====>", net) 111 | net = self.transition_layer(net, 'transition_0') 112 | print("transition_0 ====>", net) 113 | 114 | net = self.dense_block(net, 12, 'block_1') 115 | print("block_1 ====>", net) 116 | net = self.transition_layer(net, 'transition_1') 117 | print("transition_1 ====>", net) 118 | 119 | net = self.dense_block(net, 48, 'block_2') 120 | print("block_2 ====>", net) 121 | net = self.transition_layer(net, 'transition_2') 122 | print("transition_2 ====>", net) 123 | 124 | net = self.dense_block(net, 32, 'block_3') 125 | 126 | net = tf.layers.batch_normalization(net, training=self.is_training) 127 | net = tf.nn.relu(net) 128 | net = self.global_avg_pool(net) 129 | # todo add center loss 130 | self.features = tf.squeeze(net, name='features') 131 | net = tf.layers.dense(net, units=self.num_classes, name='linear') 132 | net = tf.squeeze(net, axis=(1,2), name='logits') 133 | return net 134 | 135 | -------------------------------------------------------------------------------- /layout_analysis/model.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | 4 | import tensorflow as tf 5 | 6 | def down_block(inputs, filters, block_index, is_training): 7 | with tf.variable_scope('g_encoder') as scope: 8 | inputs = tf.pad(inputs, paddings=[[0, 0], [1,1], [1, 1],[0, 0]], name="pad_%d"%(block_index)) 9 | out = tf.layers.conv2d(inputs, filters=filters, kernel_size=[4,4], strides=2, activation=tf.nn.leaky_relu, 10 | name="g_conv_%d"%(block_index)) 11 | return out 12 | 13 | 14 | def up_block(input_a, input_b, out_filters, block_index, is_training, name_scope='g_decoder'): 15 | with tf.variable_scope(name_scope) as scope: 16 | inputs = tf.concat([input_a, input_b], axis=-1) 17 | out = tf.layers.conv2d_transpose(inputs, filters=out_filters, kernel_size=[4,4], strides=2, activation=None, 18 | name="g_up_cov_%d"%(block_index), padding="SAME") 19 | out = tf.layers.batch_normalization(out, training=is_training, name='g_up_norm_%d'%(block_index)) 20 | out = tf.nn.relu(out) 21 | return out 22 | 23 | 24 | def conv_norm_leakrelu_layer(inputs, filters, is_training, scope_name): 25 | with tf.variable_scope(scope_name) as scope: 26 | out = tf.pad(inputs, paddings=[[0, 0], [1,1], [1,1], [0, 0]] ) 27 | out = tf.layers.conv2d(out, filters=filters, kernel_size=[4,4], strides=[2, 2], activation=None) 28 | out = tf.layers.batch_normalization(out, training=is_training) 29 | out = tf.nn.leaky_relu(out ) 30 | return out 31 | 32 | class UnetModel(object): 33 | def __init__(self, num_class, is_training=True, dice_loss=False): 34 | self.image = tf.placeholder(shape=[None, None, None, 3], name="image", dtype=tf.float32) 35 | self.mask = tf.placeholder(shape=[None, None, None, 1], name='mask', dtype=tf.float32) 36 | self.label = tf.placeholder(shape=[None, None, None, 1], name='label', dtype=tf.int32) 37 | self.is_training = is_training 38 | self.num_class = num_class 39 | self._build_graph(self.is_training) 40 | self.prob = tf.sigmoid(self.logits, name='prob') 41 | self.prediction = tf.argmax(self.logits, axis=-1, name='prediction') 42 | 43 | self.filter_size_list = [64, 128, 256, 512, 512, 512] 44 | if is_training: 45 | if not dice_loss: 46 | self._loss() 47 | else: 48 | self.dice_loss() 49 | 50 | def _build_graph(self, is_training): 51 | with tf.variable_scope('encoder') as scope: 52 | block1 = down_block(self.image, 64, 1, is_training) 53 | block2 = down_block(block1, 128, 2, is_training) 54 | block3 = down_block(block2, 256, 3, is_training) 55 | block4 = down_block(block3, 512, 4, is_training) 56 | block5 = down_block(block4, 512, 5, is_training) 57 | block6 = down_block(block5, 512, 6, is_training) 58 | block7 = down_block(block6, 512, 7, is_training) 59 | center = down_block(block7, 512, 8, is_training) 60 | 61 | with tf.variable_scope('decoder') as scope: 62 | center = tf.layers.conv2d_transpose(center, 512, kernel_size=[4,4], strides=2, 63 | activation=None, padding="SAME", name='g_center') 64 | center_norm = tf.layers.batch_normalization(center, training=is_training, name='g_center_norml') 65 | center_norm = tf.nn.relu(center_norm) 66 | upblock7 = up_block(block7, center_norm, 512, 7, is_training) 67 | upblock6 = up_block(block6, upblock7, 512, 6, is_training) 68 | upblock5 = up_block(block5, upblock6, 512, 4, is_training) 69 | upblock4 = up_block(block4, upblock5, 256, 3, is_training) 70 | upblock3 = up_block(block3, upblock4, 128, 2, is_training ) 71 | upblock2 = up_block(block2, upblock3, 64, 1, is_training) 72 | out = tf.concat([block1, upblock2], axis=-1) 73 | out = tf.layers.conv2d_transpose(out, filters=self.num_class, kernel_size=(4, 4), 74 | strides=2, padding="SAME", activation=None, name='g_out') 75 | self.logits = tf.identity(out, name='logits') 76 | 77 | 78 | def _loss(self): 79 | shaped_mask = tf.reshape(self.mask, shape=[-1,]) 80 | shaped_logits = tf.reshape(self.logits, shape=[-1, self.num_class]) 81 | 82 | reshape_labels = tf.reshape(self.label, shape=[-1, ]) 83 | sparse_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=reshape_labels, logits=shaped_logits) 84 | 85 | self.loss = tf.reduce_mean(tf.multiply(shaped_mask, sparse_loss)) 86 | self.acc, self.acc_op = tf.metrics.accuracy(labels=tf.reshape(self.label, shape=[-1, ]), predictions=tf.argmax(shaped_logits, 1)) 87 | 88 | def dice_coefficient(self, y_true_cls, y_pred_cls, training_mask): 89 | ''' 90 | dice loss 91 | :param y_true_cls: ground truth 92 | :param y_pred_cls: predict 93 | :param training_mask: 94 | :return: 95 | ''' 96 | eps = 1e-5 97 | intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask) 98 | union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps 99 | dice = 2 * intersection / union 100 | loss = 1. - dice 101 | return dice, loss 102 | 103 | 104 | def dice_loss(self): 105 | probs = tf.nn.softmax(self.logits, axis=-1, name='probs') 106 | labels = tf.one_hot(self.label, depth=self.num_class, axis=-1) 107 | labels = tf.squeeze(labels) 108 | prob_list = tf.split(value=probs, num_or_size_splits=self.num_class, axis=3) 109 | label_list = tf.split(value=labels, num_or_size_splits=self.num_class, axis=3) 110 | loss_list = [] 111 | 112 | for i in range(self.num_class): 113 | cls_dice, cls_loss = self.dice_coefficient(label_list[i], prob_list[i], self.mask) 114 | loss_list.append(cls_loss) 115 | 116 | self.loss = tf.reduce_sum(loss_list) 117 | shaped_logits = tf.reshape(self.logits, shape=[-1, self.num_class]) 118 | self.acc, self.acc_op = tf.metrics.accuracy(labels=tf.reshape(self.label, shape=[-1, ]), predictions=tf.argmax(shaped_logits, 1)) 119 | 120 | -------------------------------------------------------------------------------- /layout_analysis/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import os 6 | import model 7 | import tensorflow as tf 8 | import numpy as np 9 | import cv2 10 | import copy 11 | import time 12 | from tensorflow.python.saved_model import tag_constants 13 | from dataset import data_factory 14 | import argparse 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Process some integers.') 18 | parser.add_argument('--name', help='an integer for the accumulator', ) 19 | parser.add_argument('--export_dir', default='./export_models', help='export model dir') 20 | parser.add_argument('--export', dest='export',type=bool, help='if export saved model') 21 | parser.add_argument('--data_dir', help='valid data dir') 22 | parser.set_defaults(export=False) 23 | 24 | COLOR_LIST=[(0, 255, 0),(0, 0, 255), (0, 255, 255)] 25 | 26 | def mask_to_bbox(mask, image, num_class, out_path=None, out_file_name=None): 27 | bbox_list = [] 28 | im = copy.copy(image) 29 | mask = mask.astype(np.uint8) 30 | for i in range(1, num_class, 1): 31 | c_bbox_list = [] 32 | c_mask = np.zeros_like(mask) 33 | c_mask[np.where(mask==i)] = 255 34 | bimg , countours, hier = cv2.findContours(c_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 35 | for cnt in countours: 36 | area = cv2.contourArea(cnt) 37 | if area < 50: 38 | continue 39 | epsilon = 0.005 * cv2.arcLength(cnt,True) 40 | approx = cv2.approxPolyDP(cnt,epsilon,True) 41 | (x, y, w, h) = cv2.boundingRect(approx) 42 | c_bbox_list.append([x, y, x+w, y+h]) 43 | if out_path is not None: 44 | color = COLOR_LIST[i-1] 45 | im=cv2.rectangle(im, pt1=(x, y), pt2=(x+w, y+h),color=color, thickness=2) 46 | bbox_list.append(c_bbox_list) 47 | if out_path is not None: 48 | outf = os.path.join(out_path, out_file_name) 49 | print(outf) 50 | cv2.imwrite(outf, im) 51 | return bbox_list 52 | 53 | 54 | def metrics(true_mask, predict_mask): 55 | pixel_accuracy = np.sum(predict_mask==true_mask) / true_mask.size 56 | mean_pixel_accuracy = mean_accuraccy(true_mask, predict_mask) 57 | print("pixel_accuracy:", pixel_accuracy) 58 | print("mean_pixel_accuracy:", mean_pixel_accuracy) 59 | return pixel_accuracy, mean_pixel_accuracy 60 | 61 | def mean_accuraccy(true_mask, predict_mask): 62 | """ 63 | computes mean accuraccy: 1/n_cl * sum_i(n_ii/t_i) 64 | """ 65 | s, cl = per_class_accuraccy(true_mask, predict_mask) 66 | return np.sum(s) / cl.size 67 | 68 | def per_class_accuraccy(true_mask, predict_mask): 69 | """ 70 | computes pixel by pixel accuraccy per class in target 71 | sum_i(n_ii/t_i) 72 | """ 73 | cl = np.unique(predict_mask) 74 | n_cl = cl.size 75 | s = np.zeros(n_cl) 76 | for i, c in enumerate(cl): 77 | s[i] = (predict_mask[predict_mask== true_mask] == c).sum() / (predict_mask== c).sum() 78 | return (s, cl) 79 | 80 | 81 | def compute_iou(groundtruth_box, detection_box): 82 | g_ymin, g_xmin, g_ymax, g_xmax = groundtruth_box 83 | d_ymin, d_xmin, d_ymax, d_xmax = detection_box 84 | 85 | xa = max(g_xmin, d_xmin) 86 | ya = max(g_ymin, d_ymin) 87 | xb = min(g_xmax, d_xmax) 88 | yb = min(g_ymax, d_ymax) 89 | 90 | intersection = max(0, xb - xa + 1) * max(0, yb - ya + 1) 91 | 92 | boxAArea = (g_xmax - g_xmin + 1) * (g_ymax - g_ymin + 1) 93 | boxBArea = (d_xmax - d_xmin + 1) * (d_ymax - d_ymin + 1) 94 | 95 | return intersection / float(boxAArea + boxBArea - intersection) 96 | 97 | def bbox_accuracy(true_bbox_list, predict_bbox_list): 98 | total_true = len(true_bbox_list[0]) 99 | total_predict = len(predict_bbox_list[0]) 100 | tp = 0 101 | for true_bbox in true_bbox_list[0]: 102 | for predict_bbox in predict_bbox_list[0]: 103 | iou = compute_iou(true_bbox, predict_bbox) 104 | if iou >= 0.9: 105 | tp +=1 106 | break 107 | precision = 1.0 * tp / total_predict 108 | recal = 1.0 * tp / total_true 109 | f1 = 2 * precision * recal / (precision + recal) 110 | print(total_true, total_predict, tp) 111 | return total_true, total_predict, tp 112 | 113 | 114 | def predict(): 115 | data_config = data_factory.get_data_config(args.name) 116 | check_points_path = tf.train.latest_checkpoint(data_config.get("check_points_path")) 117 | num_class = data_config.get("num_class") 118 | test_model = model.UnetModel(num_class, False) 119 | saver = tf.train.Saver() 120 | total_true = 0 121 | total_predict = 0 122 | total_tp = 0 123 | counter = 0 124 | with tf.Session() as sess: 125 | saver.restore(sess=sess, save_path=check_points_path) 126 | if args.export: 127 | tensor_image_input_info = tf.saved_model.utils.build_tensor_info(test_model.image) 128 | tensor_prob_output_info = tf.saved_model.utils.build_tensor_info(test_model.prob) 129 | signature=tf.saved_model.signature_def_utils.build_signature_def( 130 | inputs={"images":tensor_image_input_info, }, 131 | outputs={"prob":tensor_prob_output_info}) 132 | ex_dir = str(int(time.time())) 133 | export_dir = os.path.join(args.export_dir, args.name, ex_dir) 134 | builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 135 | builder.add_meta_graph_and_variables(sess=sess, tags=[tag_constants.SERVING], signature_def_map={"predict":signature}) 136 | builder.save() 137 | print("model exporte at %s"%(export_dir)) 138 | exit(0) 139 | data_iterator = data_factory.get_data_iterator(args.name, mode='test', batch_size=1) 140 | for batch_data in data_iterator: 141 | if batch_data is None: 142 | continue 143 | images = np.array(batch_data[0]) 144 | labels = batch_data[1] 145 | labels = np.array(labels[0]) 146 | logits = sess.run([test_model.logits], feed_dict={test_model.image:images}) 147 | mask = logits[0] 148 | mask = np.argmax(mask[0], axis=-1) 149 | mask = mask.astype(np.uint8) 150 | im = images[0] * 255.0 151 | labels = labels.astype(np.uint8) 152 | test_out_path = './test_out' 153 | if not os.path.exists(test_out_path): 154 | os.makedirs(test_out_path) 155 | true_bbox_list = mask_to_bbox(labels, im, num_class, out_path=test_out_path, out_file_name=str(counter)+"_true.png") 156 | predict_bbox_list = mask_to_bbox(mask, im, num_class, out_path=test_out_path, out_file_name=str(counter) + "_predict.png") 157 | metrics(labels, mask) 158 | counter += 1 159 | if counter > 20: 160 | break 161 | 162 | if __name__ == "__main__": 163 | args = parser.parse_args() 164 | print(args) 165 | predict() 166 | 167 | -------------------------------------------------------------------------------- /crnn_ocr/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | import cv2 7 | import json 8 | import os 9 | import model 10 | import numpy as np 11 | import random 12 | from tensorflow.python.saved_model import tag_constants 13 | import Levenshtein 14 | 15 | 16 | tf.app.flags.DEFINE_string( 17 | 'image_dir', 'ocr/dataset/text_line_gen', 'Path to the directory containing images.') 18 | tf.app.flags.DEFINE_string( 19 | 'image_list', 'ocr/dataset/text_line_gen/labels.txt', 'Path to the images list txt file.') 20 | tf.app.flags.DEFINE_string( 21 | 'model_dir', 'ocr/models/crnn_gen_model', 'Base directory for the model.') 22 | tf.app.flags.DEFINE_integer( 23 | 'lstm_hidden_layers', 2, 'The number of stacked LSTM cell.') 24 | tf.app.flags.DEFINE_integer( 25 | 'lstm_hidden_uints', 256, 'The number of units in each LSTM cell') 26 | tf.app.flags.DEFINE_string( 27 | 'char_map_json_file', '', 'Path to char map json file') 28 | 29 | tf.flags.DEFINE_boolean('export', False, 'if export model') 30 | tf.flags.DEFINE_boolean('eval', False, 'if evaluate model') 31 | tf.app.flags.DEFINE_integer('max_seq_length', 1024 , '') 32 | tf.app.flags.DEFINE_integer('channel_size', 3, 'image channels') 33 | 34 | FLAGS = tf.app.flags.FLAGS 35 | _IMAGE_HEIGHT = 32 36 | _MAX_LENGTH = FLAGS.max_seq_length 37 | 38 | def _int_to_string(value, char_map_dict): 39 | word = char_map_dict.get(int(value)) 40 | if word is None : 41 | word = " " 42 | elif len(char_map_dict.keys()) == int(value): 43 | word = "" 44 | else: 45 | word = word.strip("\n") 46 | return word 47 | 48 | 49 | def _sparse_matrix_to_list(sparse_matrix, char_map_dict): 50 | indices = sparse_matrix.indices 51 | values = sparse_matrix.values 52 | dense_shape = sparse_matrix.dense_shape 53 | dense_matrix = len(char_map_dict.keys()) * np.ones(dense_shape, dtype=np.int32) 54 | 55 | for i, indice in enumerate(indices): 56 | dense_matrix[indice[0], indice[1]] = values[i] 57 | string_list = [] 58 | for row in dense_matrix: 59 | string = [] 60 | for val in row: 61 | string.append(_int_to_string(val, char_map_dict)) 62 | string_list.append(''.join(s for s in string if s != '*')) 63 | return string_list 64 | 65 | def standardize(img): 66 | mean = np.mean(img) 67 | std = np.std(img) 68 | img = (img - mean) / std 69 | return img 70 | 71 | class Config(): 72 | num_classes = 0 73 | lstm_num_units = 256 74 | channel_size = FLAGS.channel_size 75 | 76 | def load_char_map(): 77 | char_map_dict = json.load(open(FLAGS.char_map_json_file)) 78 | return char_map_dict 79 | 80 | import time 81 | def merge_text(content): 82 | words = list(content) 83 | new_words = [] 84 | last_word = "" 85 | for w in words: 86 | if w in [' ']: 87 | if last_word and (last_word.islower() or last_word.isupper()): 88 | new_words.append(w) 89 | else: 90 | continue 91 | else: 92 | new_words.append(w) 93 | return "".join(new_words) 94 | 95 | 96 | def eval(): 97 | tf.reset_default_graph() 98 | char_map_dict = load_char_map() 99 | config = Config() 100 | config.num_classes = len(char_map_dict) + 1 101 | id_to_char = {v:k for k, v in char_map_dict.items()} 102 | crnn_net = model.CRNN(config) 103 | with open(FLAGS.image_list, 'r') as fd: 104 | image_names = [] 105 | true_labels = [] 106 | for i, line in enumerate(fd): 107 | seg = " " 108 | line = line.strip().split(seg) 109 | image_names.append(line[0]) 110 | true_labels.append(seg.join(line[1:])) 111 | 112 | index_list = random.choices(list(range(len(image_names))), k=50) 113 | image_names = [image_names[i] for i in index_list] 114 | labels = [true_labels[i] for i in index_list] 115 | 116 | saver = tf.train.Saver() 117 | save_path = tf.train.latest_checkpoint(FLAGS.model_dir) 118 | with tf.Session() as sess: 119 | saver.restore(sess=sess, save_path=save_path) 120 | print("restored from %s"%(save_path)) 121 | decoded, log_prob = tf.nn.ctc_greedy_decoder(crnn_net.logits, crnn_net.sequence_length, merge_repeated=True) 122 | if FLAGS.export: 123 | tensor_image_input_info = tf.saved_model.utils.build_tensor_info(crnn_net.images) 124 | tensor_seq_len_input_info = tf.saved_model.utils.build_tensor_info(crnn_net.sequence_length) 125 | tensor_is_traing_info = tf.saved_model.utils.build_tensor_info(crnn_net.is_training) 126 | tensor_keep_prob = tf.saved_model.utils.build_tensor_info(crnn_net.keep_prob) 127 | output_info = tf.saved_model.utils.build_tensor_info(decoded[0]) 128 | signature = tf.saved_model.signature_def_utils.build_signature_def( 129 | inputs={ 130 | 'images': tensor_image_input_info, 131 | 'sequence_length':tensor_seq_len_input_info, 132 | "is_training":tensor_is_traing_info, 133 | "keep_prob":tensor_keep_prob }, 134 | outputs={'decoded': output_info}) 135 | 136 | ex_dir = str(int(time.time())) 137 | builder = tf.saved_model.builder.SavedModelBuilder("./all_exported_models/%s/"%(ex_dir,)) 138 | builder.add_meta_graph_and_variables(sess=sess, tags=[tag_constants.SERVING], signature_def_map={"predict": signature}) 139 | builder.save() 140 | print("exported model at ") 141 | ignore = 0 142 | error_count = 0 143 | total_count = 0 144 | for i, image_name in enumerate(image_names): 145 | image_path = os.path.join(FLAGS.image_dir, image_name) 146 | if FLAGS.channel_size == 3: 147 | image = cv2.imread(image_path) 148 | else: 149 | image = cv2.imread(image_path , cv2.IMREAD_GRAYSCALE) 150 | if image is None: 151 | print('ignore') 152 | ignore+=1 153 | continue 154 | h, w = image.shape[:2] 155 | height = _IMAGE_HEIGHT 156 | width = int(w * height / h) 157 | image = cv2.resize(image, (width, height)) 158 | image = np.array(image, dtype=np.float32) 159 | image = image / 255.0 160 | seq_len = np.array([width / 4], dtype=np.int32) 161 | print("length: ", seq_len) 162 | if FLAGS.channel_size ==1: 163 | image = image[:,:,np.newaxis] 164 | cv2.imwrite("test.png", image*255.0) 165 | image = np.expand_dims(image, axis=0) 166 | start = time.time() 167 | print(image.shape) 168 | logit, preds, prob = sess.run( 169 | [crnn_net.logits, decoded, log_prob ], 170 | feed_dict={ 171 | crnn_net.images: image, 172 | crnn_net.sequence_length:seq_len, 173 | crnn_net.keep_prob: 1.0, 174 | crnn_net.is_training:False}) 175 | preds = _sparse_matrix_to_list(preds[0], id_to_char) 176 | cost_time = time.time() - start 177 | res_text = preds[0] 178 | res_text = merge_text(res_text) 179 | err_count = Levenshtein.distance(labels[i], res_text) 180 | total_count += len(labels[i]) 181 | error_count += err_count 182 | print(image_name) 183 | print('true label {:s} \n predict result: {:s} cost:{:f} \n error_count:{:d}'.format(labels[i], preds[0], cost_time, err_count,) ) 184 | print(1 - 1.0 * error_count / total_count) 185 | 186 | if __name__ == "__main__": 187 | eval() 188 | -------------------------------------------------------------------------------- /layout_analysis/dataset/book_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import os 6 | import json 7 | import cv2 8 | import random 9 | import numpy as np 10 | import time 11 | from dataset import image_augmation 12 | from dataset import data_util 13 | import xml.etree.ElementTree as ET 14 | 15 | IMAGE_WIDTH=1024 16 | IMAGE_HEIGHT=1024 17 | NUM_CLASS=3 18 | 19 | 20 | def load_image_list(image_data_path): 21 | image_path_list = [] 22 | with open(image_data_path) as f: 23 | for _, line in enumerate(f): 24 | line = line.strip("\n") 25 | image_path_list.append(line) 26 | return image_path_list 27 | 28 | def get_int_value(item, name): 29 | return int(item.get(name)) 30 | 31 | def get_item_box(item): 32 | x = get_int_value(item, 'left') 33 | y = get_int_value(item, 'top') 34 | w = get_int_value(item, 'width') 35 | h = get_int_value(item, 'height') 36 | return x, y, w, h 37 | 38 | def load_label_data(label_file_path): 39 | tree = ET.parse(label_file_path) 40 | root = tree.getroot() 41 | label_data = {} 42 | image_list= root.findall('image') 43 | text_list= root.findall('text') 44 | label_data['image_path']=root.get("image_path") 45 | page_height = int(root.get('height')) 46 | page_width = int(root.get('width')) 47 | label_data['width'] = page_width 48 | label_data['height'] = page_height 49 | label_data['images'] = [] 50 | label_data['texts'] = [] 51 | for image in image_list: 52 | box = get_item_box(image) 53 | label_data['images'].append(box) 54 | for text in text_list: 55 | box = get_item_box(text) 56 | label_data['texts'].append(box) 57 | return label_data 58 | 59 | def get_shape_by_type(shapes, label): 60 | shape_list = [] 61 | for shape in shapes: 62 | if shape['label'] == label: 63 | shape_list.append(shape) 64 | return shape_list 65 | 66 | def fill_image(label_image, boxs, label_value, w_factor, h_factor): 67 | for box in boxs: 68 | x, y, w, h = box 69 | min_x, min_y = x, y 70 | max_x, max_y = x + w , y + h 71 | area = (max_x - min_x) * (max_y - min_y) 72 | point_box = [(min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)] 73 | point_box = np.array(point_box) 74 | point_box = point_box.reshape((4,2)) 75 | point_box[:,0] = point_box[:,0] * w_factor 76 | point_box[:,1] = point_box[:,1] * h_factor 77 | label_image = cv2.fillPoly(label_image, point_box.astype(np.int32)[np.newaxis, :,: ], label_value) 78 | return label_image 79 | 80 | 81 | def data_generator(list_path, image_dir, batch_size, mode='train'): 82 | label_file_list = load_image_list(list_path) 83 | print("example size:", len(label_file_list)) 84 | image_batch = [] 85 | label_batch = [] 86 | mask_batch = [] 87 | xml_path_batch = [] 88 | scal_list=[0.3, 0.5, 1.0, 2.0, 3.0] 89 | while True: 90 | random.shuffle(label_file_list) 91 | for xml_path in label_file_list: 92 | xml_path = os.path.join(image_dir, xml_path) 93 | label_data=load_label_data(xml_path) 94 | image_path = os.path.join(image_dir, label_data['image_path']) 95 | image = cv2.imread(image_path) 96 | image_labels = np.array(label_data['images']) 97 | text_labels = np.array(label_data['texts']) 98 | # todo 图像增强 99 | #aug_image = image_augmation.image_aug(image.copy()) 100 | aug_image = image.copy() 101 | #rd_scale = np.random.choice(scal_list, 1) 102 | ##rd_scale = 1.0 103 | #r_image = cv2.resize(aug_image, dsize=None, fx=rd_scale, fy=rd_scale) 104 | #image_labels = image_labels * rd_scale 105 | #text_labels = text_labels * rd_scale 106 | r_image = aug_image 107 | 108 | if image is None: 109 | continue 110 | #h, w = image.shape[:2] 111 | h = label_data['height'] 112 | w = label_data['width'] 113 | 114 | image_h, image_w = r_image.shape[:2] 115 | h_ratio = image_h / h 116 | w_ratio = image_w / w 117 | if image_h > image_w: 118 | factor = IMAGE_HEIGHT / image_h 119 | new_h = IMAGE_HEIGHT 120 | new_w = int(image_w * factor) 121 | else: 122 | factor = IMAGE_WIDTH / image_w 123 | new_w = IMAGE_WIDTH 124 | new_h = int(image_h * factor) 125 | # todo resize 126 | 127 | w_factor = new_w / w 128 | h_factor = new_h / h 129 | r_image = cv2.resize(r_image, (new_w, new_h)) 130 | label_image = np.zeros((new_h, new_w)) 131 | mask = np.ones((new_h, new_w)) 132 | 133 | label_image = fill_image(label_image, image_labels , 1, w_factor, h_factor) 134 | label_image = fill_image(label_image, text_labels, 2, w_factor, h_factor) 135 | 136 | train_image = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3)) 137 | train_image[0:new_h, 0:new_w] = r_image 138 | train_label = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH)) 139 | train_label[0:new_h, 0:new_w] = label_image 140 | 141 | mask = np.ones((new_h, new_w)) 142 | train_mask = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH)) 143 | train_mask[0:new_h, 0:new_w] = mask 144 | 145 | 146 | label_batch.append(train_label) 147 | train_image = train_image / 255.0 148 | image_batch.append(train_image) 149 | mask_batch.append(train_mask) 150 | xml_path_batch.append(xml_path) 151 | if len(image_batch) == batch_size: 152 | yield image_batch, label_batch, mask_batch, xml_path_batch 153 | image_batch = [] 154 | label_batch = [] 155 | mask_batch = [] 156 | xml_path_batch = [] 157 | if mode!='train': 158 | break 159 | 160 | 161 | def get_batch(list_dir, image_dir, batch_size, mode='train', workers=1, max_queue_size=32): 162 | enqueuer = data_util.GeneratorEnqueuer(data_generator(list_dir, image_dir, batch_size, mode)) 163 | enqueuer.start(max_queue_size=max_queue_size, workers=workers) 164 | enqueuer.is_running() 165 | generator_output = None 166 | while True: 167 | while enqueuer.is_running(): 168 | if not enqueuer.queue.empty(): 169 | generator_output = enqueuer.queue.get() 170 | break 171 | else: 172 | time.sleep(0.01) 173 | yield generator_output 174 | generator_output = None 175 | 176 | 177 | import random 178 | 179 | def get_random_color(): 180 | color = (int(random.random()*255), int(random.random()*255), int(random.random()*255)) 181 | return color 182 | 183 | def mask_to_bbox(mask, im, num_class, out_path=None, out_file_name=None): 184 | bbox_list = [] 185 | mask = mask.astype(np.uint8) 186 | for i in range(1, num_class, 1): 187 | c_bbox_list = [] 188 | c_mask = np.zeros_like(mask) 189 | c_mask[np.where(mask==i)] = 255 190 | contours, hierarchy = cv2.findContours(c_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 191 | color = get_random_color() 192 | for cnt in contours: 193 | area = cv2.contourArea(cnt) 194 | if area < 50: 195 | continue 196 | epsilon = 0.005 * cv2.arcLength(cnt,True) 197 | approx = cv2.approxPolyDP(cnt,epsilon,True) 198 | (x, y, w, h) = cv2.boundingRect(approx) 199 | c_bbox_list.append([x, y, x+w, y+h]) 200 | if out_path is not None: 201 | cv2.putText(im, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA) 202 | im=cv2.rectangle(im, pt1=(x, y), pt2=(x+w, y+h), color=color, thickness=1) 203 | bbox_list.append(c_bbox_list) 204 | if out_path is not None: 205 | outf = os.path.join(out_path, out_file_name) 206 | print(outf) 207 | cv2.imwrite(outf, im) 208 | return bbox_list 209 | 210 | 211 | -------------------------------------------------------------------------------- /dataset/sample/188.xml: -------------------------------------------------------------------------------- 1 | 2 | 合肥百货大楼集团股份有限公司 2017 年年度报告全文 3 | 189 4 | 组合中,按账龄分析法计提坏账准备的其他应收款: 5 | √ 适用 □ 不适用 6 | 单位: 元 7 | 账龄 8 | 期末余额 9 | 其他应收款 10 | 坏账准备 11 | 计提比例 12 | 1 年以内分项 13 | 1 年以内小计 14 | 5,058,899.73 15 | 101,177.99 16 | 2.00% 17 | 1 至 2 年 18 | 16,832.00 19 | 841.60 20 | 5.00% 21 | 2 至 3 年 22 | 864,340.27 23 | 86,434.03 24 | 10.00% 25 | 3 年以上 26 | 8,073,921.81 27 | 7,498,961.81 28 | 3 至 4 年 29 | 718,700.00 30 | 143,740.00 31 | 20.00% 32 | 5 年以上 33 | 7,355,221.81 34 | 7,355,221.81 35 | 100.00% 36 | 合计 37 | 14,013,993.81 38 | 7,687,415.43 39 | 确定该组合依据的说明: 40 | 组合中,采用余额百分比法计提坏账准备的其他应收款: 41 | □ 适用 √ 不适用 42 | 组合中,采用其他方法计提坏账准备的其他应收款: 43 | √ 适用 □ 不适用 44 | 单位:元 45 | 组合名称 46 | 账面余额 47 | 坏账准备 48 | 单项金额重大并单项计提坏账准备的其他应收款 49 | 737,343,558.97 50 | 0.00 51 | 合计 52 | 737,343,558.97 53 | 0.00 54 | (2 55 | 本期计提坏账准备金额 1,741,636.12 元;本期收回或转回坏账准备金额 0.00 元。 56 | 其中本期坏账准备转回或收回金额重要的: 57 | 单位: 元 58 | 单位名称 59 | 转回或收回金额 60 | 收回方式 61 | (3 62 | 单位: 元 63 | 项目 64 | 核销金额 65 | 其中重要的其他应收款核销情况: 66 | 单位: 元 67 | -------------------------------------------------------------------------------- /crnn_ocr/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | #author: wu.zheng midday.me 4 | 5 | import tensorflow as tf 6 | import model 7 | import os 8 | import json 9 | import time 10 | import numpy as np 11 | import input_fn 12 | 13 | 14 | tf.app.flags.DEFINE_string( 15 | 'data_dir', '', 'Path to the directory containing data tf record.') 16 | 17 | tf.app.flags.DEFINE_string( 18 | 'model_dir', '/data/zhengwu_workspace/ocr/models/crnn_gen_model', 'Base directory for the model.') 19 | 20 | tf.app.flags.DEFINE_integer( 21 | 'num_threads', 4, 'The number of threads to use in batch shuffling') 22 | 23 | tf.app.flags.DEFINE_integer( 24 | 'step_per_eval', 5000, 'The number of training steps to run between evaluations.') 25 | 26 | tf.app.flags.DEFINE_integer( 27 | 'step_per_save', 100, 'The number of training steps to run between save checkpoints.') 28 | 29 | tf.app.flags.DEFINE_integer( 30 | 'batch_size', 64, 'The number of samples in each batch.') 31 | 32 | tf.app.flags.DEFINE_integer( 33 | 'max_train_steps', 500000, 'The number of maximum iteration steps for training') 34 | 35 | tf.app.flags.DEFINE_float( 36 | 'learning_rate', 0.0001, 'The initial learning rate for training.') 37 | 38 | tf.app.flags.DEFINE_integer( 39 | 'decay_steps', 30000, 'The learning rate decay steps for training.') 40 | 41 | tf.app.flags.DEFINE_float( 42 | 'decay_rate', 0.98, 'The learning rate decay rate for training.') 43 | 44 | tf.app.flags.DEFINE_string( 45 | 'char_map_json_file', './simple_vocab.json', 'Path to char map json file') 46 | 47 | tf.app.flags.DEFINE_integer('max_seq_length', 1024 , '') 48 | tf.app.flags.DEFINE_integer('channel_size', 3, 'image channels') 49 | 50 | FLAGS = tf.app.flags.FLAGS 51 | 52 | def _int_to_string(value, char_map_dict=None): 53 | if char_map_dict is None: 54 | char_map_dict = json.load(open(FLAGS.char_map_json_file, 'r')) 55 | 56 | assert(isinstance(char_map_dict, dict) and 'char_map_dict is not a dict') 57 | 58 | for key in char_map_dict.keys(): 59 | if char_map_dict[key] == int(value): 60 | return str(key) 61 | elif len(char_map_dict.keys()) == int(value): 62 | return "" 63 | raise ValueError('char map dict not has {:d} value. convert index to char failed.'.format(value)) 64 | 65 | def _sparse_matrix_to_list(sparse_matrix, char_map_dict=None): 66 | indices = sparse_matrix.indices 67 | values = sparse_matrix.values 68 | dense_shape = sparse_matrix.dense_shape 69 | # the last index in sparse_matrix is ctc blanck note 70 | if char_map_dict is None: 71 | char_map_dict = json.load(open(FLAGS.char_map_json_file, 'r')) 72 | assert(isinstance(char_map_dict, dict) and 'char_map_dict is not a dict') 73 | dense_matrix = len(char_map_dict.keys()) * np.ones(dense_shape, dtype=np.int32) 74 | 75 | for i, indice in enumerate(indices): 76 | dense_matrix[indice[0], indice[1]] = values[i] 77 | string_list = [] 78 | for row in dense_matrix: 79 | string = [] 80 | for val in row: 81 | string.append(_int_to_string(val, char_map_dict)) 82 | string_list.append(''.join(s for s in string if s != '*')) 83 | return string_list 84 | 85 | 86 | class Config(): 87 | num_classes = 123 88 | lstm_num_units = 256 89 | batch_size = 1 90 | is_training= True 91 | channel_size = FLAGS.channel_size 92 | 93 | def load_char_map(): 94 | char_map_dict = json.load(open(FLAGS.char_map_json_file)) 95 | return char_map_dict 96 | 97 | def sparse_labels(labels): 98 | values = [] 99 | indices = [] 100 | max_len = max(map(len, labels)) 101 | batch_size = len(labels) 102 | for i, item in enumerate(labels): 103 | for j, value in enumerate(item): 104 | ind = (i,j) 105 | indices.append(ind) 106 | values.append(value) 107 | sparse_labels = tf.SparseTensor( 108 | indices=indices, 109 | values=values, 110 | dense_shape=(batch_size, max_len)) 111 | return sparse_labels 112 | 113 | 114 | def main(): 115 | train_tf_record = os.path.join(FLAGS.data_dir, 'ocr-train-*.tfrecord') 116 | eval_tf_record = os.path.join(FLAGS.data_dir, 'ocr-validation-*.tfrecord') 117 | 118 | char_map_dict = load_char_map() 119 | train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 120 | model_name = 'crnn_ctc_ocr_{:s}.ckpt'.format(str(train_start_time)) 121 | model_save_path = os.path.join(FLAGS.model_dir, model_name) 122 | 123 | config = Config() 124 | config.batch_size = FLAGS.batch_size 125 | config.num_classes = len(char_map_dict) + 1 126 | train_input_fn = input_fn.input_fn(train_tf_record, FLAGS.batch_size, channel_size=FLAGS.channel_size) 127 | 128 | crnn_model = model.CRNN(config) 129 | saver = tf.train.Saver() 130 | if not os.path.exists(FLAGS.model_dir): 131 | os.makedirs(FLAGS.model_dir) 132 | 133 | global_step = tf.train.get_or_create_global_step() 134 | learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, 135 | global_step, 136 | FLAGS.decay_steps, 137 | FLAGS.decay_rate, 138 | staircase = True) 139 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 140 | with tf.control_dependencies(update_ops): 141 | train_op= tf.train.AdamOptimizer( 142 | learning_rate=FLAGS.learning_rate).minimize(crnn_model.loss, 143 | global_step=global_step) 144 | train_op = tf.group([train_op, update_ops]) 145 | decoded, log_prob = tf.nn.ctc_greedy_decoder(crnn_model.logits, crnn_model.sequence_length) 146 | pred_str_labels = tf.as_string(decoded[0].values) 147 | pred_tensor = tf.SparseTensor(indices=decoded[0].indices, values=pred_str_labels, dense_shape=decoded[0].dense_shape) 148 | true_str_labels = tf.as_string(crnn_model.labels.values) 149 | true_tensor = tf.SparseTensor(indices=crnn_model.labels.indices, values=true_str_labels, dense_shape=crnn_model.labels.dense_shape) 150 | edit_distance = tf.reduce_mean(tf.edit_distance(pred_tensor, true_tensor, normalize=True), name='distance') 151 | tf.summary.scalar(name='edit_distance', tensor= edit_distance) 152 | tf.summary.scalar(name='ctc_loss', tensor=crnn_model.loss) 153 | #tf.summary.scalar(name='learning_rate', tensor=learning_rate) 154 | merge_summary_op = tf.summary.merge_all() 155 | config = tf.ConfigProto() 156 | config.gpu_options.allow_growth = True 157 | with tf.Session(config=config) as sess: 158 | sess.run(tf.global_variables_initializer()) 159 | summary_writer = tf.summary.FileWriter(FLAGS.model_dir) 160 | summary_writer.add_graph(sess.graph) 161 | train_next_batch = train_input_fn.get_next() 162 | 163 | save_path = tf.train.latest_checkpoint(FLAGS.model_dir) 164 | if save_path: 165 | saver.restore(sess=sess, save_path=save_path) 166 | print("restore from %s"%(save_path) ) 167 | st = int(save_path.split("-")[-1]) 168 | sess.run(global_step.assign(st)) 169 | 170 | for s in range(FLAGS.max_train_steps): 171 | batch = sess.run(train_next_batch) 172 | images = batch['images'] 173 | labels = batch['labels'] 174 | sequence_length = batch['sequence_length'] 175 | _, loss , lr, summary, step, logits, dis = sess.run( 176 | [train_op, crnn_model.loss, learning_rate, merge_summary_op, global_step , crnn_model.logits , edit_distance ], 177 | feed_dict = { 178 | crnn_model.images:images, 179 | crnn_model.labels:labels, 180 | crnn_model.sequence_length:sequence_length, 181 | crnn_model.keep_prob:0.5, 182 | crnn_model.is_training:True}) 183 | 184 | print("step: {step} lr: {lr} loss: {loss} acc: {dis} ".format(step=step, lr=lr, loss=loss, dis=(1-dis) )) 185 | if step % FLAGS.step_per_save == 0: 186 | summary_writer.add_summary(summary=summary, global_step=step) 187 | saver.save(sess=sess, save_path=model_save_path, global_step=step) 188 | 189 | if False and step % FLAGS.step_per_eval == 0: 190 | eval_input_fn = input_fn.input_fn(eval_tf_record, FLAGS.batch_size, False, channel_size=FLAGS.channel_size ) 191 | eval_next_batch = eval_input_fn.get_next() 192 | all_distance = [] 193 | while True: 194 | try: 195 | eval_batch = sess.run(eval_next_batch) 196 | images = batch['images'] 197 | labels = batch['labels'] 198 | sequence_length = batch['sequence_length'] 199 | train_distance = sess.run([edit_distance], 200 | feed_dict={ 201 | crnn_model.images:images, 202 | crnn_model.labels:labels, 203 | crnn_model.keep_prob:1.0, 204 | crnn_model.is_training:True, 205 | crnn_model.sequence_length: sequence_length}) 206 | all_distance.append(train_distance[0]) 207 | except tf.errors.OutOfRangeError as e: 208 | print("eval acc: ", 1 - np.mean(np.array(all_distance))) 209 | break 210 | 211 | if __name__ == "__main__": 212 | main() 213 | 214 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------