├── .gitignore ├── 1 - wei ├── ASERT.ipynb ├── README.md ├── aster │ ├── LICENSE │ ├── README.md │ ├── builders │ │ ├── __init__.py │ │ ├── attention_recognition_model_builder_test.py │ │ ├── bidirectional_rnn_builder.py │ │ ├── bidirectional_rnn_builder_test.py │ │ ├── convnet_builder.py │ │ ├── convnet_builder_test.py │ │ ├── ctc_recognition_model_builder_test.py │ │ ├── feature_extractor_builder.py │ │ ├── feature_extractor_builder_test.py │ │ ├── hyperparams_builder.py │ │ ├── hyperparams_builder_test.py │ │ ├── input_reader_builder.py │ │ ├── input_reader_builder_test.py │ │ ├── label_map_builder.py │ │ ├── label_map_builder_test.py │ │ ├── loss_builder.py │ │ ├── loss_builder_test.py │ │ ├── model_builder.py │ │ ├── model_builder_test.py │ │ ├── optimizer_builder.py │ │ ├── optimizer_builder_test.py │ │ ├── predictor_builder.py │ │ ├── predictor_builder_test.py │ │ ├── preprocessor_builder.py │ │ ├── preprocessor_builder_test.py │ │ ├── rnn_cell_builder.py │ │ ├── rnn_cell_builder_test.py │ │ ├── spatial_transformer_builder.py │ │ └── spatial_transformer_builder_test.py │ ├── c_ops │ │ ├── CMakeLists.txt │ │ ├── build.sh │ │ ├── divide_curve_op.cc │ │ ├── ops.py │ │ ├── ops_test.py │ │ ├── string_filtering_op.cc │ │ └── string_reverse_op.cc │ ├── convnets │ │ ├── __init__.py │ │ ├── crnn_net.py │ │ ├── resnet.py │ │ └── stn_convnet.py │ ├── core │ │ ├── __init__.py │ │ ├── batcher.py │ │ ├── bidirectional_rnn.py │ │ ├── convnet.py │ │ ├── feature_extractor.py │ │ ├── label_map.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── predictor.py │ │ ├── prefetcher.py │ │ ├── preprocessor.py │ │ ├── spatial_transformer.py │ │ ├── spatial_transformer_test.py │ │ ├── standard_fields.py │ │ └── sync_attention_wrapper.py │ ├── data_decoders │ │ ├── __init__.py │ │ └── tf_example_decoder.py │ ├── demo.py │ ├── demo_1.py │ ├── eval.py │ ├── eval_util.py │ ├── evaluator.py │ ├── experiments │ │ ├── demo │ │ │ └── config │ │ │ │ └── trainval.prototxt │ │ └── tinymind │ │ │ └── config │ │ │ └── trainval.prototxt │ ├── meta_architectures │ │ ├── __init__.py │ │ ├── ctc_recognition_model.py │ │ └── multi_predictors_recognition_model.py │ ├── overview.png │ ├── predictors │ │ ├── __init__.py │ │ └── attention_predictor.py │ ├── protos │ │ ├── bidirectional_rnn.proto │ │ ├── bidirectional_rnn_pb2.py │ │ ├── convnet.proto │ │ ├── convnet_pb2.py │ │ ├── eval.proto │ │ ├── eval_pb2.py │ │ ├── feature_extractor.proto │ │ ├── feature_extractor_pb2.py │ │ ├── hyperparams.proto │ │ ├── hyperparams_pb2.py │ │ ├── input_reader.proto │ │ ├── input_reader_pb2.py │ │ ├── label_map.proto │ │ ├── label_map_pb2.py │ │ ├── loss.proto │ │ ├── loss_pb2.py │ │ ├── model.proto │ │ ├── model_pb2.py │ │ ├── optimizer.proto │ │ ├── optimizer_pb2.py │ │ ├── pipeline.proto │ │ ├── pipeline_pb2.py │ │ ├── predictor.proto │ │ ├── predictor_pb2.py │ │ ├── preprocessor.proto │ │ ├── preprocessor_pb2.py │ │ ├── rnn_cell.proto │ │ ├── rnn_cell_pb2.py │ │ ├── spatial_transformer.proto │ │ ├── spatial_transformer_pb2.py │ │ ├── train.proto │ │ └── train_pb2.py │ ├── tools │ │ ├── create_cute80_tfrecord.py │ │ ├── create_ic03_tfrecord.py │ │ ├── create_ic13_tfrecord.py │ │ ├── create_ic15_tfrecord.py │ │ ├── create_iiit5k_tfrecord.py │ │ ├── create_svt_perspective_tfrecord.py │ │ ├── create_svt_tfrecord.py │ │ ├── create_synth90k_tfrecord.py │ │ └── create_synthtext_tfrecord.py │ ├── train.py │ ├── train_1.py │ ├── trainer.py │ └── utils │ │ ├── __init__.py │ │ ├── dataset_util.py │ │ ├── learning_schedules.py │ │ ├── model_deploy.py │ │ ├── model_deploy_test.py │ │ ├── profile_session_run_hooks.py │ │ ├── recognition_evaluation.py │ │ ├── shape_utils.py │ │ ├── variables_helper.py │ │ ├── visualization_utils.py │ │ └── visualization_utils_test.py ├── ctpn-crnn │ ├── README.md │ ├── cptn │ │ ├── data │ │ │ └── demo │ │ │ │ ├── img_calligraphy_70001_bg.jpg │ │ │ │ ├── img_calligraphy_70001_bg.txt │ │ │ │ └── source │ │ │ │ ├── img_calligraphy_70001_bg.jpg │ │ │ │ └── lable.csv │ │ ├── main │ │ │ ├── predict.py │ │ │ └── train.py │ │ ├── nets │ │ │ ├── model_train.py │ │ │ └── vgg.py │ │ └── utils │ │ │ ├── bbox │ │ │ ├── bbox.c │ │ │ ├── bbox.pyx │ │ │ ├── bbox_transform.py │ │ │ ├── make.sh │ │ │ ├── nms.c │ │ │ ├── nms.pyx │ │ │ └── setup.py │ │ │ ├── dataset │ │ │ ├── data_provider.py │ │ │ └── data_util.py │ │ │ ├── prepare │ │ │ ├── __init__.py │ │ │ ├── split_label.py │ │ │ └── utils.py │ │ │ ├── rpn_msr │ │ │ ├── __init__.py │ │ │ ├── anchor_target_layer.py │ │ │ ├── config.py │ │ │ ├── generate_anchors.py │ │ │ └── proposal_layer.py │ │ │ └── text_connector │ │ │ ├── __init__.py │ │ │ ├── detectors.py │ │ │ ├── other.py │ │ │ ├── text_connect_cfg.py │ │ │ ├── text_proposal_connector.py │ │ │ ├── text_proposal_connector_oriented.py │ │ │ └── text_proposal_graph_builder.py │ └── crnn │ │ ├── alphabets.py │ │ ├── crnn_main.py │ │ ├── dataset.py │ │ ├── models │ │ ├── __init__.py │ │ └── crnn.py │ │ ├── params.py │ │ ├── predict.py │ │ ├── to_lmdb │ │ ├── __init__.py │ │ └── tolmdb.py │ │ └── utils.py ├── densent.ipynb ├── ocr_densenet │ ├── README.md │ ├── code │ │ ├── .DS_Store │ │ ├── ocr │ │ │ ├── .DS_Store │ │ │ ├── dataloader.py │ │ │ ├── densenet.py │ │ │ ├── main.py │ │ │ ├── resnet.py │ │ │ └── tools │ │ │ │ ├── __init__.py │ │ │ │ ├── measures.py │ │ │ │ ├── parse.py │ │ │ │ ├── plot.py │ │ │ │ ├── py_op.py │ │ │ │ ├── segmentation.py │ │ │ │ └── utils.py │ │ └── preprocessing │ │ │ ├── analysis_dataset.py │ │ │ ├── map_word_to_index.py │ │ │ └── show_black.py │ ├── files │ │ ├── .DS_Store │ │ ├── alphabet_count_dict.json │ │ ├── alphabet_index_dict.json │ │ ├── black.json │ │ ├── image_hw_ratio_dict.json │ │ ├── src │ │ │ ├── A81.png │ │ │ └── B1000_0.png │ │ ├── train_alphabet.json │ │ └── ttf │ │ │ └── simsun.ttf │ └── requirement.txt └── tinymind.ipynb ├── 2 - TitanikData ├── README.md ├── deep-text-recognition-benchmark-master │ ├── LICENSE.md │ ├── create_lmdb_dataset.py │ ├── dataset.py │ ├── demo.py │ ├── model.py │ ├── modules │ │ ├── feature_extraction.py │ │ ├── prediction.py │ │ ├── sequence_modeling.py │ │ └── transformation.py │ ├── predict.sh │ ├── test.py │ ├── train.py │ └── utils.py └── 检测 │ ├── README.txt │ ├── faster_rcnn_r50_fpn_1x_voc0712.py │ ├── pick_picname.py │ └── test.py ├── 3 - TechDing ├── .idea │ ├── RMB_TechDing.iml │ ├── dictionaries │ │ └── pc.xml │ ├── encodings.xml │ ├── misc.xml │ ├── modules.xml │ └── workspace.xml ├── CTC_Models │ ├── Utils.py │ ├── __init__.py │ └── model_predict.py ├── CTPN_ROI │ ├── __init__.py │ ├── checkpoints │ │ └── checkpoint │ ├── ctpn │ │ ├── __init__.py │ │ ├── get_ROI_imgs.py │ │ ├── rmb_db.py │ │ └── text.yml │ └── 说明.txt ├── CV_ROI │ ├── __init__.py │ ├── extract_ROI.py │ └── 说明.txt ├── Main.py ├── config.cfg ├── private_test_final_submit.csv ├── test_result.csv └── 使用方法.txt ├── 4 - HLearning ├── CRNN.ipynb ├── readme.md ├── text-detection-ctpn │ ├── LICENSE │ ├── README.md │ ├── main │ │ ├── demo.py │ │ └── train.py │ ├── nets │ │ ├── model_train.py │ │ └── vgg.py │ ├── requirements.txt │ └── utils │ │ ├── bbox │ │ ├── .fuse_hidden0000c4ff00000003 │ │ ├── bbox.c │ │ ├── bbox.pyx │ │ ├── bbox_transform.py │ │ ├── make.sh │ │ ├── nms.c │ │ ├── nms.pyx │ │ └── setup.py │ │ ├── dataset │ │ ├── data_provider.py │ │ └── data_util.py │ │ ├── prepare │ │ ├── split_label.py │ │ └── utils.py │ │ ├── rpn_msr │ │ ├── __init__.py │ │ ├── anchor_target_layer.py │ │ ├── config.py │ │ ├── generate_anchors.py │ │ └── proposal_layer.py │ │ └── text_connector │ │ ├── __init__.py │ │ ├── detectors.py │ │ ├── other.py │ │ ├── text_connect_cfg.py │ │ ├── text_proposal_connector.py │ │ ├── text_proposal_connector_oriented.py │ │ └── text_proposal_graph_builder.py └── zzz-Xecp&&IncV2&&Dense&&resnet.csv ├── 5 - ResNet34 ├── README.md ├── crnn-pytorch │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── collate_fn.py │ │ ├── data_transform.py │ │ ├── test_data.py │ │ └── text_data.py │ ├── fold_tta.pkl │ ├── lr_policy.py │ ├── models │ │ ├── __init__.py │ │ ├── crnn.py │ │ └── model_loader.py │ ├── pb_rcnn_label.csv │ ├── submit.py │ ├── test.py │ ├── test2.py │ ├── test2_tta.py │ ├── train.py │ └── 数据读取.ipynb └── multi-digit-pytorch │ ├── 1_train.py │ ├── 2_predict.py │ └── 未命名.ipynb ├── README.md └── 赛后分享PPT ├── 1 - wei.pdf ├── 2 - TitanikData.pdf ├── 3 - TechDing.pdf ├── 4 - HLearning.pdf └── 5 - ResNet34.pdf /1 - wei/README.md: -------------------------------------------------------------------------------- 1 | ###所有文件都可以在google colab直接打开运行 2 | 检测部分: 3 | 1. 下载ctpn代码及模型,地址 4 | https://drive.google.com/drive/folders/1VdsIQ4sgGmNI8Gor-_ie3xktLMFDp8Sd?usp=sharing 5 | 因为包含vgg16和ctpn的模型,较大 6 | 7 | 2.运行 tinymind.ipynb 中ctpn部分代码,生成剪切图片 (训练数据为人工标记160多张编码位置) 8 | 备注:99%只检测到一个候选框,可以直接用 9 | 剩余部分需要简单过滤,过滤完仍有多个框的全部送到检测,对结果进行正则判断,错误框都是一堆乱的数字,极少(大概三张)会出现边缘少了一点,预测结果正则会检测出来,运行demo_1重新截取 10 | 11 | 识别部分: 12 | crnn 模型地址: 13 | 运行代码在 tinymind.ipynb 14 | https://drive.google.com/file/d/1ywyH25xtcSHhZxeIACgso4Bslo-ZidNV/view?usp=sharing 15 | 代码在densent.ipynb 16 | densent模型地址: 17 | https://drive.google.com/file/d/1_xU2d7bU6FOLHJPjy1dgDLitdlGHCS7-/view?usp=sharing 18 | 代码在ASTER.ipynb 19 | aster模型地址: 20 | https://drive.google.com/drive/folders/1ctd55IG30aRAC4xUcSyirOyKaWW7mAfh?usp=sharing 21 | experiments 解压到 aster文件夹下面 22 | 23 | 模型融合代码在 tinymind.ipynb 最后部分 24 | 25 | 2D9QKGRM.jpg 为残缺图片,强制覆盖的结果 26 | -------------------------------------------------------------------------------- /1 - wei/aster/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Baoguang Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/builders/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/builders/bidirectional_rnn_builder.py: -------------------------------------------------------------------------------- 1 | from aster.core import bidirectional_rnn 2 | from aster.protos import hyperparams_pb2 3 | from aster.protos import bidirectional_rnn_pb2 4 | from aster.builders import hyperparams_builder 5 | from aster.builders import rnn_cell_builder 6 | 7 | 8 | def build(config, is_training): 9 | if not isinstance(config, bidirectional_rnn_pb2.BidirectionalRnn): 10 | raise ValueError('config not of type bidirectional_rnn_pb2.BidirectionalRnn') 11 | 12 | if config.static: 13 | brnn_class = bidirectional_rnn.StaticBidirectionalRnn 14 | else: 15 | brnn_class = bidirectional_rnn.DynamicBidirectionalRnn 16 | 17 | fw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) 18 | bw_cell_object = rnn_cell_builder.build(config.fw_bw_rnn_cell) 19 | rnn_regularizer_object = hyperparams_builder._build_regularizer(config.rnn_regularizer) 20 | fc_hyperparams_object = None 21 | if config.num_output_units > 0: 22 | if config.fc_hyperparams.op != hyperparams_pb2.Hyperparams.FC: 23 | raise ValueError('op type must be FC') 24 | fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams, is_training) 25 | 26 | return brnn_class( 27 | fw_cell_object, bw_cell_object, 28 | rnn_regularizer=rnn_regularizer_object, 29 | num_output_units=config.num_output_units, 30 | fc_hyperparams=fc_hyperparams_object, 31 | summarize_activations=config.summarize_activations) 32 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/bidirectional_rnn_builder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from google.protobuf import text_format 4 | from aster.builders import bidirectional_rnn_builder 5 | from aster.protos import bidirectional_rnn_pb2 6 | 7 | 8 | class BidirectionalRnnBuilderTest(tf.test.TestCase): 9 | 10 | def test_bidirectional_rnn(self): 11 | text_proto = """ 12 | static: true 13 | fw_bw_rnn_cell { 14 | lstm_cell { 15 | num_units: 32 16 | forget_bias: 1.0 17 | initializer { orthogonal_initializer {} } 18 | } 19 | } 20 | rnn_regularizer { l2_regularizer { weight: 1e-4 } } 21 | num_output_units: 31 22 | fc_hyperparams { 23 | op: FC 24 | activation: RELU 25 | initializer { variance_scaling_initializer { } } 26 | regularizer { l2_regularizer { weight: 1e-4 } } 27 | } 28 | """ 29 | config = bidirectional_rnn_pb2.BidirectionalRnn() 30 | text_format.Merge(text_proto, config) 31 | brnn_object = bidirectional_rnn_builder.build(config, True) 32 | 33 | test_input = tf.random_uniform([2, 5, 32], dtype=tf.float32) 34 | test_output = brnn_object.predict(test_input) 35 | 36 | with self.test_session() as sess: 37 | tf.global_variables_initializer().run() 38 | sess_outputs = sess.run({'outputs': test_output}) 39 | self.assertAllEqual(sess_outputs['outputs'].shape, [2, 5, 31]) 40 | 41 | def test_dynamic_bidirectional_rnn(self): 42 | text_proto = """ 43 | static: false 44 | fw_bw_rnn_cell { 45 | lstm_cell { 46 | num_units: 32 47 | forget_bias: 1.0 48 | initializer { orthogonal_initializer {} } 49 | } 50 | } 51 | rnn_regularizer { l2_regularizer { weight: 1e-4 } } 52 | num_output_units: 31 53 | fc_hyperparams { 54 | op: FC 55 | activation: RELU 56 | initializer { variance_scaling_initializer { } } 57 | regularizer { l2_regularizer { weight: 1e-4 } } 58 | } 59 | """ 60 | config = bidirectional_rnn_pb2.BidirectionalRnn() 61 | text_format.Merge(text_proto, config) 62 | brnn_object = bidirectional_rnn_builder.build(config, True) 63 | 64 | test_input = tf.random_uniform([2, 5, 32], dtype=tf.float32) 65 | test_output = brnn_object.predict(test_input) 66 | 67 | with self.test_session() as sess: 68 | tf.global_variables_initializer().run() 69 | sess_outputs = sess.run({'outputs': test_output}) 70 | self.assertAllEqual(sess_outputs['outputs'].shape, [2, 5, 31]) 71 | 72 | def test_bidirectional_rnn_nofc(self): 73 | text_proto = """ 74 | static: true 75 | fw_bw_rnn_cell { 76 | lstm_cell { 77 | num_units: 32 78 | forget_bias: 1.0 79 | initializer { orthogonal_initializer {} } 80 | } 81 | } 82 | rnn_regularizer { l2_regularizer { weight: 1e-4 } } 83 | """ 84 | config = bidirectional_rnn_pb2.BidirectionalRnn() 85 | text_format.Merge(text_proto, config) 86 | brnn_object = bidirectional_rnn_builder.build(config, True) 87 | 88 | test_input = tf.random_uniform([2, 5, 32], dtype=tf.float32) 89 | test_output = brnn_object.predict(test_input) 90 | 91 | with self.test_session() as sess: 92 | tf.global_variables_initializer().run() 93 | sess_outputs = sess.run({'outputs': test_output}) 94 | self.assertAllEqual(sess_outputs['outputs'].shape, [2, 5, 64]) 95 | 96 | if __name__ == '__main__': 97 | tf.test.main() 98 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/ctc_recognition_model_builder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from google.protobuf import text_format 4 | from aster.builders import model_builder 5 | from aster.protos import model_pb2 6 | 7 | 8 | class CtcRecognitionModelBuilderTest(tf.test.TestCase): 9 | 10 | def test_build_ctc_model(self): 11 | model_text_proto = """ 12 | ctc_recognition_model { 13 | feature_extractor { 14 | convnet { 15 | crnn_net { 16 | net_type: SINGLE_BRANCH 17 | conv_hyperparams { 18 | op: CONV 19 | regularizer { l2_regularizer { weight: 1e-4 } } 20 | initializer { variance_scaling_initializer { } } 21 | batch_norm { } 22 | } 23 | summarize_activations: false 24 | } 25 | } 26 | 27 | bidirectional_rnn { 28 | fw_bw_rnn_cell { 29 | lstm_cell { 30 | num_units: 256 31 | forget_bias: 1.0 32 | initializer { orthogonal_initializer {} } 33 | } 34 | } 35 | rnn_regularizer { l2_regularizer { weight: 1e-4 } } 36 | num_output_units: 256 37 | fc_hyperparams { 38 | op: FC 39 | activation: RELU 40 | initializer { variance_scaling_initializer { } } 41 | regularizer { l2_regularizer { weight: 1e-4 } } 42 | } 43 | } 44 | 45 | summarize_activations: true 46 | } 47 | 48 | fc_hyperparams { 49 | op: FC 50 | initializer { variance_scaling_initializer {} } 51 | regularizer { l2_regularizer { weight: 1e-4 } } 52 | } 53 | 54 | label_map { 55 | character_set { 56 | text_string: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 57 | delimiter: "" 58 | } 59 | label_offset: 0 60 | } 61 | } 62 | """ 63 | model_proto = model_pb2.Model() 64 | text_format.Merge(model_text_proto, model_proto) 65 | model_object = model_builder.build(model_proto, True) 66 | 67 | test_groundtruth_text_list = [ 68 | tf.constant(b'hello', dtype=tf.string), 69 | tf.constant(b'world', dtype=tf.string)] 70 | model_object.provide_groundtruth(test_groundtruth_text_list) 71 | test_input_image = tf.random_uniform(shape=[2, 32, 100, 3], minval=0, maxval=255, 72 | dtype=tf.float32, seed=1) 73 | prediction_dict = model_object.predict(model_object.preprocess(test_input_image)) 74 | loss = model_object.loss(prediction_dict) 75 | 76 | with self.test_session() as sess: 77 | sess.run([ 78 | tf.global_variables_initializer(), 79 | tf.tables_initializer()]) 80 | outputs = sess.run({'loss': loss}) 81 | print(outputs['loss']) 82 | 83 | if __name__ == '__main__': 84 | tf.test.main() 85 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/feature_extractor_builder.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | 5 | from aster.core import feature_extractor 6 | from aster.protos import feature_extractor_pb2 7 | from aster.builders import convnet_builder 8 | from aster.builders import bidirectional_rnn_builder 9 | from aster.builders import hyperparams_builder 10 | 11 | 12 | def build(config, is_training): 13 | if not isinstance(config, feature_extractor_pb2.FeatureExtractor): 14 | raise ValueError('config not of type feature_extractor_pb2.FeatureExtractor') 15 | 16 | convnet_object = convnet_builder.build(config.convnet, is_training) 17 | brnn_fn_list = [ 18 | functools.partial(bidirectional_rnn_builder.build, brnn_config, is_training) 19 | for brnn_config in config.bidirectional_rnn 20 | ] 21 | feature_extractor_object = feature_extractor.FeatureExtractor( 22 | convnet=convnet_object, 23 | brnn_fn_list=brnn_fn_list, 24 | summarize_activations=config.summarize_activations, 25 | is_training=is_training 26 | ) 27 | return feature_extractor_object 28 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/input_reader_builder.py: -------------------------------------------------------------------------------- 1 | """Input reader builder. 2 | 3 | Creates data sources for DetectionModels from an InputReader config. See 4 | input_reader.proto for options. 5 | 6 | Note: If users wishes to also use their own InputReaders with the Object 7 | Detection configuration framework, they should define their own builder function 8 | that wraps the build function. 9 | """ 10 | import os 11 | 12 | import tensorflow as tf 13 | 14 | from aster.data_decoders import tf_example_decoder 15 | from aster.protos import input_reader_pb2 16 | 17 | parallel_reader = tf.contrib.slim.parallel_reader 18 | 19 | 20 | def build(input_reader_config): 21 | """Builds a tensor dictionary based on the InputReader config. 22 | 23 | Args: 24 | input_reader_config: A input_reader_pb2.InputReader object. 25 | 26 | Returns: 27 | A tensor dict based on the input_reader_config. 28 | 29 | Raises: 30 | ValueError: On invalid input reader proto. 31 | """ 32 | if not isinstance(input_reader_config, input_reader_pb2.InputReader): 33 | raise ValueError('input_reader_config not of type ' 34 | 'input_reader_pb2.InputReader.') 35 | 36 | input_reader_oneof = input_reader_config.WhichOneof('input_reader') 37 | if input_reader_oneof == 'tf_record_input_reader': 38 | config = input_reader_config.tf_record_input_reader 39 | if not os.path.exists(config.input_path): 40 | raise ValueError('Input path not found: {}'.format(config.input_path)) 41 | 42 | _, string_tensor = parallel_reader.parallel_read( 43 | config.input_path, 44 | reader_class=tf.TFRecordReader, 45 | num_epochs=(input_reader_config.num_epochs 46 | if input_reader_config.num_epochs else None), 47 | num_readers=input_reader_config.num_readers, 48 | shuffle=input_reader_config.shuffle, 49 | dtypes=[tf.string, tf.string], 50 | capacity=input_reader_config.queue_capacity, 51 | min_after_dequeue=input_reader_config.min_after_dequeue) 52 | 53 | return tf_example_decoder.TfExampleDecoder().Decode(string_tensor) 54 | 55 | raise ValueError('Unsupported input_reader_config: {}'.format(input_reader_oneof)) 56 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/input_reader_builder_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from google.protobuf import text_format 6 | 7 | from tensorflow.core.example import example_pb2 8 | from tensorflow.core.example import feature_pb2 9 | from aster.core import standard_fields as fields 10 | from aster.protos import input_reader_pb2 11 | from aster.builders import input_reader_builder 12 | 13 | 14 | class InputReaderTest(tf.test.TestCase): 15 | 16 | def create_tf_record(self): 17 | path = os.path.join(self.get_temp_dir(), 'tfrecord') 18 | writer = tf.python_io.TFRecordWriter(path) 19 | 20 | image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) 21 | with self.test_session(): 22 | encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval() 23 | example = example_pb2.Example(features=feature_pb2.Features(feature={ 24 | 'image/encoded': feature_pb2.Feature( 25 | bytes_list=feature_pb2.BytesList(value=[encoded_jpeg])), 26 | 'image/format': feature_pb2.Feature( 27 | bytes_list=feature_pb2.BytesList(value=['jpeg'.encode('utf-8')])), 28 | 'image/transcript': feature_pb2.Feature( 29 | bytes_list=feature_pb2.BytesList(value=[ 30 | 'hello'.encode('utf-8')])) 31 | })) 32 | writer.write(example.SerializeToString()) 33 | writer.close() 34 | 35 | return path 36 | 37 | def test_build_tf_record_input_reader(self): 38 | tf_record_path = self.create_tf_record() 39 | 40 | input_reader_text_proto = """ 41 | shuffle: false 42 | num_readers: 3 43 | tf_record_input_reader {{ 44 | input_path: '{0}' 45 | }} 46 | """.format(tf_record_path) 47 | input_reader_proto = input_reader_pb2.InputReader() 48 | text_format.Merge(input_reader_text_proto, input_reader_proto) 49 | tensor_dict = input_reader_builder.build(input_reader_proto) 50 | 51 | sv = tf.train.Supervisor(logdir=self.get_temp_dir()) 52 | with sv.prepare_or_wait_for_session() as sess: 53 | sv.start_queue_runners(sess) 54 | output_dict = sess.run(tensor_dict) 55 | 56 | self.assertEqual( 57 | (4, 5, 3), 58 | output_dict[fields.InputDataFields.image].shape) 59 | self.assertEqual( 60 | 'hello'.encode('utf-8'), 61 | output_dict[fields.InputDataFields.groundtruth_text]) 62 | 63 | 64 | if __name__ == '__main__': 65 | tf.test.main() 66 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/label_map_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | 4 | import tensorflow as tf 5 | 6 | from aster.core import label_map 7 | from aster.protos import label_map_pb2 8 | 9 | 10 | def build(config): 11 | if not isinstance(config, label_map_pb2.LabelMap): 12 | raise ValueError('config not of type label_map_pb2.LabelMap') 13 | 14 | character_set = _build_character_set(config.character_set) 15 | label_map_object = label_map.LabelMap( 16 | character_set=character_set, 17 | label_offset=config.label_offset, 18 | unk_label=config.unk_label) 19 | return label_map_object 20 | 21 | def _build_character_set(config): 22 | if not isinstance(config, label_map_pb2.CharacterSet): 23 | raise ValueError('config not of type label_map_pb2.CharacterSet') 24 | 25 | source_oneof = config.WhichOneof('source_oneof') 26 | character_set_string = None 27 | if source_oneof == 'text_file': 28 | file_path = config.text_file 29 | with open(file_path, 'r') as f: 30 | character_set_string = f.read() 31 | character_set = character_set_string.split('\n') 32 | elif source_oneof == 'text_string': 33 | character_set_string = config.text_string 34 | character_set = character_set_string.split() 35 | elif source_oneof == 'built_in_set': 36 | if config.built_in_set == label_map_pb2.CharacterSet.LOWERCASE: 37 | character_set = list(string.digits + string.ascii_lowercase) 38 | elif config.built_in_set == label_map_pb2.CharacterSet.ALLCASES: 39 | character_set = list(string.digits + string.ascii_letters) 40 | elif config.built_in_set == label_map_pb2.CharacterSet.ALLCASES_SYMBOLS: 41 | character_set = list(string.printable[:-6]) 42 | else: 43 | raise ValueError('Unknown built_in_set') 44 | else: 45 | raise ValueError('Unknown source_oneof: {}'.format(source_oneof)) 46 | 47 | return character_set 48 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/label_map_builder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from google.protobuf import text_format 3 | 4 | from aster.builders import label_map_builder 5 | from aster.protos import label_map_pb2 6 | 7 | 8 | class LabelMapTest(tf.test.TestCase): 9 | 10 | def test_build_label_map(self): 11 | label_map_text_proto = """ 12 | character_set { 13 | built_in_set: LOWERCASE 14 | } 15 | label_offset: 3 16 | unk_label: -2 17 | """ 18 | label_map_proto = label_map_pb2.LabelMap() 19 | text_format.Merge(label_map_text_proto, label_map_proto) 20 | label_map_object = label_map_builder.build(label_map_proto) 21 | 22 | test_text = tf.constant( 23 | ['a', 'b', '', 'abz', '0a='], 24 | tf.string 25 | ) 26 | test_labels, text_lengths = label_map_object.text_to_labels(test_text, return_lengths=True) 27 | test_text_from_labels = label_map_object.labels_to_text(test_labels) 28 | 29 | with self.test_session() as sess: 30 | tf.tables_initializer().run() 31 | outputs = sess.run({ 32 | 'test_labels': test_labels, 33 | 'text_lengths': text_lengths, 34 | 'text_from_labels': test_text_from_labels 35 | }) 36 | self.assertAllEqual( 37 | outputs['test_labels'], 38 | [[13, -1, -1], 39 | [14, -1, -1], 40 | [-1, -1, -1], 41 | [13, 14, 38], 42 | [3, 13, -2]] 43 | ) 44 | self.assertAllEqual( 45 | outputs['text_lengths'], 46 | [1, 1, 0, 3, 3] 47 | ) 48 | self.assertAllEqual( 49 | outputs['text_from_labels'], 50 | [b'a', b'b', b'', b'abz', b'0a'] 51 | ) 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/loss_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from aster.core import loss 4 | from aster.protos import loss_pb2 5 | 6 | 7 | def build(config): 8 | if not isinstance(config, loss_pb2.Loss): 9 | raise ValueError('config not of type loss_pb2.Loss') 10 | loss_oneof = config.WhichOneof('loss_oneof') 11 | if loss_oneof == 'sequence_cross_entropy_loss': 12 | sequence_cross_entropy_loss_config = config.sequence_cross_entropy_loss 13 | return loss.SequenceCrossEntropyLoss( 14 | sequence_normalize=sequence_cross_entropy_loss_config.sequence_normalize, 15 | sample_normalize=sequence_cross_entropy_loss_config.sample_normalize, 16 | weight=sequence_cross_entropy_loss_config.weight 17 | ) 18 | elif loss_oneof == 'tfseq2seq_loss': 19 | raise NotImplementedError 20 | elif loss_oneof == 'l2_regression_loss': 21 | loss_config = config.l2_regression_loss 22 | return loss.L2RegressionLoss(weight=loss_config.weight) 23 | else: 24 | raise ValueError('Unknown loss_oneof: {}'.format(loss_oneof)) 25 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/loss_builder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from google.protobuf import text_format 5 | from aster.core import loss 6 | from aster.builders import loss_builder 7 | from aster.protos import loss_pb2 8 | 9 | 10 | class LossTest(tf.test.TestCase): 11 | 12 | def test_build_seq_loss(self): 13 | loss_text_proto = """ 14 | sequence_cross_entropy_loss { 15 | sequence_normalize: false 16 | sample_normalize: true 17 | weight: 0.5 18 | } 19 | """ 20 | loss_proto = loss_pb2.Loss() 21 | text_format.Merge(loss_text_proto, loss_proto) 22 | loss_object = loss_builder.build(loss_proto) 23 | 24 | test_logits = tf.constant( 25 | [ 26 | [ 27 | [0.0, 1.0], 28 | [0.5, 0.5], 29 | [0.3, 0.7], 30 | ], 31 | [ 32 | [0.0, -1.0], 33 | [1.0, 10.0], 34 | [1.0, 20.0], 35 | ], 36 | ], 37 | dtype=tf.float32 38 | ) 39 | test_labels = tf.constant( 40 | [ 41 | [0, 1, 0], 42 | [0, 0, 0] 43 | ], 44 | dtype=tf.int32 45 | ) 46 | test_lengths = tf.constant( 47 | [3, 1], 48 | dtype=tf.int32 49 | ) 50 | loss_tensor = loss_object(test_logits, test_labels, test_lengths, scope='loss') 51 | 52 | with self.test_session() as sess: 53 | outputs = sess.run({ 54 | 'loss': loss_tensor 55 | }) 56 | print(outputs) 57 | 58 | def test_build_reg_loss(self): 59 | loss_text_proto = """ 60 | l2_regression_loss { 61 | weight: 1.0 62 | } 63 | """ 64 | loss_proto = loss_pb2.Loss() 65 | text_format.Merge(loss_text_proto, loss_proto) 66 | loss_object = loss_builder.build(loss_proto) 67 | self.assertTrue(isinstance(loss_object, loss.L2RegressionLoss)) 68 | 69 | prediction = tf.constant(np.random.uniform(0, 1, (2, 20))) 70 | target = tf.constant(np.random.uniform(0, 1, (2, 20))) 71 | loss_tensor = loss_object(prediction, target) 72 | with self.test_session() as sess: 73 | print({'loss': loss_tensor.eval()}) 74 | 75 | if __name__ == '__main__': 76 | tf.test.main() 77 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/model_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from aster.builders import spatial_transformer_builder 4 | from aster.builders import feature_extractor_builder 5 | from aster.builders import predictor_builder 6 | from aster.builders import loss_builder 7 | from aster.meta_architectures import multi_predictors_recognition_model 8 | from aster.protos import model_pb2 9 | 10 | 11 | def build(config, is_training): 12 | if not isinstance(config, model_pb2.Model): 13 | raise ValueError('config not of type model_pb2.Model') 14 | model_oneof = config.WhichOneof('model_oneof') 15 | if model_oneof == 'multi_predictors_recognition_model': 16 | return _build_multi_predictors_recognition_model( 17 | config.multi_predictors_recognition_model, is_training) 18 | else: 19 | raise ValueError('Unknown model_oneof: {}'.format(model_oneof)) 20 | 21 | def _build_multi_predictors_recognition_model(config, is_training): 22 | if not isinstance(config, model_pb2.MultiPredictorsRecognitionModel): 23 | raise ValueError('config not of type model_pb2.MultiPredictorsRecognitionModel') 24 | 25 | spatial_transformer_object = None 26 | if config.HasField('spatial_transformer'): 27 | spatial_transformer_object = spatial_transformer_builder.build( 28 | config.spatial_transformer, is_training) 29 | 30 | feature_extractor_object = feature_extractor_builder.build( 31 | config.feature_extractor, 32 | is_training=is_training 33 | ) 34 | predictors_dict = { 35 | predictor_config.name : predictor_builder.build(predictor_config, is_training=is_training) 36 | for predictor_config in config.predictor 37 | } 38 | regression_loss_object = ( 39 | None if not config.keypoint_supervision else 40 | loss_builder.build(config.regression_loss)) 41 | 42 | model_object = multi_predictors_recognition_model.MultiPredictorsRecognitionModel( 43 | spatial_transformer=spatial_transformer_object, 44 | feature_extractor=feature_extractor_object, 45 | predictors_dict=predictors_dict, 46 | keypoint_supervision=config.keypoint_supervision, 47 | regression_loss=regression_loss_object, 48 | is_training=is_training, 49 | ) 50 | return model_object 51 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/predictor_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | 4 | from aster.protos import predictor_pb2 5 | from aster.builders import rnn_cell_builder 6 | from aster.builders import label_map_builder 7 | from aster.builders import loss_builder 8 | from aster.builders import hyperparams_builder 9 | from aster.predictors import attention_predictor 10 | # from aster.predictors import attention_predictor_with_lm 11 | 12 | 13 | def build(config, is_training): 14 | if not isinstance(config, predictor_pb2.Predictor): 15 | raise ValueError('config not of type predictor_pb2.AttentionPredictor') 16 | predictor_oneof = config.WhichOneof('predictor_oneof') 17 | 18 | if predictor_oneof == 'attention_predictor': 19 | predictor_config = config.attention_predictor 20 | rnn_cell_object = rnn_cell_builder.build(predictor_config.rnn_cell) 21 | rnn_regularizer_object = hyperparams_builder._build_regularizer(predictor_config.rnn_regularizer) 22 | label_map_object = label_map_builder.build(predictor_config.label_map) 23 | loss_object = loss_builder.build(predictor_config.loss) 24 | if not predictor_config.HasField('lm_rnn_cell'): 25 | lm_rnn_cell_object = None 26 | else: 27 | lm_rnn_cell_object = _build_language_model_rnn_cell(predictor_config.lm_rnn_cell) 28 | 29 | attention_predictor_object = attention_predictor.AttentionPredictor( 30 | rnn_cell=rnn_cell_object, 31 | rnn_regularizer=rnn_regularizer_object, 32 | num_attention_units=predictor_config.num_attention_units, 33 | max_num_steps=predictor_config.max_num_steps, 34 | multi_attention=predictor_config.multi_attention, 35 | beam_width=predictor_config.beam_width, 36 | reverse=predictor_config.reverse, 37 | label_map=label_map_object, 38 | loss=loss_object, 39 | sync=predictor_config.sync, 40 | lm_rnn_cell=lm_rnn_cell_object, 41 | is_training=is_training 42 | ) 43 | return attention_predictor_object 44 | else: 45 | raise ValueError('Unknown predictor_oneof: {}'.format(predictor_oneof)) 46 | 47 | 48 | def _build_language_model_rnn_cell(config): 49 | if not isinstance(config, predictor_pb2.LanguageModelRnnCell): 50 | raise ValueError('config not of type predictor_pb2.LanguageModelRnnCell') 51 | rnn_cell_list = [ 52 | rnn_cell_builder.build(rnn_cell_config) for rnn_cell_config in config.rnn_cell 53 | ] 54 | lm_rnn_cell = rnn.MultiRNNCell(rnn_cell_list) 55 | return lm_rnn_cell 56 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/rnn_cell_builder.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | 5 | from aster.protos import rnn_cell_pb2 6 | from aster.builders import hyperparams_builder 7 | 8 | 9 | def build(rnn_cell_config): 10 | if not isinstance(rnn_cell_config, rnn_cell_pb2.RnnCell): 11 | raise ValueError('rnn_cell_config not of type ' 12 | 'rnn_cell_pb2.RnnCell') 13 | rnn_cell_oneof = rnn_cell_config.WhichOneof('rnn_cell_oneof') 14 | 15 | if rnn_cell_oneof == 'lstm_cell': 16 | lstm_cell_config = rnn_cell_config.lstm_cell 17 | weights_initializer_object = hyperparams_builder._build_initializer( 18 | lstm_cell_config.initializer) 19 | lstm_cell_object = tf.contrib.rnn.LSTMCell( 20 | lstm_cell_config.num_units, 21 | use_peepholes=lstm_cell_config.use_peepholes, 22 | forget_bias=lstm_cell_config.forget_bias, 23 | initializer=weights_initializer_object 24 | ) 25 | return lstm_cell_object 26 | 27 | elif rnn_cell_oneof == 'gru_cell': 28 | gru_cell_config = rnn_cell_config.gru_cell 29 | weights_initializer_object = hyperparams_builder._build_initializer( 30 | gru_cell_config.initializer) 31 | gru_cell_object = tf.contrib.rnn.GRUCell( 32 | gru_cell_config.num_units, 33 | kernel_initializer=weights_initializer_object 34 | ) 35 | return gru_cell_object 36 | 37 | else: 38 | raise ValueError('Unknown rnn_cell_oneof: {}'.format(rnn_cell_oneof)) 39 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/rnn_cell_builder_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from google.protobuf import text_format 4 | from aster.protos import rnn_cell_pb2 5 | from aster.builders import rnn_cell_builder 6 | 7 | 8 | class RnnCellTest(tf.test.TestCase): 9 | 10 | def test_build_lstm_cell(self): 11 | rnn_cell_text_proto = """ 12 | lstm_cell { 13 | num_units: 1024 14 | use_peepholes: true 15 | forget_bias: 1.5 16 | initializer { orthogonal_initializer { seed: 1 } } 17 | } 18 | """ 19 | rnn_cell_proto = rnn_cell_pb2.RnnCell() 20 | text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) 21 | rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) 22 | 23 | lstm_state_tuple = rnn_cell_object.state_size 24 | 25 | self.assertEqual(lstm_state_tuple[0], 1024) 26 | self.assertEqual(lstm_state_tuple[1], 1024) 27 | 28 | def test_build_gru_cell(self): 29 | rnn_cell_text_proto = """ 30 | gru_cell { 31 | num_units: 1024 32 | initializer { orthogonal_initializer { seed: 1 } } 33 | } 34 | """ 35 | rnn_cell_proto = rnn_cell_pb2.RnnCell() 36 | text_format.Merge(rnn_cell_text_proto, rnn_cell_proto) 37 | rnn_cell_object = rnn_cell_builder.build(rnn_cell_proto) 38 | 39 | self.assertEqual(rnn_cell_object.state_size, 1024) 40 | 41 | if __name__ == '__main__': 42 | tf.test.main() 43 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/spatial_transformer_builder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from aster.core import spatial_transformer 4 | from aster.protos import spatial_transformer_pb2 5 | from aster.builders import hyperparams_builder 6 | from aster.builders import convnet_builder 7 | 8 | 9 | def build(config, is_training): 10 | if not isinstance(config, spatial_transformer_pb2.SpatialTransformer): 11 | raise ValueError('config not of type spatial_transformer_pb2.SpatialTransformer') 12 | 13 | convnet_object = convnet_builder.build(config.convnet, is_training) 14 | fc_hyperparams_object = hyperparams_builder.build(config.fc_hyperparams, is_training) 15 | return spatial_transformer.SpatialTransformer( 16 | convnet=convnet_object, 17 | fc_hyperparams=fc_hyperparams_object, 18 | localization_image_size=(config.localization_h, config.localization_w), 19 | output_image_size=(config.output_h, config.output_w), 20 | num_control_points=config.num_control_points, 21 | init_bias_pattern=config.init_bias_pattern, 22 | margins=(config.margin_x, config.margin_y), 23 | activation=config.activation, 24 | summarize_activations=config.summarize_activations 25 | ) 26 | -------------------------------------------------------------------------------- /1 - wei/aster/builders/spatial_transformer_builder_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from google.protobuf import text_format 5 | 6 | from aster.protos import spatial_transformer_pb2 7 | from aster.builders import spatial_transformer_builder 8 | 9 | 10 | class SpatialTransformerBuilderTest(tf.test.TestCase): 11 | 12 | def test_build_spatial_transformer(self): 13 | text_proto = """ 14 | convnet { 15 | stn_resnet { 16 | conv_hyperparams { 17 | op: CONV 18 | regularizer { l2_regularizer { } } 19 | initializer { variance_scaling_initializer { } } 20 | batch_norm { decay: 0.99 } 21 | } 22 | summarize_activations: false 23 | } 24 | } 25 | fc_hyperparams { 26 | op: CONV 27 | regularizer { l2_regularizer { } } 28 | initializer { variance_scaling_initializer { } } 29 | batch_norm { decay: 0.99 } 30 | } 31 | localization_h: 64 32 | localization_w: 128 33 | output_h: 32 34 | output_w: 100 35 | num_control_points: 20 36 | margin: 0.05 37 | init_bias_pattern: "slope" 38 | summarize_activations: true 39 | """ 40 | config = spatial_transformer_pb2.SpatialTransformer() 41 | text_format.Merge(text_proto, config) 42 | spatial_transformer_object = spatial_transformer_builder.build(config, True) 43 | self.assertTrue(spatial_transformer_object._summarize_activations == True) 44 | 45 | test_input_images = tf.random_uniform( 46 | [2, 64, 512, 3], minval=0, maxval=255, dtype=tf.float32) 47 | output_dict = spatial_transformer_object.batch_transform(test_input_images) 48 | 49 | with self.test_session() as sess: 50 | sess.run(tf.global_variables_initializer()) 51 | sess_outputs = sess.run({ 52 | 'rectified_images': output_dict['rectified_images'], 53 | 'control_points': output_dict['control_points'], 54 | }) 55 | self.assertEqual(sess_outputs['rectified_images'].shape, (2, 32, 100, 3)) 56 | 57 | init_bias = spatial_transformer_object._init_bias 58 | init_ctrl_pts = (1. / (1. + np.exp(-init_bias))).reshape(20, 2) 59 | self.assertAllClose(sess_outputs['control_points'][0], init_ctrl_pts) 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | project(aster) 3 | 4 | # compiler flags 5 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${OpenMP_CXX_FLAGS} -Wall -fPIC -D_GLIBCXX_USE_CXX11_ABI=0") 6 | 7 | # TensorFlow dependencies 8 | execute_process(COMMAND python3 -c "import os, sys; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; sys.stdout.write(tf.sysconfig.get_include()); sys.stdout.flush()" OUTPUT_VARIABLE TF_INC) 9 | message(STATUS "Found TF_INC: " ${TF_INC}) 10 | execute_process(COMMAND python3 -c "import os, sys; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; sys.stdout.write(tf.sysconfig.get_lib()); sys.stdout.flush()" OUTPUT_VARIABLE TF_LIB) 11 | message(STATUS "Found TF_LIB: " ${TF_LIB}) 12 | 13 | # target 14 | include_directories(${TF_INC} "${TF_INC}/external/nsync/public") 15 | link_directories(${TF_LIB}) 16 | add_library(aster SHARED 17 | string_filtering_op.cc 18 | string_reverse_op.cc 19 | divide_curve_op.cc) 20 | target_link_libraries(aster tensorflow_framework) 21 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | mkdir -p build 6 | cd build 7 | cmake -DCMAKE_BUILD_TYPE=Release .. 8 | make -j8 9 | cp libaster.* .. 10 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/ops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import uuid 4 | from os.path import join, dirname, realpath, exists 5 | import tensorflow as tf 6 | 7 | tf.app.flags.DEFINE_string('oplib_name', 'aster', 'Name of op library.') 8 | tf.app.flags.DEFINE_string('oplib_suffix', '.so', 'Library suffix.') 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | 12 | def _load_oplib(lib_name): 13 | """ 14 | Load TensorFlow operator library. 15 | """ 16 | lib_path = join(dirname(realpath(__file__)), 'lib{0}{1}'.format(lib_name, FLAGS.oplib_suffix)) 17 | assert exists(lib_path), '{0} not found'.format(lib_path) 18 | 19 | # duplicate library with a random new name so that 20 | # a running program will not be interrupted when the lib file is updated 21 | lib_copy_path = '/tmp/lib{0}_{1}{2}'.format(lib_name, str(uuid.uuid4())[:8], FLAGS.oplib_suffix) 22 | shutil.copyfile(lib_path, lib_copy_path) 23 | oplib = tf.load_op_library(lib_copy_path) 24 | return oplib 25 | 26 | _oplib = _load_oplib(FLAGS.oplib_name) 27 | 28 | # map C++ operators to python objects 29 | string_filtering = _oplib.string_filtering 30 | string_reverse = _oplib.string_reverse 31 | divide_curve = _oplib.divide_curve 32 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/ops_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from aster.c_ops import ops 5 | 6 | class OpsTest(tf.test.TestCase): 7 | 8 | def test_string_reverse(self): 9 | test_input_strings = tf.constant( 10 | [b'Hello', b'world', b'1l08ck`-?=1', b'']) 11 | test_reversed_strings = ops.string_reverse(test_input_strings) 12 | 13 | with self.test_session() as sess: 14 | self.assertAllEqual( 15 | test_reversed_strings.eval(), 16 | np.asarray([b'olleH', b'dlrow', b'1=?-`kc80l1', b'']) 17 | ) 18 | 19 | def test_divide_curve(self): 20 | num_keypoints = 128 21 | fit_points = np.array([ 22 | [0.0, 1.0], 23 | [1.0, 2.0], 24 | [2.0, 1.0] 25 | ], dtype=np.float32) 26 | coeffs = np.polyfit(fit_points[:,0], fit_points[:,1], 2) 27 | poly_fn = np.poly1d(coeffs) 28 | xmin, xmax = np.min(fit_points[:,0]), np.max(fit_points[:,0]) 29 | xs = np.linspace(xmin, xmax, num=(num_keypoints // 2)) 30 | ys = poly_fn(xs) 31 | curve_points = np.stack([xs, ys], axis=1).flatten() 32 | curve_points = np.expand_dims(curve_points, axis=0) 33 | 34 | key_points = ops.divide_curve(curve_points, num_key_points=20) 35 | with self.test_session() as sess: 36 | sess_outputs = sess.run({ 37 | 'key_points': key_points 38 | }) 39 | self.assertAllEqual(sess_outputs['key_points'].shape, (1, 40)) 40 | 41 | import matplotlib.pyplot as plt 42 | plt.subplot(1,2,1) 43 | plt.scatter(curve_points[0,::2], curve_points[0,1::2]) 44 | plt.subplot(1,2,2) 45 | plt.scatter(sess_outputs['key_points'][0,::2], sess_outputs['key_points'][0,1::2]) 46 | plt.show() 47 | 48 | if __name__ == '__main__': 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/string_filtering_op.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "tensorflow/core/framework/op.h" 8 | #include "tensorflow/core/framework/op_kernel.h" 9 | #include "tensorflow/core/framework/shape_inference.h" 10 | #include "tensorflow/core/framework/register_types.h" 11 | #include "tensorflow/core/framework/tensor.h" 12 | #include "tensorflow/core/framework/types.h" 13 | 14 | using namespace std; 15 | using namespace tensorflow; 16 | 17 | REGISTER_OP("StringFiltering") 18 | .Input("input_string: string") 19 | .Output("output_string: string") 20 | .Attr("lower_case: bool = False") 21 | .Attr("include_charset: string") 22 | .SetShapeFn([](shape_inference::InferenceContext* c) { 23 | using namespace shape_inference; 24 | 25 | ShapeHandle input_string = c->input(0); 26 | TF_RETURN_IF_ERROR(c->WithRank(input_string, 1, &input_string)); 27 | DimensionHandle num_strings = c->Dim(input_string, 0); 28 | 29 | c->set_output(0, c->MakeShape({num_strings})); 30 | return Status::OK(); 31 | }); 32 | 33 | 34 | class StringFilteringOp : public OpKernel { 35 | public: 36 | explicit StringFilteringOp(OpKernelConstruction* context): OpKernel(context) { 37 | OP_REQUIRES_OK(context, 38 | context->GetAttr("lower_case", &lower_case_)); 39 | string charset_string; 40 | OP_REQUIRES_OK(context, 41 | context->GetAttr("include_charset", &charset_string)); 42 | for (char c : charset_string) { 43 | charset_.insert(c); 44 | } 45 | } 46 | 47 | void Compute(OpKernelContext* context) override { 48 | // input-0 input_string 49 | const Tensor& input_string = context->input(0); 50 | OP_REQUIRES(context, input_string.dims() == 1, 51 | errors::InvalidArgument("Expected 1D string input, got ", 52 | input_string.shape().DebugString())); 53 | auto input_string_tensor = input_string.tensor(); 54 | 55 | const int num_strings = input_string.dim_size(0); 56 | 57 | Tensor* output_string = nullptr; 58 | OP_REQUIRES_OK(context, context->allocate_output(0, {num_strings}, &output_string)); 59 | auto output_string_tensor = output_string->tensor(); 60 | 61 | for (int i = 0; i < num_strings; i++) { 62 | string orig_string = input_string_tensor(i); 63 | string processed_string = ""; 64 | if (lower_case_) { 65 | transform(orig_string.begin(), orig_string.end(), orig_string.begin(), ::tolower); 66 | } 67 | for (char c : orig_string) { 68 | if (charset_.find(c) != charset_.end()) { 69 | processed_string += c; 70 | } 71 | } 72 | output_string_tensor(i) = processed_string; 73 | } 74 | } 75 | 76 | private: 77 | unordered_set charset_; 78 | bool lower_case_; 79 | }; 80 | 81 | REGISTER_KERNEL_BUILDER(Name("StringFiltering").Device(DEVICE_CPU), 82 | StringFilteringOp) 83 | -------------------------------------------------------------------------------- /1 - wei/aster/c_ops/string_reverse_op.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "tensorflow/core/framework/op.h" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | #include "tensorflow/core/framework/shape_inference.h" 7 | #include "tensorflow/core/framework/register_types.h" 8 | #include "tensorflow/core/framework/tensor.h" 9 | #include "tensorflow/core/framework/types.h" 10 | 11 | using namespace std; 12 | using namespace tensorflow; 13 | 14 | REGISTER_OP("StringReverse") 15 | .Input("input_string: string") 16 | .Output("output_string: string") 17 | .SetShapeFn([](shape_inference::InferenceContext* c) { 18 | using namespace shape_inference; 19 | ShapeHandle input_string = c->input(0); 20 | TF_RETURN_IF_ERROR(c->WithRank(input_string, 1, &input_string)); 21 | DimensionHandle num_strings = c->Dim(input_string, 0); 22 | c->set_output(0, c->MakeShape({num_strings})); 23 | return Status::OK(); 24 | }); 25 | 26 | 27 | class StringReverseOp : public OpKernel { 28 | public: 29 | explicit StringReverseOp(OpKernelConstruction* context): OpKernel(context) {} 30 | 31 | void Compute(OpKernelContext* context) override { 32 | // input-0 input_string 33 | const Tensor& input_string = context->input(0); 34 | OP_REQUIRES(context, input_string.dims() == 1, 35 | errors::InvalidArgument("Expected 1D string input, got ", 36 | input_string.shape().DebugString())); 37 | auto input_string_tensor = input_string.tensor(); 38 | 39 | const int num_strings = input_string.dim_size(0); 40 | 41 | Tensor* output_string = nullptr; 42 | OP_REQUIRES_OK(context, context->allocate_output(0, {num_strings}, &output_string)); 43 | auto output_string_tensor = output_string->tensor(); 44 | 45 | for (int i = 0; i < num_strings; i++) { 46 | string orig_string = input_string_tensor(i); 47 | string reversed_string = orig_string; 48 | reverse(reversed_string.begin(), reversed_string.end()); 49 | output_string_tensor(i) = reversed_string; 50 | } 51 | } 52 | }; 53 | 54 | REGISTER_KERNEL_BUILDER(Name("StringReverse").Device(DEVICE_CPU), 55 | StringReverseOp) 56 | -------------------------------------------------------------------------------- /1 - wei/aster/convnets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/convnets/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/convnets/stn_convnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib.layers import conv2d, max_pool2d 5 | from tensorflow.contrib.framework import arg_scope 6 | 7 | from aster.core import convnet 8 | 9 | 10 | class StnConvnet(convnet.Convnet): 11 | 12 | def _extract_features(self, preprocessed_inputs): 13 | """Extract features 14 | Args: 15 | preprocessed_inputs: float32 tensor of shape [batch_size, image_height, image_width, 3] 16 | Return: 17 | feature_maps: a list of extracted feature maps 18 | """ 19 | with arg_scope([conv2d], kernel_size=3, activation_fn=tf.nn.relu), \ 20 | arg_scope([max_pool2d], kernel_size=2, stride=2): 21 | conv1 = conv2d(preprocessed_inputs, 32, scope='conv1') # 64 22 | pool1 = max_pool2d(conv1, scope='pool1') 23 | conv2 = conv2d(pool1, 64, scope='conv2') # 32 24 | pool2 = max_pool2d(conv2, scope='pool2') 25 | conv3 = conv2d(pool2, 128, scope='conv3') # 16 26 | pool3 = max_pool2d(conv3, scope='pool3') 27 | conv4 = conv2d(pool3, 256, scope='conv4') # 8 28 | pool4 = max_pool2d(conv4, scope='pool4') 29 | conv5 = conv2d(pool4, 256, scope='conv5') # 4 30 | pool5 = max_pool2d(conv5, scope='pool5') 31 | conv6 = conv2d(pool5, 256, scope='conv6') # 2 32 | feature_maps_dict = { 33 | 'conv1': conv1, 'conv2': conv2, 'conv3': conv3, 34 | 'conv4': conv4, 'conv5': conv5, 'conv6': conv6 } 35 | return feature_maps_dict 36 | 37 | def _output_endpoints(self, feature_maps_dict): 38 | return [feature_maps_dict['conv6']] 39 | 40 | 41 | class StnConvnetTiny(convnet.Convnet): 42 | 43 | def _extract_features(self, preprocessed_inputs): 44 | """Extract features 45 | Args: 46 | preprocessed_inputs: float32 tensor of shape [batch_size, image_height, image_width, 3] 47 | Return: 48 | feature_maps: a list of extracted feature maps 49 | """ 50 | with arg_scope([conv2d], kernel_size=3, activation_fn=tf.nn.relu), \ 51 | arg_scope([max_pool2d], kernel_size=2, stride=2): 52 | conv1 = conv2d(preprocessed_inputs, 8, scope='conv1') # 64 53 | pool1 = max_pool2d(conv1, scope='pool1') 54 | conv2 = conv2d(pool1, 16, scope='conv2') # 32 55 | pool2 = max_pool2d(conv2, scope='pool2') 56 | conv3 = conv2d(pool2, 32, scope='conv3') # 16 57 | pool3 = max_pool2d(conv3, scope='pool3') 58 | conv4 = conv2d(pool3, 32, scope='conv4') # 8 59 | pool4 = max_pool2d(conv4, scope='pool4') 60 | conv5 = conv2d(pool4, 64, scope='conv5') # 4 61 | pool5 = max_pool2d(conv5, scope='pool5') 62 | conv6 = conv2d(pool5, 64, scope='conv6') # 2 63 | feature_maps_dict = { 64 | 'conv1': conv1, 'conv2': conv2, 'conv3': conv3, 65 | 'conv4': conv4, 'conv5': conv5, 'conv6': conv6 } 66 | return feature_maps_dict 67 | 68 | def _output_endpoints(self, feature_maps_dict): 69 | return [feature_maps_dict['conv6']] 70 | -------------------------------------------------------------------------------- /1 - wei/aster/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/core/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/core/convnet.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | 4 | import tensorflow as tf 5 | from tensorflow.contrib.framework import arg_scope 6 | 7 | 8 | class Convnet(object): 9 | __metaclass__ = ABCMeta 10 | 11 | def __init__(self, 12 | conv_hyperparams=None, 13 | summarize_activations=False, 14 | is_training=True): 15 | self._conv_hyperparams = conv_hyperparams 16 | self._summarize_activations = summarize_activations 17 | self._is_training = is_training 18 | 19 | def preprocess(self, resized_inputs, scope=None): 20 | with tf.variable_scope(scope, 'ConvnetPreprocess', [resized_inputs]): 21 | preprocessed_inputs = (2.0 / 255.0) * resized_inputs - 1.0 22 | if self._summarize_activations: 23 | tf.summary.image('preprocessed_inputs', preprocessed_inputs, max_outputs=1) 24 | return preprocessed_inputs 25 | 26 | def extract_features(self, preprocessed_inputs, scope=None): 27 | with tf.variable_scope(scope, 'Convnet', [preprocessed_inputs]): 28 | shape_assert = self._shape_check(preprocessed_inputs) 29 | if shape_assert is None: 30 | shape_assert = tf.no_op() 31 | with tf.control_dependencies([shape_assert]), \ 32 | arg_scope(self._conv_hyperparams): 33 | feature_maps_dict = self._extract_features(preprocessed_inputs) 34 | if self._summarize_activations: 35 | for k, tensor in feature_maps_dict.items(): 36 | tf.summary.histogram('Activations/' + k, tensor) 37 | return self._output_endpoints(feature_maps_dict) 38 | 39 | def _shape_check(self, preprocessed_inputs): 40 | return None 41 | 42 | @abstractmethod 43 | def _output_endpoints(self, feature_maps_dict): 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /1 - wei/aster/core/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib.layers import conv2d, max_pool2d 5 | from tensorflow.contrib.framework import arg_scope 6 | 7 | from aster.utils import shape_utils 8 | 9 | 10 | class FeatureExtractor(object): 11 | def __init__(self, 12 | convnet=None, 13 | brnn_fn_list=[], 14 | summarize_activations=False, 15 | is_training=True): 16 | self._convnet = convnet 17 | self._brnn_fn_list = brnn_fn_list 18 | self._summarize_activations = summarize_activations 19 | self._is_training = is_training 20 | 21 | def preprocess(self, resized_inputs, scope=None): 22 | with tf.variable_scope(scope, 'FeatureExtractorPreprocess', [resized_inputs]) as preproc_scope: 23 | preprocessed_inputs = self._convnet.preprocess(resized_inputs, preproc_scope) 24 | return preprocessed_inputs 25 | 26 | def extract_features(self, preprocessed_inputs, scope=None): 27 | with tf.variable_scope(scope, 'FeatureExtractor', [preprocessed_inputs]): 28 | feature_maps = self._convnet.extract_features(preprocessed_inputs) 29 | 30 | if len(self._brnn_fn_list) > 0: 31 | feature_sequences_list = [] 32 | for i, feature_map in enumerate(feature_maps): 33 | shape_assert = tf.Assert( 34 | tf.equal(tf.shape(feature_map)[1], 1), 35 | ['Feature map height must be 1 if bidirectional RNN is going to be applied.'] 36 | ) 37 | batch_size, _, _, map_depth = shape_utils.combined_static_and_dynamic_shape(feature_map) 38 | with tf.control_dependencies([shape_assert]): 39 | feature_sequence = tf.reshape(feature_map, [batch_size, -1, map_depth]) 40 | for j, brnn_fn in enumerate(self._brnn_fn_list): 41 | brnn_object = brnn_fn() 42 | feature_sequence = brnn_object.predict(feature_sequence, scope='BidirectionalRnn_Branch_{}_{}'.format(i, j)) 43 | feature_sequences_list.append(feature_sequence) 44 | 45 | feature_maps = [tf.expand_dims(fmap, axis=1) for fmap in feature_sequences_list] 46 | return feature_maps 47 | -------------------------------------------------------------------------------- /1 - wei/aster/core/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from aster.utils import shape_utils 4 | 5 | 6 | class SequenceCrossEntropyLoss(object): 7 | def __init__(self, 8 | sequence_normalize=None, 9 | sample_normalize=None, 10 | weight=None): 11 | self._sequence_normalize = sequence_normalize 12 | self._sample_normalize = sample_normalize 13 | self._weight = weight 14 | 15 | def __call__(self, logits, labels, lengths, scope=None): 16 | """ 17 | Args: 18 | logits: float32 tensor with shape [batch_size, max_time, num_classes] 19 | labels: int32 tensor with shape [batch_size, max_time] 20 | lengths: int32 tensor with shape [batch_size] 21 | """ 22 | with tf.name_scope(scope, 'SequenceCrossEntropyLoss', [logits, labels, lengths]): 23 | raw_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( 24 | labels=labels, 25 | logits=logits 26 | ) 27 | batch_size, max_time = shape_utils.combined_static_and_dynamic_shape(labels) 28 | mask = tf.less( 29 | tf.tile([tf.range(max_time)], [batch_size, 1]), 30 | tf.expand_dims(lengths, 1), 31 | name='mask' 32 | ) 33 | masked_losses = tf.multiply( 34 | raw_losses, 35 | tf.cast(mask, tf.float32), 36 | name='masked_losses' 37 | ) # => [batch_size, max_time] 38 | row_losses = tf.reduce_sum(masked_losses, 1, name='row_losses') 39 | if self._sequence_normalize: 40 | loss = tf.truediv( 41 | row_losses, 42 | tf.cast(tf.maximum(lengths, 1), tf.float32), 43 | name='seq_normed_losses') 44 | loss = tf.reduce_sum(row_losses) 45 | if self._sample_normalize: 46 | loss = tf.truediv( 47 | loss, 48 | tf.cast(tf.maximum(batch_size, 1), tf.float32)) 49 | if self._weight: 50 | loss = loss * self._weight 51 | return loss 52 | 53 | 54 | class L2RegressionLoss(object): 55 | def __init__(self, weight=None): 56 | self._weight = weight 57 | 58 | def __call__(self, prediction, target, scope=None): 59 | with tf.name_scope(scope, 'L2RegressionLoss', [prediction, target]): 60 | diff = prediction - target 61 | losses = tf.reduce_sum(tf.square(diff), axis=1) 62 | loss = tf.reduce_mean(losses) 63 | if self._weight is not None: 64 | loss = loss * self._weight 65 | return loss 66 | -------------------------------------------------------------------------------- /1 - wei/aster/core/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | 4 | import tensorflow as tf 5 | 6 | 7 | class Model(object): 8 | __metaclass__ = ABCMeta 9 | 10 | def __init__(self, 11 | feature_extractor=None, 12 | is_training=True): 13 | self._feature_extractor = feature_extractor 14 | self._is_training = is_training 15 | self._predictors = {} 16 | self._groundtruth_dict = {} 17 | 18 | def preprocess(self, resized_inputs, scope=None): 19 | with tf.variable_scope(scope, 'ModelPreprocess', [resized_inputs]) as preprocess_scope: 20 | if resized_inputs.dtype is not tf.float32: 21 | raise ValueError('`preprocess` expects a tf.float32 tensor') 22 | preprocess_inputs = self._feature_extractor.preprocess(resized_inputs, scope=preprocess_scope) 23 | return preprocess_inputs 24 | 25 | @abstractmethod 26 | def predict(self, preprocessed_inputs, scope=None): 27 | pass 28 | 29 | @abstractmethod 30 | def loss(self, predictions_dict, scope=None): 31 | pass 32 | 33 | @abstractmethod 34 | def postprocess(self, predictions_dict, scope=None): 35 | pass 36 | 37 | @abstractmethod 38 | def provide_groundtruth(self, groundtruth_lists, scope=None): 39 | pass 40 | -------------------------------------------------------------------------------- /1 - wei/aster/core/predictor.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | 4 | import tensorflow as tf 5 | 6 | 7 | class Predictor(object): 8 | __metaclass__ = ABCMeta 9 | 10 | def __init__(self, is_training=True): 11 | self._is_training = is_training 12 | self._groundtruth_dict = {} 13 | 14 | @property 15 | def name(self): 16 | return self._name 17 | 18 | @abstractmethod 19 | def predict(self, feature_maps, scope=None): 20 | pass 21 | 22 | @abstractmethod 23 | def loss(self, predictions_dict, scope=None): 24 | pass 25 | 26 | @abstractmethod 27 | def provide_groundtruth(self, groundtruth_lists, scope=None): 28 | pass 29 | 30 | @abstractmethod 31 | def postprocess(self, predictions_dict, scope=None): 32 | return predictions_dict 33 | -------------------------------------------------------------------------------- /1 - wei/aster/core/prefetcher.py: -------------------------------------------------------------------------------- 1 | """Provides functions to prefetch tensors to feed into models.""" 2 | import tensorflow as tf 3 | 4 | 5 | def prefetch(tensor_dict, capacity): 6 | """Creates a prefetch queue for tensors. 7 | 8 | Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a 9 | dequeue op that evaluates to a tensor_dict. This function is useful in 10 | prefetching preprocessed tensors so that the data is readily available for 11 | consumers. 12 | 13 | Example input pipeline when you don't need batching: 14 | ---------------------------------------------------- 15 | key, string_tensor = slim.parallel_reader.parallel_read(...) 16 | tensor_dict = decoder.decode(string_tensor) 17 | tensor_dict = preprocessor.preprocess(tensor_dict, ...) 18 | prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20) 19 | tensor_dict = prefetch_queue.dequeue() 20 | outputs = Model(tensor_dict) 21 | ... 22 | ---------------------------------------------------- 23 | 24 | For input pipelines with batching, refer to core/batcher.py 25 | 26 | Args: 27 | tensor_dict: a dictionary of tensors to prefetch. 28 | capacity: the size of the prefetch queue. 29 | 30 | Returns: 31 | a FIFO prefetcher queue 32 | """ 33 | names = list(tensor_dict.keys()) 34 | dtypes = [t.dtype for t in tensor_dict.values()] 35 | shapes = [t.get_shape() for t in tensor_dict.values()] 36 | prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes, 37 | shapes=shapes, 38 | names=names, 39 | name='prefetch_queue') 40 | enqueue_op = prefetch_queue.enqueue(tensor_dict) 41 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner( 42 | prefetch_queue, [enqueue_op])) 43 | tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name, 44 | capacity), 45 | tf.to_float(prefetch_queue.size()) * (1. / capacity)) 46 | return prefetch_queue 47 | -------------------------------------------------------------------------------- /1 - wei/aster/core/spatial_transformer_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from PIL import Image 5 | 6 | from aster.core import spatial_transformer 7 | 8 | 9 | class SpatialTransformerTest(tf.test.TestCase): 10 | 11 | def test_batch_transform(self): 12 | transformer = spatial_transformer.SpatialTransformer( 13 | output_image_size=(32, 100), 14 | num_control_points=6, 15 | init_bias_pattern='identity', 16 | margin=0.05 17 | ) 18 | test_input_ctrl_pts = np.array([ 19 | [ 20 | [0.1, 0.4], [0.5, 0.1], [0.9, 0.4], 21 | [0.1, 0.9], [0.5, 0.6], [0.9, 0.9] 22 | ], 23 | [ 24 | [0.1, 0.1], [0.5, 0.4], [0.9, 0.1], 25 | [0.1, 0.6], [0.5, 0.9], [0.9, 0.6] 26 | ], 27 | [ 28 | [0.1, 0.1], [0.5, 0.1], [0.9, 0.1], 29 | [0.1, 0.9], [0.5, 0.9], [0.9, 0.9], 30 | ] 31 | ], dtype=np.float32) 32 | test_im = Image.open('aster/data/test_image.jpg').resize((128, 128)) 33 | test_image_array = np.array(test_im) 34 | test_image_array = np.array([test_image_array, test_image_array, test_image_array]) 35 | test_images = tf.cast(tf.constant(test_image_array), tf.float32) 36 | test_images = (test_images / 128.0) - 1.0 37 | 38 | sampling_grid = transformer._batch_generate_grid(test_input_ctrl_pts) 39 | rectified_images = transformer._batch_sample(test_images, sampling_grid) 40 | 41 | output_ctrl_pts = transformer._output_ctrl_pts 42 | with self.test_session() as sess: 43 | outputs = sess.run({ 44 | 'sampling_grid': sampling_grid, 45 | 'rectified_images': rectified_images 46 | }) 47 | 48 | rectified_images_ = (outputs['rectified_images'] + 1.0) * 128.0 49 | 50 | if True: 51 | plt.figure() 52 | plt.subplot(3,4,1) 53 | plt.scatter(test_input_ctrl_pts[0,:,0], test_input_ctrl_pts[0,:,1]) 54 | plt.subplot(3,4,2) 55 | plt.scatter(output_ctrl_pts[:,0], output_ctrl_pts[:,1]) 56 | plt.subplot(3,4,3) 57 | plt.scatter(outputs['sampling_grid'][0,:,0], outputs['sampling_grid'][0,:,1], marker='+') 58 | plt.subplot(3,4,4) 59 | plt.imshow(rectified_images_[0].astype(np.uint8)) 60 | 61 | plt.subplot(3,4,5) 62 | plt.scatter(test_input_ctrl_pts[1,:,0], test_input_ctrl_pts[1,:,1]) 63 | plt.subplot(3,4,6) 64 | plt.scatter(output_ctrl_pts[:,0], output_ctrl_pts[:,1]) 65 | plt.subplot(3,4,7) 66 | plt.scatter(outputs['sampling_grid'][1,:,0], outputs['sampling_grid'][1,:,1], marker='+') 67 | plt.subplot(3,4,8) 68 | plt.imshow(rectified_images_[1].astype(np.uint8)) 69 | 70 | plt.subplot(3,4,9) 71 | plt.scatter(test_input_ctrl_pts[2,:,0], test_input_ctrl_pts[2,:,1]) 72 | plt.subplot(3,4,10) 73 | plt.scatter(output_ctrl_pts[:,0], output_ctrl_pts[:,1]) 74 | plt.subplot(3,4,11) 75 | plt.scatter(outputs['sampling_grid'][2,:,0], outputs['sampling_grid'][2,:,1], marker='+') 76 | plt.subplot(3,4,12) 77 | plt.imshow(rectified_images_[2].astype(np.uint8)) 78 | 79 | plt.show() 80 | 81 | 82 | if __name__ == '__main__': 83 | tf.test.main() 84 | -------------------------------------------------------------------------------- /1 - wei/aster/core/standard_fields.py: -------------------------------------------------------------------------------- 1 | class InputDataFields(object): 2 | image = 'image' 3 | original_image = 'original_image' 4 | key = 'key' 5 | source_id = 'source_id' 6 | filename = 'filename' 7 | groundtruth_text = 'groundtruth_text' 8 | groundtruth_keypoints = 'groundtruth_keypoints' 9 | lexicon = 'lexicon' 10 | 11 | 12 | class TfExampleFields(object): 13 | image_encoded = 'image/encoded' 14 | image_format = 'image/format' # format is reserved keyword 15 | filename = 'image/filename' 16 | channels = 'image/channels' 17 | colorspace = 'image/colorspace' 18 | height = 'image/height' 19 | width = 'image/width' 20 | source_id = 'image/source_id' 21 | transcript = 'image/transcript' 22 | lexicon = 'image/lexicon' 23 | keypoints = 'image/keypoints' 24 | -------------------------------------------------------------------------------- /1 - wei/aster/core/sync_attention_wrapper.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops import array_ops 2 | from tensorflow.contrib import rnn 3 | from tensorflow.contrib import seq2seq 4 | from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _compute_attention 5 | 6 | 7 | class SyncAttentionWrapper(seq2seq.AttentionWrapper): 8 | 9 | def __init__(self, 10 | cell, 11 | attention_mechanism, 12 | attention_layer_size=None, 13 | alignment_history=False, 14 | cell_input_fn=None, 15 | output_attention=True, 16 | initial_cell_state=None, 17 | name=None): 18 | if not isinstance(cell, (rnn.LSTMCell, rnn.GRUCell)): 19 | raise ValueError('SyncAttentionWrapper only supports LSTMCell and GRUCell, ' 20 | 'Got: {}'.format(cell)) 21 | super(SyncAttentionWrapper, self).__init__( 22 | cell, 23 | attention_mechanism, 24 | attention_layer_size=attention_layer_size, 25 | alignment_history=alignment_history, 26 | cell_input_fn=cell_input_fn, 27 | output_attention=output_attention, 28 | initial_cell_state=initial_cell_state, 29 | name=name 30 | ) 31 | 32 | def call(self, inputs, state): 33 | if not isinstance(state, seq2seq.AttentionWrapperState): 34 | raise TypeError("Expected state to be instance of AttentionWrapperState. " 35 | "Received type %s instead." % type(state)) 36 | 37 | if self._is_multi: 38 | previous_alignments = state.alignments 39 | previous_alignment_history = state.alignment_history 40 | else: 41 | previous_alignments = [state.alignments] 42 | previous_alignment_history = [state.alignment_history] 43 | 44 | all_alignments = [] 45 | all_attentions = [] 46 | all_histories = [] 47 | for i, attention_mechanism in enumerate(self._attention_mechanisms): 48 | if isinstance(self._cell, rnn.LSTMCell): 49 | rnn_cell_state = state.cell_state.h 50 | else: 51 | rnn_cell_state = state.cell_state 52 | attention, alignments,attention_state = _compute_attention( 53 | attention_mechanism, rnn_cell_state, previous_alignments[i], 54 | self._attention_layers[i]if self._attention_layers else None) 55 | alignment_history = previous_alignment_history[i].write( 56 | state.time, alignments) if self._alignment_history else () 57 | 58 | all_alignments.append(alignments) 59 | all_histories.append(alignment_history) 60 | all_attentions.append(attention) 61 | 62 | attention = array_ops.concat(all_attentions, 1) 63 | 64 | cell_inputs = self._cell_input_fn(inputs, attention) 65 | cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) 66 | 67 | next_state = seq2seq.AttentionWrapperState( 68 | time=state.time + 1, 69 | cell_state=next_cell_state, 70 | attention=attention, 71 | attention_state=attention_state, 72 | alignments=self._item_or_tuple(all_alignments), 73 | alignment_history=self._item_or_tuple(all_histories)) 74 | 75 | if self._output_attention: 76 | return attention, next_state 77 | else: 78 | return cell_output, next_state 79 | -------------------------------------------------------------------------------- /1 - wei/aster/data_decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/data_decoders/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import tensorflow as tf 4 | from PIL import Image 5 | from google.protobuf import text_format 6 | import numpy as np 7 | 8 | from aster.protos import pipeline_pb2 9 | from aster.builders import model_builder 10 | 11 | # supress TF logging duplicates 12 | logging.getLogger('tensorflow').propagate = False 13 | tf.logging.set_verbosity(tf.logging.INFO) 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | flags = tf.app.flags 17 | flags.DEFINE_string('exp_dir', 'aster/experiments/demo/', 18 | 'Directory containing config, training log and evaluations') 19 | flags.DEFINE_string('input_image', 'aster/data/demo.jpg', 'Demo image') 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | def get_configs_from_exp_dir(): 24 | pipeline_config_path = os.path.join(FLAGS.exp_dir, 'config/trainval.prototxt') 25 | 26 | pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() 27 | with tf.gfile.GFile(pipeline_config_path, 'r') as f: 28 | text_format.Merge(f.read(), pipeline_config) 29 | 30 | model_config = pipeline_config.model 31 | eval_config = pipeline_config.eval_config 32 | input_config = pipeline_config.eval_input_reader 33 | 34 | return model_config, eval_config, input_config 35 | 36 | 37 | def main(_): 38 | checkpoint_dir = os.path.join(FLAGS.exp_dir, 'log') 39 | # eval_dir = os.path.join(FLAGS.exp_dir, 'log/eval') 40 | model_config, _, _ = get_configs_from_exp_dir() 41 | 42 | model = model_builder.build(model_config, is_training=False) 43 | 44 | input_image_str_tensor = tf.placeholder( 45 | dtype=tf.string, 46 | shape=[]) 47 | input_image_tensor = tf.image.decode_jpeg( 48 | input_image_str_tensor, 49 | channels=3, 50 | ) 51 | resized_image_tensor = tf.image.resize_images( 52 | tf.to_float(input_image_tensor), 53 | [64, 256]) 54 | 55 | predictions_dict = model.predict(tf.expand_dims(resized_image_tensor, 0)) 56 | recognitions = model.postprocess(predictions_dict) 57 | recognition_text = recognitions['text'][0] 58 | control_points = predictions_dict['control_points'], 59 | rectified_images = predictions_dict['rectified_images'] 60 | 61 | saver = tf.train.Saver(tf.global_variables()) 62 | checkpoint = os.path.join(FLAGS.exp_dir, 'log/model.ckpt') 63 | 64 | fetches = { 65 | 'original_image': input_image_tensor, 66 | 'recognition_text': recognition_text, 67 | 'control_points': predictions_dict['control_points'], 68 | 'rectified_images': predictions_dict['rectified_images'], 69 | } 70 | 71 | with open(FLAGS.input_image, 'rb') as f: 72 | input_image_str = f.read() 73 | 74 | with tf.Session() as sess: 75 | sess.run([ 76 | tf.global_variables_initializer(), 77 | tf.local_variables_initializer(), 78 | tf.tables_initializer()]) 79 | saver.restore(sess, checkpoint) 80 | sess_outputs = sess.run(fetches, feed_dict={input_image_str_tensor: input_image_str}) 81 | 82 | print('Recognized text: {}'.format(sess_outputs['recognition_text'].decode('utf-8'))) 83 | 84 | rectified_image = sess_outputs['rectified_images'][0] 85 | rectified_image_pil = Image.fromarray((128 * (rectified_image + 1.0)).astype(np.uint8)) 86 | input_image_dir = os.path.dirname(FLAGS.input_image) 87 | rectified_image_save_path = os.path.join(input_image_dir, 'rectified_image.jpg') 88 | rectified_image_pil.save(rectified_image_save_path) 89 | print('Rectified image saved to {}'.format(rectified_image_save_path)) 90 | 91 | if __name__ == '__main__': 92 | tf.app.run() 93 | -------------------------------------------------------------------------------- /1 - wei/aster/meta_architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/meta_architectures/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/overview.png -------------------------------------------------------------------------------- /1 - wei/aster/predictors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/predictors/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/protos/bidirectional_rnn.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/rnn_cell.proto"; 5 | import "aster/protos/hyperparams.proto"; 6 | 7 | message BidirectionalRnn { 8 | optional bool static = 1 [default = true]; 9 | optional RnnCell fw_bw_rnn_cell = 2; 10 | optional Regularizer rnn_regularizer = 3; 11 | optional int32 num_output_units = 4 [default = 0]; 12 | optional Hyperparams fc_hyperparams = 5; 13 | optional bool summarize_activations = 6 [default = false]; 14 | } 15 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/convnet.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/hyperparams.proto"; 5 | 6 | message Convnet { 7 | oneof convnet_oneof { 8 | CrnnNet crnn_net = 1; 9 | ResNet resnet = 2; 10 | StnConvnet stn_convnet = 3; 11 | } 12 | } 13 | 14 | message CrnnNet { 15 | enum NetType { 16 | SINGLE_BRANCH = 0; 17 | TWO_BRANCHES = 1; 18 | THREE_BRANCHES = 2; 19 | } 20 | optional NetType net_type = 1 [default = SINGLE_BRANCH]; 21 | optional Hyperparams conv_hyperparams = 2; 22 | optional bool summarize_activations = 3 [default=false]; 23 | optional bool tiny = 4 [default = false]; 24 | } 25 | 26 | message ResNet { 27 | enum NetType { 28 | SINGLE_BRANCH = 0; 29 | TWO_BRANCHES = 1; 30 | THREE_BRANCHES = 2; 31 | } 32 | enum NetDepth { 33 | RESNET_30 = 0; 34 | RESNET_50 = 1; 35 | RESNET_100 = 2; 36 | } 37 | optional NetType net_type = 1 [default = SINGLE_BRANCH]; 38 | optional NetDepth net_depth = 2 [default = RESNET_50]; 39 | optional Hyperparams conv_hyperparams = 3; 40 | optional bool summarize_activations = 4 [default=false]; 41 | } 42 | 43 | message StnConvnet { 44 | optional Hyperparams conv_hyperparams = 1; 45 | optional bool summarize_activations = 2 [default=false]; 46 | optional bool tiny = 3 [default = false]; 47 | } 48 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/eval.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/preprocessor.proto"; 5 | 6 | // Message for configuring DetectionModel evaluation jobs (eval.py). 7 | message EvalConfig { 8 | // Number of visualization images to generate. 9 | optional uint32 num_visualizations = 1 [default=10]; 10 | 11 | optional bool only_visualize_incorrect = 14 [default=false]; 12 | 13 | // Number of examples to process of evaluation. 14 | optional uint32 num_examples = 2 [default=5000]; 15 | 16 | // How often to run evaluation. 17 | optional uint32 eval_interval_secs = 3 [default=300]; 18 | 19 | // Maximum number of times to run evaluation. If set to 0, will run forever. 20 | optional uint32 max_evals = 4 [default=0]; 21 | 22 | // Whether the TensorFlow graph used for evaluation should be saved to disk. 23 | optional bool save_graph = 5 [default=false]; 24 | 25 | // Path to directory to store visualizations in. If empty, visualization 26 | // images are not exported (only shown on Tensorboard). 27 | optional string visualization_export_dir = 6 [default=""]; 28 | 29 | // BNS name of the TensorFlow master. 30 | optional string eval_master = 7 [default=""]; 31 | 32 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC 33 | // detection metrics. 34 | optional string metrics_set = 8 [default="recognition_metrics"]; 35 | 36 | // Path to export detections to COCO compatible JSON format. 37 | optional string export_path = 9 [default='']; 38 | 39 | // Option to not read groundtruth labels and only export detections to 40 | // COCO-compatible JSON file. 41 | optional bool ignore_groundtruth = 10 [default=false]; 42 | 43 | // Use exponential moving averages of variables for evaluation. 44 | // TODO: When this is false make sure the model is constructed 45 | // without moving averages in restore_fn. 46 | optional bool use_moving_averages = 11 [default=false]; 47 | 48 | // Whether to evaluate instance masks. 49 | optional bool eval_instance_masks = 12 [default=false]; 50 | 51 | // Whether to evaluate with lexicon 52 | optional bool eval_with_lexicon = 15 [default=false]; 53 | 54 | // data preprocessing steps 55 | repeated PreprocessingStep data_preprocessing_steps = 13; 56 | } 57 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/feature_extractor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/convnet.proto"; 5 | import "aster/protos/bidirectional_rnn.proto"; 6 | 7 | message FeatureExtractor { 8 | optional Convnet convnet = 1; 9 | repeated BidirectionalRnn bidirectional_rnn = 2; 10 | optional bool summarize_activations = 3 [default=false]; 11 | } 12 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | 5 | message InputReader { 6 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string 7 | // labels to integer ids. 8 | optional string label_map_path = 1 [default=""]; 9 | 10 | // Whether data should be processed in the order they are read in, or 11 | // shuffled randomly. 12 | optional bool shuffle = 2 [default=true]; 13 | 14 | // Maximum number of records to keep in reader queue. 15 | optional uint32 queue_capacity = 3 [default=2000]; 16 | 17 | // Minimum number of records to keep in reader queue. A large value is needed 18 | // to generate a good random shuffle. 19 | optional uint32 min_after_dequeue = 4 [default=1000]; 20 | 21 | // The number of times a data source is read. If set to zero, the data source 22 | // will be reused indefinitely. 23 | optional uint32 num_epochs = 5 [default=0]; 24 | 25 | // Number of reader instances to create. 26 | optional uint32 num_readers = 6 [default=8]; 27 | 28 | // Whether to load groundtruth instance masks. 29 | optional bool load_instance_masks = 7 [default = false]; 30 | 31 | oneof input_reader { 32 | TFRecordInputReader tf_record_input_reader = 8; 33 | ExternalInputReader external_input_reader = 9; 34 | } 35 | } 36 | 37 | // An input reader that reads TF Example protos from local TFRecord files. 38 | message TFRecordInputReader { 39 | // Path to TFRecordFile. 40 | optional string input_path = 1 [default=""]; 41 | } 42 | 43 | // An externally defined input reader. Users may define an extension to this 44 | // proto to interface their own input readers. 45 | message ExternalInputReader { 46 | extensions 1 to 999; 47 | } 48 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/label_map.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | message LabelMap { 5 | optional CharacterSet character_set = 1; 6 | optional int64 label_offset = 2 [default=0]; 7 | optional int64 unk_label = 3; 8 | } 9 | 10 | message CharacterSet { 11 | enum BuiltInSet { 12 | LOWERCASE = 0; 13 | ALLCASES = 1; 14 | ALLCASES_SYMBOLS = 2; 15 | } 16 | 17 | oneof source_oneof { 18 | string text_file = 1 [default=""]; 19 | string text_string = 2 [default=""]; 20 | BuiltInSet built_in_set = 3 [default=LOWERCASE]; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/loss.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | message Loss { 5 | oneof loss_oneof { 6 | SequenceCrossEntropyLoss sequence_cross_entropy_loss = 1; 7 | L2RegressionLoss l2_regression_loss = 2; 8 | } 9 | } 10 | 11 | message SequenceCrossEntropyLoss { 12 | optional bool sequence_normalize = 1 [default = false]; 13 | optional bool sample_normalize = 2 [default = true]; 14 | optional float weight = 3 [default = 1.0]; 15 | } 16 | 17 | message L2RegressionLoss { 18 | optional float weight = 3 [default = 1.0]; 19 | } 20 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/feature_extractor.proto"; 5 | import "aster/protos/predictor.proto"; 6 | import "aster/protos/spatial_transformer.proto"; 7 | import "aster/protos/loss.proto"; 8 | 9 | 10 | message Model { 11 | oneof model_oneof { 12 | MultiPredictorsRecognitionModel multi_predictors_recognition_model = 1; 13 | } 14 | } 15 | 16 | message MultiPredictorsRecognitionModel { 17 | optional SpatialTransformer spatial_transformer = 1; 18 | optional FeatureExtractor feature_extractor = 2; 19 | repeated Predictor predictor = 3; 20 | optional bool keypoint_supervision = 4 [default = false]; 21 | optional Loss regression_loss = 5; 22 | } 23 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | // Messages for configuring the optimizing strategy for training object 5 | // detection models. 6 | 7 | // Top level optimizer message. 8 | message Optimizer { 9 | oneof optimizer { 10 | RMSPropOptimizer rms_prop_optimizer = 1; 11 | MomentumOptimizer momentum_optimizer = 2; 12 | AdamOptimizer adam_optimizer = 3; 13 | NadamOptimizer nadam_optimizer=4; 14 | AdadeltaOptimizer adadelta_optimizer = 5; 15 | } 16 | optional bool use_moving_average = 6 [default=true]; 17 | optional float moving_average_decay = 7 [default=0.9999]; 18 | } 19 | 20 | // Configuration message for the RMSPropOptimizer 21 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 22 | message RMSPropOptimizer { 23 | optional LearningRate learning_rate = 1; 24 | optional float momentum_optimizer_value = 2 [default=0.9]; 25 | optional float decay = 3 [default=0.9]; 26 | optional float epsilon = 4 [default=1.0]; 27 | } 28 | 29 | // Configuration message for the MomentumOptimizer 30 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 31 | message MomentumOptimizer { 32 | optional LearningRate learning_rate = 1; 33 | optional float momentum_optimizer_value = 2 [default=0.9]; 34 | } 35 | 36 | // Configuration message for the AdamOptimizer 37 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 38 | message AdamOptimizer { 39 | optional LearningRate learning_rate = 1; 40 | } 41 | 42 | message NadamOptimizer { 43 | optional LearningRate learning_rate = 1; 44 | } 45 | 46 | message AdadeltaOptimizer { 47 | optional LearningRate learning_rate = 1; 48 | optional float rho = 2 [default=0.95]; 49 | } 50 | 51 | // Configuration message for optimizer learning rate. 52 | message LearningRate { 53 | oneof learning_rate { 54 | ConstantLearningRate constant_learning_rate = 1; 55 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2; 56 | ManualStepLearningRate manual_step_learning_rate = 3; 57 | } 58 | } 59 | 60 | // Configuration message for a constant learning rate. 61 | message ConstantLearningRate { 62 | optional float learning_rate = 1 [default=0.002]; 63 | } 64 | 65 | // Configuration message for an exponentially decaying learning rate. 66 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ 67 | // decaying_the_learning_rate#exponential_decay 68 | message ExponentialDecayLearningRate { 69 | optional float initial_learning_rate = 1 [default=0.002]; 70 | optional uint32 decay_steps = 2 [default=4000000]; 71 | optional float decay_factor = 3 [default=0.95]; 72 | optional bool staircase = 4 [default=true]; 73 | } 74 | 75 | // Configuration message for a manually defined learning rate schedule. 76 | message ManualStepLearningRate { 77 | optional float initial_learning_rate = 1 [default=0.002]; 78 | message LearningRateSchedule { 79 | optional uint32 step = 1; 80 | optional float learning_rate = 2 [default=0.002]; 81 | } 82 | repeated LearningRateSchedule schedule = 2; 83 | } 84 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/train.proto"; 5 | import "aster/protos/eval.proto"; 6 | import "aster/protos/input_reader.proto"; 7 | import "aster/protos/model.proto"; 8 | 9 | // Convenience message for configuring a training and eval pipeline. Allows all 10 | // of the pipeline parameters to be configured from one file. 11 | message TrainEvalPipelineConfig { 12 | optional Model model = 1; 13 | optional TrainConfig train_config = 2; 14 | repeated InputReader train_input_reader = 3; 15 | optional EvalConfig eval_config = 4; 16 | optional InputReader eval_input_reader = 5; 17 | } 18 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/predictor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/rnn_cell.proto"; 5 | import "aster/protos/hyperparams.proto"; 6 | import "aster/protos/label_map.proto"; 7 | import "aster/protos/loss.proto"; 8 | 9 | 10 | message Predictor { 11 | optional string name = 1 [default = "Predictor"]; 12 | oneof predictor_oneof { 13 | AttentionPredictor attention_predictor = 2; 14 | CtcPredictor ctc_predictor = 3; 15 | } 16 | } 17 | 18 | message AttentionPredictor { 19 | optional RnnCell rnn_cell = 1; 20 | optional Regularizer rnn_regularizer = 2; 21 | optional int32 num_attention_units = 3 [default=128]; 22 | optional int32 max_num_steps = 4 [default=40]; 23 | optional bool multi_attention = 5 [default = false]; 24 | optional int32 beam_width = 6 [default = 1]; 25 | optional bool reverse = 7 [default = false]; 26 | optional LabelMap label_map = 8; 27 | optional Loss loss = 9; 28 | optional LanguageModelRnnCell lm_rnn_cell = 10; 29 | optional bool sync = 11 [default = true]; 30 | } 31 | 32 | message LanguageModelRnnCell { 33 | repeated RnnCell rnn_cell = 1; 34 | optional string restore_path = 2 [default = ""]; 35 | } 36 | 37 | message CtcPredictor { 38 | 39 | } -------------------------------------------------------------------------------- /1 - wei/aster/protos/preprocessor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/label_map.proto"; 5 | 6 | 7 | message PreprocessingStep { 8 | oneof preprocessing_step { 9 | ResizeImageRandomMethod resize_image_random_method = 1; 10 | ResizeImage resize_image = 2; 11 | NormalizeImage normalize_image = 3; 12 | RandomPixelValueScale random_pixel_value_scale = 4; 13 | RandomRgbToGray random_rgb_to_gray = 5; 14 | RandomAdjustBrightness random_adjust_brightness = 6; 15 | RandomAdjustContrast random_adjust_contrast = 7; 16 | RandomAdjustHue random_adjust_hue = 8; 17 | RandomAdjustSaturation random_adjust_saturation = 9; 18 | RandomDistortColor random_distort_color = 10; 19 | ImageToFloat image_to_float = 11; 20 | SubtractChannelMean subtract_channel_mean = 12; 21 | RgbToGray rgb_to_gray = 13; 22 | StringFiltering string_filtering = 14; 23 | } 24 | } 25 | 26 | message ResizeImageRandomMethod { 27 | optional int32 target_height = 1 [default=512]; 28 | optional int32 target_width = 2 [default=512]; 29 | } 30 | 31 | message ResizeImage { 32 | enum Method { 33 | AREA=1; 34 | BICUBIC=2; 35 | BILINEAR=3; 36 | NEAREST_NEIGHBOR=4; 37 | } 38 | optional int32 target_height = 1 [default=512]; 39 | optional int32 target_width = 2 [default=512]; 40 | optional Method method = 3 [default=BILINEAR]; 41 | } 42 | 43 | message NormalizeImage { 44 | optional float original_minval = 1; 45 | optional float original_maxval = 2; 46 | optional float target_minval = 3 [default=0]; 47 | optional float target_maxval = 4 [default=1]; 48 | } 49 | 50 | message RandomPixelValueScale { 51 | optional float minval = 1 [default=0.9]; 52 | optional float maxval = 2 [default=1.1]; 53 | } 54 | 55 | message RandomRgbToGray { 56 | optional float probability = 1 [default=0.1]; 57 | } 58 | 59 | message RandomAdjustBrightness { 60 | optional float max_delta=1 [default=0.2]; 61 | } 62 | 63 | message RandomAdjustContrast { 64 | optional float min_delta = 1 [default=0.8]; 65 | optional float max_delta = 2 [default=1.25]; 66 | } 67 | 68 | message RandomAdjustHue { 69 | optional float max_delta = 1 [default=0.02]; 70 | } 71 | 72 | message RandomAdjustSaturation { 73 | optional float min_delta = 1 [default=0.8]; 74 | optional float max_delta = 2 [default=1.25]; 75 | } 76 | 77 | message RandomDistortColor { 78 | optional int32 color_ordering = 1; 79 | } 80 | 81 | message ImageToFloat { 82 | } 83 | 84 | message SubtractChannelMean { 85 | // The mean to subtract from each channel. Should be of same dimension of 86 | // channels in the input image. 87 | repeated float means = 1; 88 | } 89 | 90 | message RgbToGray { 91 | optional bool three_channels = 1 [default=false]; 92 | } 93 | 94 | message StringFiltering { 95 | optional bool lower_case = 1 [default=false]; 96 | optional CharacterSet include_charset = 2; 97 | } 98 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/rnn_cell.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/hyperparams.proto"; 5 | 6 | message RnnCell { 7 | oneof rnn_cell_oneof { 8 | LstmCell lstm_cell = 1; 9 | GruCell gru_cell = 2; 10 | } 11 | } 12 | 13 | message LstmCell { 14 | optional uint32 num_units = 1 [default=128]; 15 | optional bool use_peepholes = 2 [default=false]; 16 | optional float forget_bias = 3 [default=1.0]; 17 | optional Initializer initializer = 4; 18 | } 19 | 20 | message GruCell { 21 | optional uint32 num_units = 1 [default=128]; 22 | optional Initializer initializer = 2; 23 | } 24 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/spatial_transformer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/convnet.proto"; 5 | import "aster/protos/hyperparams.proto"; 6 | 7 | 8 | message SpatialTransformer { 9 | optional Convnet convnet = 1; 10 | optional Hyperparams fc_hyperparams = 2; 11 | 12 | // image size for the localization network 13 | optional int32 localization_h = 3 [default = 64]; 14 | optional int32 localization_w = 4 [default = 128]; 15 | 16 | // rectified image size 17 | optional int32 output_h = 5 [default = 32]; 18 | optional int32 output_w = 6 [default = 100]; 19 | 20 | optional float margin_x = 7 [default = 0.1]; 21 | optional float margin_y = 8 [default = 0.1]; 22 | 23 | optional int32 num_control_points = 9 [default = 20]; 24 | optional string init_bias_pattern = 10 [default = "identity"]; 25 | optional string activation = 11 [default = "none"]; 26 | optional bool summarize_activations = 12 [default = false]; 27 | } 28 | -------------------------------------------------------------------------------- /1 - wei/aster/protos/train.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package aster.protos; 3 | 4 | import "aster/protos/optimizer.proto"; 5 | import "aster/protos/preprocessor.proto"; 6 | 7 | // Message for configuring DetectionModel training jobs (train.py). 8 | message TrainConfig { 9 | // Input queue batch size. 10 | repeated uint32 batch_size = 1; 11 | 12 | // Data augmentation options. 13 | repeated PreprocessingStep data_augmentation_options = 2; 14 | 15 | // Whether to synchronize replicas during training. 16 | optional bool sync_replicas = 3 [default=false]; 17 | 18 | // How frequently to keep checkpoints. 19 | optional uint32 keep_checkpoint_every_n_hours = 4 [default=1000]; 20 | 21 | // Optimizer used to train the DetectionModel. 22 | optional Optimizer optimizer = 5; 23 | 24 | // If greater than 0, clips gradients by this value. 25 | optional float gradient_clipping_by_norm = 6 [default=0.0]; 26 | 27 | // Checkpoint to restore variables from. Typically used to load feature 28 | // extractor variables trained outside of object detection. 29 | optional string fine_tune_checkpoint = 7 [default=""]; 30 | 31 | // Specifies if the finetune checkpoint is from an object detection model. 32 | // If from an object detection model, the model being trained should have 33 | // the same parameters with the exception of the num_classes parameter. 34 | // If false, it assumes the checkpoint was a object classification model. 35 | optional bool from_detection_checkpoint = 8 [default=false]; 36 | 37 | // Number of steps to train the DetectionModel for. If 0, will train the model 38 | // indefinitely. 39 | optional uint32 num_steps = 9 [default=0]; 40 | 41 | // Number of training steps between replica startup. 42 | // This flag must be set to 0 if sync_replicas is set to true. 43 | optional float startup_delay_steps = 10 [default=15]; 44 | 45 | // If greater than 0, multiplies the gradient of bias variables by this 46 | // amount. 47 | optional float bias_grad_multiplier = 11 [default=0]; 48 | 49 | // Variables that should not be updated during training. 50 | repeated string freeze_variables = 12; 51 | 52 | // Number of replicas to aggregate before making parameter updates. 53 | optional int32 replicas_to_aggregate = 13 [default=1]; 54 | 55 | // Maximum number of elements to store within a queue. 56 | optional int32 batch_queue_capacity = 14 [default=600]; 57 | 58 | // Number of threads to use for batching. 59 | optional int32 num_batch_queue_threads = 15 [default=8]; 60 | 61 | // Maximum capacity of the queue used to prefetch assembled batches. 62 | optional int32 prefetch_queue_capacity = 16 [default=10]; 63 | 64 | // Save checkpoint every n seconds 65 | optional int32 save_checkpoint_secs = 17 [default=600]; 66 | 67 | // save summaries every n steps 68 | optional int32 save_summaries_steps = 18 [default=100]; 69 | } 70 | -------------------------------------------------------------------------------- /1 - wei/aster/tools/create_cute80_tfrecord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import xml.etree.ElementTree as ET 4 | 5 | from PIL import Image 6 | import tensorflow as tf 7 | 8 | from aster.utils import dataset_util 9 | from aster.core import standard_fields as fields 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_string('data_dir', '', 'Root directory to raw SynthText dataset.') 13 | FLAGS = flags.FLAGS 14 | 15 | def create_cute80(output_path): 16 | writer = tf.python_io.TFRecordWriter(output_path) 17 | image_list_file = os.path.join(FLAGS.data_dir, 'imagelist.txt') 18 | with open(image_list_file, 'r') as f: 19 | tlines = [tline.rstrip('\n') for tline in f.readlines()] 20 | 21 | count = 0 22 | 23 | for tline in tlines: 24 | image_rel_path, groundtruth_text, lexicon_length, lexicon = \ 25 | tline.split(' ') 26 | groundtruth_text = groundtruth_text.lower() 27 | 28 | image_path = os.path.join(FLAGS.data_dir, image_rel_path) 29 | with open(image_path, 'rb') as f: 30 | image_jpeg = f.read() 31 | 32 | example = tf.train.Example(features=tf.train.Features(feature={ 33 | fields.TfExampleFields.image_encoded: \ 34 | dataset_util.bytes_feature(image_jpeg), 35 | fields.TfExampleFields.image_format: \ 36 | dataset_util.bytes_feature('jpeg'.encode('utf-8')), 37 | fields.TfExampleFields.filename: \ 38 | dataset_util.bytes_feature(image_rel_path.encode('utf-8')), 39 | fields.TfExampleFields.channels: \ 40 | dataset_util.int64_feature(3), 41 | fields.TfExampleFields.colorspace: \ 42 | dataset_util.bytes_feature('rgb'.encode('utf-8')), 43 | fields.TfExampleFields.transcript: \ 44 | dataset_util.bytes_feature(groundtruth_text.encode('utf-8')), 45 | })) 46 | writer.write(example.SerializeToString()) 47 | count += 1 48 | 49 | writer.close() 50 | print('{} examples created'.format(count)) 51 | 52 | if __name__ == '__main__': 53 | create_cute80('data/cute80_test.tfrecord') 54 | -------------------------------------------------------------------------------- /1 - wei/aster/tools/create_ic15_tfrecord.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import io 5 | import random 6 | import re 7 | import glob 8 | 9 | from PIL import Image 10 | import tensorflow as tf 11 | 12 | from aster.utils import dataset_util 13 | from aster.core import standard_fields as fields 14 | 15 | flags = tf.app.flags 16 | flags.DEFINE_string('data_dir', '/home/mkyang/dataset/recognition/icdar2015/', 'Root directory to raw SynthText dataset.') 17 | flags.DEFINE_bool('exclude_difficult', False, 'Excluding non-alphanumeric examples.') 18 | flags.DEFINE_string('output_path', 'data/ic15_test_all.tfrecord', 'Output tfrecord path.') 19 | FLAGS = flags.FLAGS 20 | 21 | def _is_difficult(word): 22 | assert isinstance(word, str) 23 | return not re.match('^[\w]+$', word) 24 | 25 | def char_check(word): 26 | if not word.isalnum(): 27 | return False 28 | else: 29 | for char in word: 30 | if char < ' ' or char > '~': 31 | return False 32 | return True 33 | 34 | def create_ic15(output_path): 35 | writer = tf.python_io.TFRecordWriter(output_path) 36 | 37 | groundtruth_file_path = os.path.join(FLAGS.data_dir, 'test_groundtruth_all.txt') 38 | 39 | count = 0 40 | with open(groundtruth_file_path, 'r') as f: 41 | lines = f.readlines() 42 | img_gts = [line.strip() for line in lines] 43 | for img_gt in img_gts: 44 | img_rel_path, gt = img_gt.split(' ', 1) 45 | if FLAGS.exclude_difficult and not char_check(gt): 46 | continue 47 | img_path = os.path.join(FLAGS.data_dir, img_rel_path) 48 | img = Image.open(img_path) 49 | img_buff = io.BytesIO() 50 | img.save(img_buff, format='jpeg') 51 | word_crop_jpeg = img_buff.getvalue() 52 | crop_name = os.path.basename(img_path) 53 | 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | fields.TfExampleFields.image_encoded: \ 56 | dataset_util.bytes_feature(word_crop_jpeg), 57 | fields.TfExampleFields.image_format: \ 58 | dataset_util.bytes_feature('jpeg'.encode('utf-8')), 59 | fields.TfExampleFields.filename: \ 60 | dataset_util.bytes_feature(crop_name.encode('utf-8')), 61 | fields.TfExampleFields.channels: \ 62 | dataset_util.int64_feature(3), 63 | fields.TfExampleFields.colorspace: \ 64 | dataset_util.bytes_feature('rgb'.encode('utf-8')), 65 | fields.TfExampleFields.transcript: \ 66 | dataset_util.bytes_feature(gt.encode('utf-8')), 67 | })) 68 | writer.write(example.SerializeToString()) 69 | count += 1 70 | 71 | writer.close() 72 | print('{} examples created'.format(count)) 73 | 74 | if __name__ == '__main__': 75 | create_ic15(FLAGS.output_path) 76 | -------------------------------------------------------------------------------- /1 - wei/aster/tools/create_iiit5k_tfrecord.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import PIL.Image 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import scipy.io as sio 9 | from tqdm import tqdm 10 | 11 | from aster.utils import dataset_util 12 | from aster.core import standard_fields as fields 13 | 14 | flags = tf.app.flags 15 | flags.DEFINE_string('data_dir', '', 'Root directory to raw SynthText dataset.') 16 | FLAGS = flags.FLAGS 17 | 18 | 19 | def create_iiit5k_subset(output_path, train_subset=True, lexicon_index=None): 20 | writer = tf.python_io.TFRecordWriter(output_path) 21 | 22 | mat_file_name = 'traindata.mat' if train_subset else 'testdata.mat' 23 | data_key = 'traindata' if train_subset else 'testdata' 24 | groundtruth_mat_path = os.path.join(FLAGS.data_dir, mat_file_name) 25 | 26 | mat_dict = sio.loadmat(groundtruth_mat_path) 27 | entries = mat_dict[data_key].flatten() 28 | for entry in tqdm(entries): 29 | image_rel_path = str(entry[0][0]) 30 | groundtruth_text = str(entry[1][0]) 31 | if not train_subset: 32 | lexicon = [str(t[0]) for t in entry[lexicon_index].flatten()] 33 | 34 | image_path = os.path.join(FLAGS.data_dir, image_rel_path) 35 | with open(image_path, 'rb') as f: 36 | image_jpeg = f.read() 37 | 38 | example = tf.train.Example(features=tf.train.Features(feature={ 39 | fields.TfExampleFields.image_encoded: \ 40 | dataset_util.bytes_feature(image_jpeg), 41 | fields.TfExampleFields.image_format: \ 42 | dataset_util.bytes_feature('jpeg'.encode('utf-8')), 43 | fields.TfExampleFields.filename: \ 44 | dataset_util.bytes_feature(image_rel_path.encode('utf-8')), 45 | fields.TfExampleFields.channels: \ 46 | dataset_util.int64_feature(3), 47 | fields.TfExampleFields.colorspace: \ 48 | dataset_util.bytes_feature('rgb'.encode('utf-8')), 49 | fields.TfExampleFields.transcript: \ 50 | dataset_util.bytes_feature(groundtruth_text.encode('utf-8')), 51 | fields.TfExampleFields.lexicon: \ 52 | dataset_util.bytes_feature(('\t'.join(lexicon)).encode('utf-8')) 53 | })) 54 | writer.write(example.SerializeToString()) 55 | 56 | writer.close() 57 | 58 | 59 | if __name__ == '__main__': 60 | # create_iiit5k_subset('data/iiit5k_train.tfrecord', train_subset=True) 61 | create_iiit5k_subset('data/iiit5k_test_50.tfrecord', train_subset=False, lexicon_index=2) 62 | # create_iiit5k_subset('data/iiit5k_test_1k.tfrecord', train_subset=False, lexicon_index=3) 63 | -------------------------------------------------------------------------------- /1 - wei/aster/tools/create_svt_perspective_tfrecord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import xml.etree.ElementTree as ET 4 | 5 | from PIL import Image 6 | import tensorflow as tf 7 | 8 | from aster.utils import dataset_util 9 | from aster.core import standard_fields as fields 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_string('data_dir', '', 'Root directory to raw SynthText dataset.') 13 | FLAGS = flags.FLAGS 14 | 15 | def create_svt_perspective(output_path): 16 | writer = tf.python_io.TFRecordWriter(output_path) 17 | image_list_file = os.path.join(FLAGS.data_dir, 'imagelist.txt') 18 | with open(image_list_file, 'r') as f: 19 | tlines = [tline.rstrip('\n') for tline in f.readlines()] 20 | 21 | count = 0 22 | 23 | for tline in tlines: 24 | image_rel_path, groundtruth_text, lexicon_length, lexicon = \ 25 | tline.split(' ') 26 | groundtruth_text = groundtruth_text.lower() 27 | lexicon_length = int(lexicon_length) 28 | lexicon_list = [w.lower() for w in lexicon.split(',')] 29 | 30 | image_path = os.path.join(FLAGS.data_dir, image_rel_path) 31 | with open(image_path, 'rb') as f: 32 | image_jpeg = f.read() 33 | 34 | example = tf.train.Example(features=tf.train.Features(feature={ 35 | fields.TfExampleFields.image_encoded: \ 36 | dataset_util.bytes_feature(image_jpeg), 37 | fields.TfExampleFields.image_format: \ 38 | dataset_util.bytes_feature('jpeg'.encode('utf-8')), 39 | fields.TfExampleFields.filename: \ 40 | dataset_util.bytes_feature(image_rel_path.encode('utf-8')), 41 | fields.TfExampleFields.channels: \ 42 | dataset_util.int64_feature(3), 43 | fields.TfExampleFields.colorspace: \ 44 | dataset_util.bytes_feature('rgb'.encode('utf-8')), 45 | fields.TfExampleFields.transcript: \ 46 | dataset_util.bytes_feature(groundtruth_text.encode('utf-8')), 47 | fields.TfExampleFields.lexicon: \ 48 | dataset_util.bytes_feature(('\t'.join(lexicon_list)).encode('utf-8')), 49 | })) 50 | writer.write(example.SerializeToString()) 51 | count += 1 52 | 53 | writer.close() 54 | print('{} examples created'.format(count)) 55 | 56 | if __name__ == '__main__': 57 | create_svt_perspective('data/svt_perspective_test.tfrecord') 58 | -------------------------------------------------------------------------------- /1 - wei/aster/tools/create_svt_tfrecord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import xml.etree.ElementTree as ET 4 | 5 | from PIL import Image 6 | import tensorflow as tf 7 | 8 | from aster.utils import dataset_util 9 | from aster.core import standard_fields as fields 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_string('data_dir', '', 'Root directory to raw SynthText dataset.') 13 | flags.DEFINE_float('crop_margin', 0.05, 'Margin in percent of word height') 14 | FLAGS = flags.FLAGS 15 | 16 | 17 | def create_svt_subset(output_path): 18 | writer = tf.python_io.TFRecordWriter(output_path) 19 | test_xml_path = os.path.join(FLAGS.data_dir, 'test.xml') 20 | count = 0 21 | xml_root = ET.parse(test_xml_path).getroot() 22 | for image_node in xml_root.findall('image'): 23 | image_rel_path = image_node.find('imageName').text 24 | lexicon = image_node.find('lex').text.lower() 25 | lexicon = lexicon.split(',') 26 | image_path = os.path.join(FLAGS.data_dir, image_rel_path) 27 | image = Image.open(image_path) 28 | image_w, image_h = image.size 29 | 30 | for i, rect in enumerate(image_node.find('taggedRectangles')): 31 | bbox_x = float(rect.get('x')) 32 | bbox_y = float(rect.get('y')) 33 | bbox_w = float(rect.get('width')) 34 | bbox_h = float(rect.get('height')) 35 | if FLAGS.crop_margin > 0: 36 | margin = bbox_h * FLAGS.crop_margin 37 | bbox_x = bbox_x - margin 38 | bbox_y = bbox_y - margin 39 | bbox_w = bbox_w + 2 * margin 40 | bbox_h = bbox_h + 2 * margin 41 | bbox_xmin = int(round(max(0, bbox_x))) 42 | bbox_ymin = int(round(max(0, bbox_y))) 43 | bbox_xmax = int(round(min(image_w-1, bbox_x + bbox_w))) 44 | bbox_ymax = int(round(min(image_h-1, bbox_y + bbox_h))) 45 | 46 | word_crop_im = image.crop((bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax)) 47 | im_buff = io.BytesIO() 48 | word_crop_im.save(im_buff, format='jpeg') 49 | word_crop_jpeg = im_buff.getvalue() 50 | crop_name = '{}:{}'.format(image_rel_path, i) 51 | 52 | groundtruth_text = rect.find('tag').text.lower() 53 | 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | fields.TfExampleFields.image_encoded: \ 56 | dataset_util.bytes_feature(word_crop_jpeg), 57 | fields.TfExampleFields.image_format: \ 58 | dataset_util.bytes_feature('jpeg'.encode('utf-8')), 59 | fields.TfExampleFields.filename: \ 60 | dataset_util.bytes_feature(crop_name.encode('utf-8')), 61 | fields.TfExampleFields.channels: \ 62 | dataset_util.int64_feature(3), 63 | fields.TfExampleFields.colorspace: \ 64 | dataset_util.bytes_feature('rgb'.encode('utf-8')), 65 | fields.TfExampleFields.transcript: \ 66 | dataset_util.bytes_feature(groundtruth_text.encode('utf-8')), 67 | fields.TfExampleFields.lexicon: \ 68 | dataset_util.bytes_feature(('\t'.join(lexicon)).encode('utf-8')), 69 | })) 70 | writer.write(example.SerializeToString()) 71 | count += 1 72 | 73 | writer.close() 74 | print('{} examples created'.format(count)) 75 | 76 | if __name__ == '__main__': 77 | create_svt_subset('data/svt_test.tfrecord') 78 | -------------------------------------------------------------------------------- /1 - wei/aster/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/aster/utils/__init__.py -------------------------------------------------------------------------------- /1 - wei/aster/utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def read_examples_list(path): 42 | """Read list of training or validation examples. 43 | 44 | The file is assumed to contain a single example per line where the first 45 | token in the line is an identifier that allows us to find the image and 46 | annotation xml for that example. 47 | 48 | For example, the line: 49 | xyz 3 50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 51 | 52 | Args: 53 | path: absolute path to examples list file. 54 | 55 | Returns: 56 | list of example identifiers (strings). 57 | """ 58 | with tf.gfile.GFile(path) as fid: 59 | lines = fid.readlines() 60 | return [line.strip().split(' ')[0] for line in lines] 61 | 62 | 63 | def recursive_parse_xml_to_dict(xml): 64 | """Recursively parses XML contents to python dict. 65 | 66 | We assume that `object` tags are the only ones that can appear 67 | multiple times at the same level of a tree. 68 | 69 | Args: 70 | xml: xml tree obtained by parsing XML file contents using lxml.etree 71 | 72 | Returns: 73 | Python dictionary holding XML contents. 74 | """ 75 | if not xml: 76 | return {xml.tag: xml.text} 77 | result = {} 78 | for child in xml: 79 | child_result = recursive_parse_xml_to_dict(child) 80 | if child.tag != 'object': 81 | result[child.tag] = child_result[child.tag] 82 | else: 83 | if child.tag not in result: 84 | result[child.tag] = [] 85 | result[child.tag].append(child_result[child.tag]) 86 | return {xml.tag: result} 87 | -------------------------------------------------------------------------------- /1 - wei/aster/utils/profile_session_run_hooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import tensorflow as tf 5 | from tensorflow.python.client import timeline 6 | from tensorflow.python.training import training_util 7 | from tensorflow.python.training import session_run_hook 8 | 9 | 10 | class ProfileAtStepHook(session_run_hook.SessionRunHook): 11 | """Hook that requests stop at a specified step.""" 12 | 13 | def __init__(self, at_step=None, checkpoint_dir=None, trace_level=tf.RunOptions.FULL_TRACE): 14 | self._at_step = at_step 15 | self._do_profile = False 16 | self._writer = tf.summary.FileWriter(checkpoint_dir) 17 | self._trace_level = trace_level 18 | 19 | def begin(self): 20 | self._global_step_tensor = tf.train.get_global_step() 21 | if self._global_step_tensor is None: 22 | raise RuntimeError("Global step should be created to use ProfileAtStepHook.") 23 | 24 | def before_run(self, run_context): # pylint: disable=unused-argument 25 | if self._do_profile: 26 | options = tf.RunOptions(trace_level=self._trace_level) 27 | else: 28 | options = None 29 | return tf.train.SessionRunArgs(self._global_step_tensor, options=options) 30 | 31 | def after_run(self, run_context, run_values): 32 | global_step = run_values.results - 1 33 | if self._do_profile: 34 | self._do_profile = False 35 | self._writer.add_run_metadata(run_values.run_metadata, 36 | 'trace_{}'.format(global_step), global_step) 37 | timeline_object = timeline.Timeline(run_values.run_metadata.step_stats) 38 | chrome_trace = timeline_object.generate_chrome_trace_format() 39 | chrome_trace_save_path = 'timeline_{}.json'.format(global_step) 40 | with open(chrome_trace_save_path, 'w') as f: 41 | f.write(chrome_trace) 42 | logging.info('Profile trace saved to {}'.format(chrome_trace_save_path)) 43 | if global_step == self._at_step: 44 | self._do_profile = True 45 | -------------------------------------------------------------------------------- /1 - wei/aster/utils/recognition_evaluation.py: -------------------------------------------------------------------------------- 1 | import string 2 | import logging 3 | 4 | import numpy as np 5 | import edit_distance 6 | 7 | 8 | class RecognitionEvaluation(object): 9 | def __init__(self): 10 | self.image_keys = set() 11 | self.all_recognition_text = [] 12 | self.all_groundtruth_text = [] 13 | 14 | def clear(self): 15 | self.image_keys = set() 16 | self.all_recognition_text = [] 17 | self.all_groundtruth_text = [] 18 | 19 | def add_single_image_recognition_info(self, image_key, recognition_text, groundtruth_text): 20 | """ 21 | Args: 22 | image_key: Python string 23 | recognition_text: Numpy scalar of string type 24 | groundtruth_text: Numpy scalar of string type 25 | """ 26 | if image_key in self.image_keys: 27 | logging.warning('{} already evaluated'.format(image_key)) 28 | return 29 | self.image_keys.add(image_key) 30 | 31 | self.all_recognition_text.append(recognition_text.decode('utf-8')) 32 | self.all_groundtruth_text.append(groundtruth_text.decode('utf-8')) 33 | 34 | def evaluate_all(self): 35 | num_samples = len(self.all_recognition_text) 36 | 37 | def _normalize_text(text): 38 | text = ''.join(filter(lambda x: x in (string.digits + string.ascii_letters), text)) 39 | return text.lower() 40 | 41 | num_correct = 0 42 | num_incorrect = 0 43 | total_edit_distance = 0 44 | incorrect_pairs = [] 45 | for i in range(num_samples): 46 | recogition = _normalize_text(self.all_recognition_text[i]) 47 | groundtruth = _normalize_text(self.all_groundtruth_text[i]) 48 | if recogition == groundtruth: 49 | num_correct += 1 50 | else: 51 | num_incorrect += 1 52 | incorrect_pairs.append((recogition, groundtruth)) 53 | sm = edit_distance.SequenceMatcher(a=recogition, b=groundtruth) 54 | normalized_ed = sm.distance() / len(groundtruth) 55 | total_edit_distance += normalized_ed 56 | num_print = min(len(incorrect_pairs), 100) 57 | # print('*** Groundtruth => Prediction ***') 58 | # for i in range(num_print): 59 | # recogition, groundtruth = incorrect_pairs[i] 60 | # print('{} => {}'.format(groundtruth, recogition)) 61 | # print('**********************************') 62 | case_insensitive_accuracy = num_correct / (num_correct + num_incorrect) 63 | 64 | metrics = { 65 | 'WordAccuracy': case_insensitive_accuracy, 66 | 'TotalEditDistance': total_edit_distance, 67 | } 68 | return metrics 69 | -------------------------------------------------------------------------------- /1 - wei/aster/utils/visualization_utils_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from aster.utils import visualization_utils 4 | 5 | class VisualizationUtilsTest(tf.test.TestCase): 6 | 7 | def test_tile_activation_maps_with_padding(self): 8 | test_maps = tf.random_uniform([64, 32, 100, 16]) 9 | tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 5) 10 | 11 | with self.test_session() as sess: 12 | tiled_map_output = tiled_map.eval() 13 | self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 5, 1]) 14 | 15 | def test_tile_activation_maps_with_slicing(self): 16 | test_maps = tf.random_uniform([64, 32, 100, 16]) 17 | tiled_map = visualization_utils.tile_activation_maps_rows_cols(test_maps, 5, 1) 18 | 19 | with self.test_session() as sess: 20 | tiled_map_output = tiled_map.eval() 21 | self.assertAllEqual(tiled_map_output.shape, [64, 32 * 5, 100 * 1, 1]) 22 | 23 | def test_tile_activation_maps_max_sizes(self): 24 | test_maps = tf.random_uniform([64, 32, 100, 16]) 25 | tiled_map = visualization_utils.tile_activation_maps_max_dimensions( 26 | test_maps, 512, 512) 27 | 28 | with self.test_session() as sess: 29 | tiled_map_output = tiled_map.eval() 30 | self.assertAllEqual(tiled_map_output.shape, [64, 512, 500, 1]) 31 | 32 | 33 | if __name__ == '__main__': 34 | tf.test.main() 35 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/README.md: -------------------------------------------------------------------------------- 1 | # ctpn-crnn 2 | 3 | 从别人那里拿来的,稍微改了下,适合识别竖排书法,原文: 4 | 5 | https://github.com/eragonruan/text-detection-ctpn 6 | 7 | https://github.com/Sierkinhane/crnn_chinese_characters_rec 8 | 9 | 10 | ## cptn 11 | 12 | ### 环境搭建(tqdm,opencv-python,Shapely,matplotlib,numpy,tensorflow-gpu or tensorflow,Cython,ipython 自行用pip3安装) 13 | 14 | cd cptn/utils/bbox 15 | sh make.sh 16 | 17 | ### 创建数据集 18 | cd cptn/utils/prepare 19 | python3 split_label.py(DATA_FOLDER 和 OUTPUT 改成自己的路径) 20 | 21 | 原始数据图片: 22 | ![Image text](https://github.com/hwwu/cptn-crnn/blob/master/cptn/data/demo/source/img_calligraphy_70001_bg.jpg) 23 | 24 | csv标记内容为: 25 | 26 | FileName | x1| y1| x2| y2| x3| y3| x4| y4| text 27 | ----------------------------|---|---|---|---|---|---|---|---|------ 28 | img_calligraphy_70001_bg.jpg|72 |53 |96 |53 |96 |358|72 |358|黎沈昨骑托那缝丁聚侮篮海炭 29 | img_calligraphy_70001_bg.jpg|46 |53 |70 |53 |70 |394|46 |394|缩蝇躁劣趋拴局伦绸启杭吭惯蛋仅 30 | img_calligraphy_70001_bg.jpg|20 |53 |44 |53 |44 |174|20 |174|效射市关蝉 31 | 32 | 创建好的数据集图片(绿线为标记的坐标,输入ctpn的图片是不带绿线的): 33 | ![Image text](https://github.com/hwwu/cptn-crnn/blob/master/cptn/data/demo/img_calligraphy_70001_bg.jpg) 34 | 35 | lable格式在 data/demo/img_calligraphy_70001_bg.txt中 36 | 37 | ### 训练 38 | 39 | 1.首先下载vgg16的模型 40 | https://github.com/tensorflow/models/tree/1af55e018eebce03fb61bba9959a04672536107d/research/slim 41 | ,放到data目录下 42 | 43 | 2.运行 python3 main/train.py 44 | 可以直接下载我训练好的模型: 45 | https://drive.google.com/open?id=1RwZb1HLG0vum-5RHZdSfqtDD2in_sNRD 46 | 47 | 48 | ## crnn 49 | 50 | ### 环境搭建 51 | 52 | 安装 pytorch和warp-ctc 53 | 根据cuda版本选择pytorch的安装文件 54 | pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl 55 | pip3 install torchvision 56 | git clone https://github.com/SeanNaren/warp-ctc.git 57 | cd warp-ctc;mkdir build; cd build;cmake ..;make 58 | cd ./warp-ctc/pytorch_binding;python setup.py install 59 | 将 pytorch_binding 中生成的warpctc_pytorch文件夹copy到crnn下 60 | 61 | ### 制作训练集 62 | python3 to_lmdb/tolmdb.py(按比例将训练集划分成train和val,生成方式一样,就是放到不同的目录) 63 | 根据自己的lable文件,生成 alphabets.py(就是将所有lable去重写进此文件) 64 | 65 | ### 训练(第一次训练,--model_path '') 66 | python3 crnn_main.py --trainroot './data/' --valroot './data/val' --cuda --model_path './expr/crnn_Rec_done_35_2019-03-27.pth' 67 | 68 | 69 | ### 训练结果 70 | 我的训练模型(35次可能对训练集有些过拟合,提供了中间的几个结果模型,看哪一个更适合): 71 | https://drive.google.com/open?id=1Ckz1j5ZXfNILh1ePJlYcDpE_PQ-tqML- 72 | 训练结果: 73 | Test loss: 0.178429, accuray: 0.997500 74 | [35/300][24100/29790] Loss: 0.658935 75 | [35/300][24200/29790] Loss: 0.534306 76 | [35/300][24300/29790] Loss: 0.541349 77 | [35/300][24400/29790] Loss: 0.475645 78 | ![Image text](https://github.com/hwwu/cptn-crnn/blob/master/crnn/test_images/1img_calligraphy_80040_bg.jpg) : 兽亵播疒阌飨百怒逭纫 79 | 80 | ![Image text](https://github.com/hwwu/cptn-crnn/blob/master/crnn/test_images/3img_calligraphy_80011_bg.jpg) : 蜂肉昆材摄 81 | 82 | ![Image text](https://github.com/hwwu/cptn-crnn/blob/master/crnn/test_images/3img_calligraphy_80017_bg.jpg) : 不似周趋阙去 83 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/data/demo/img_calligraphy_70001_bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ctpn-crnn/cptn/data/demo/img_calligraphy_70001_bg.jpg -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/data/demo/img_calligraphy_70001_bg.txt: -------------------------------------------------------------------------------- 1 | 146,83,159,150 2 | 160,83,175,150 3 | 176,83,191,150 4 | 192,83,207,150 5 | 208,83,223,150 6 | 224,83,239,150 7 | 240,83,255,150 8 | 256,83,271,150 9 | 272,83,287,150 10 | 288,83,303,150 11 | 304,83,319,150 12 | 320,83,335,150 13 | 336,83,351,150 14 | 352,83,367,150 15 | 368,83,383,150 16 | 384,83,399,150 17 | 400,83,415,150 18 | 416,83,431,150 19 | 432,83,447,150 20 | 448,83,463,150 21 | 464,83,479,150 22 | 480,83,495,150 23 | 496,83,511,150 24 | 512,83,527,150 25 | 528,83,543,150 26 | 544,83,559,150 27 | 560,83,575,150 28 | 576,83,591,150 29 | 592,83,607,150 30 | 608,83,623,150 31 | 624,83,639,150 32 | 640,83,655,150 33 | 656,83,671,150 34 | 672,83,687,150 35 | 688,83,703,150 36 | 704,83,719,150 37 | 720,83,735,150 38 | 736,83,751,150 39 | 752,83,767,150 40 | 768,83,783,150 41 | 784,83,799,150 42 | 800,83,815,150 43 | 816,83,831,150 44 | 832,83,847,150 45 | 848,83,863,150 46 | 864,83,879,150 47 | 880,83,895,150 48 | 896,83,911,150 49 | 912,83,927,150 50 | 928,83,943,150 51 | 944,83,959,150 52 | 960,83,975,150 53 | 976,83,991,150 54 | 146,156,159,223 55 | 160,156,175,223 56 | 176,156,191,223 57 | 192,156,207,223 58 | 208,156,223,223 59 | 224,156,239,223 60 | 240,156,255,223 61 | 256,156,271,223 62 | 272,156,287,223 63 | 288,156,303,223 64 | 304,156,319,223 65 | 320,156,335,223 66 | 336,156,351,223 67 | 352,156,367,223 68 | 368,156,383,223 69 | 384,156,399,223 70 | 400,156,415,223 71 | 416,156,431,223 72 | 432,156,447,223 73 | 448,156,463,223 74 | 464,156,479,223 75 | 480,156,495,223 76 | 496,156,511,223 77 | 512,156,527,223 78 | 528,156,543,223 79 | 544,156,559,223 80 | 560,156,575,223 81 | 576,156,591,223 82 | 592,156,607,223 83 | 608,156,623,223 84 | 624,156,639,223 85 | 640,156,655,223 86 | 656,156,671,223 87 | 672,156,687,223 88 | 688,156,703,223 89 | 704,156,719,223 90 | 720,156,735,223 91 | 736,156,751,223 92 | 752,156,767,223 93 | 768,156,783,223 94 | 784,156,799,223 95 | 800,156,815,223 96 | 816,156,831,223 97 | 832,156,847,223 98 | 848,156,863,223 99 | 864,156,879,223 100 | 880,156,895,223 101 | 896,156,911,223 102 | 912,156,927,223 103 | 928,156,943,223 104 | 944,156,959,223 105 | 960,156,975,223 106 | 976,156,991,223 107 | 992,156,1007,223 108 | 1008,156,1023,223 109 | 1024,156,1039,223 110 | 1040,156,1055,223 111 | 1056,156,1071,223 112 | 1072,156,1087,223 113 | 146,229,159,296 114 | 160,229,175,296 115 | 176,229,191,296 116 | 192,229,207,296 117 | 208,229,223,296 118 | 224,229,239,296 119 | 240,229,255,296 120 | 256,229,271,296 121 | 272,229,287,296 122 | 288,229,303,296 123 | 304,229,319,296 124 | 320,229,335,296 125 | 336,229,351,296 126 | 352,229,367,296 127 | 368,229,383,296 128 | 384,229,399,296 129 | 400,229,415,296 130 | 416,229,431,296 131 | 432,229,447,296 132 | 448,229,463,296 133 | 464,229,479,296 134 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/data/demo/source/img_calligraphy_70001_bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ctpn-crnn/cptn/data/demo/source/img_calligraphy_70001_bg.jpg -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/data/demo/source/lable.csv: -------------------------------------------------------------------------------- 1 | FileName,x1,y1,x2,y2,x3,y3,x4,y4,text 2 | img_calligraphy_70001_bg.jpg,72,53,96,53,96,358,72,358,黎沈昨骑托那缝丁聚侮篮海炭 3 | img_calligraphy_70001_bg.jpg,46,53,70,53,70,394,46,394,缩蝇躁劣趋拴局伦绸启杭吭惯蛋仅 4 | img_calligraphy_70001_bg.jpg,20,53,44,53,44,174,20,174,效射市关蝉 5 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/nets/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def vgg_arg_scope(weight_decay=0.0005): 7 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 8 | activation_fn=tf.nn.relu, 9 | weights_regularizer=slim.l2_regularizer(weight_decay), 10 | biases_initializer=tf.zeros_initializer()): 11 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 12 | return arg_sc 13 | 14 | 15 | def vgg_16(inputs, scope='vgg_16'): 16 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 17 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d]): 18 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 19 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 20 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 21 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 22 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 23 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 24 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 25 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 26 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 27 | 28 | return net 29 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/bbox/bbox.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | DTYPE = np.float 13 | ctypedef np.float_t DTYPE_t 14 | 15 | def bbox_overlaps( 16 | np.ndarray[DTYPE_t, ndim=2] boxes, 17 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 18 | """ 19 | Parameters 20 | ---------- 21 | boxes: (N, 4) ndarray of float 22 | query_boxes: (K, 4) ndarray of float 23 | Returns 24 | ------- 25 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 26 | """ 27 | cdef unsigned int N = boxes.shape[0] 28 | cdef unsigned int K = query_boxes.shape[0] 29 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) 30 | cdef DTYPE_t iw, ih, box_area 31 | cdef DTYPE_t ua 32 | cdef unsigned int k, n 33 | for k in range(K): 34 | box_area = ( 35 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 36 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 37 | ) 38 | for n in range(N): 39 | iw = ( 40 | min(boxes[n, 2], query_boxes[k, 2]) - 41 | max(boxes[n, 0], query_boxes[k, 0]) + 1 42 | ) 43 | if iw > 0: 44 | ih = ( 45 | min(boxes[n, 3], query_boxes[k, 3]) - 46 | max(boxes[n, 1], query_boxes[k, 1]) + 1 47 | ) 48 | if ih > 0: 49 | ua = float( 50 | (boxes[n, 2] - boxes[n, 0] + 1) * 51 | (boxes[n, 3] - boxes[n, 1] + 1) + 52 | box_area - iw * ih 53 | ) 54 | overlaps[n, k] = iw * ih / ua 55 | return overlaps 56 | 57 | def bbox_intersections( 58 | np.ndarray[DTYPE_t, ndim=2] boxes, 59 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 60 | """ 61 | For each query box compute the intersection ratio covered by boxes 62 | ---------- 63 | Parameters 64 | ---------- 65 | boxes: (N, 4) ndarray of float 66 | query_boxes: (K, 4) ndarray of float 67 | Returns 68 | ------- 69 | overlaps: (N, K) ndarray of intersec between boxes and query_boxes 70 | """ 71 | cdef unsigned int N = boxes.shape[0] 72 | cdef unsigned int K = query_boxes.shape[0] 73 | cdef np.ndarray[DTYPE_t, ndim=2] intersec = np.zeros((N, K), dtype=DTYPE) 74 | cdef DTYPE_t iw, ih, box_area 75 | cdef DTYPE_t ua 76 | cdef unsigned int k, n 77 | for k in range(K): 78 | box_area = ( 79 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 80 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 81 | ) 82 | for n in range(N): 83 | iw = ( 84 | min(boxes[n, 2], query_boxes[k, 2]) - 85 | max(boxes[n, 0], query_boxes[k, 0]) + 1 86 | ) 87 | if iw > 0: 88 | ih = ( 89 | min(boxes[n, 3], query_boxes[k, 3]) - 90 | max(boxes[n, 1], query_boxes[k, 1]) + 1 91 | ) 92 | if ih > 0: 93 | intersec[n, k] = iw * ih / box_area 94 | return intersec -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/bbox/bbox_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_transform(ex_rois, gt_rois): 5 | """ 6 | computes the distance from ground-truth boxes to the given boxes, normed by their size 7 | :param ex_rois: n * 4 numpy array, given boxes 8 | :param gt_rois: n * 4 numpy array, ground-truth boxes 9 | :return: deltas: n * 4 numpy array, ground-truth boxes 10 | """ 11 | ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 12 | ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 13 | ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths 14 | ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights 15 | 16 | assert np.min(ex_widths) > 0.1 and np.min(ex_heights) > 0.1, \ 17 | 'Invalid boxes found: {} {}'.format(ex_rois[np.argmin(ex_widths), :], ex_rois[np.argmin(ex_heights), :]) 18 | 19 | gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0 20 | gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0 21 | gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths 22 | gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights 23 | 24 | # warnings.catch_warnings() 25 | # warnings.filterwarnings('error') 26 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths 27 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights 28 | targets_dw = np.log(gt_widths / ex_widths) 29 | targets_dh = np.log(gt_heights / ex_heights) 30 | 31 | targets = np.vstack( 32 | (targets_dx, targets_dy, targets_dw, targets_dh)).transpose() 33 | 34 | return targets 35 | 36 | 37 | def bbox_transform_inv(boxes, deltas): 38 | boxes = boxes.astype(deltas.dtype, copy=False) 39 | 40 | widths = boxes[:, 2] - boxes[:, 0] + 1.0 41 | heights = boxes[:, 3] - boxes[:, 1] + 1.0 42 | ctr_x = boxes[:, 0] + 0.5 * widths 43 | ctr_y = boxes[:, 1] + 0.5 * heights 44 | 45 | dx = deltas[:, 0::4] 46 | dy = deltas[:, 1::4] 47 | dw = deltas[:, 2::4] 48 | dh = deltas[:, 3::4] 49 | 50 | pred_ctr_x = ctr_x[:, np.newaxis] 51 | pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis] 52 | pred_w = widths[:, np.newaxis] 53 | pred_h = np.exp(dh) * heights[:, np.newaxis] 54 | 55 | pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype) 56 | # x1 57 | pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w 58 | # y1 59 | pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h 60 | # x2 61 | pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w 62 | # y2 63 | pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h 64 | 65 | return pred_boxes 66 | 67 | 68 | def clip_boxes(boxes, im_shape): 69 | """ 70 | Clip boxes to image boundaries. 71 | """ 72 | # x1 >= 0 73 | boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) 74 | # y1 >= 0 75 | boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) 76 | # x2 < im_shape[1] 77 | boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) 78 | # y2 < im_shape[0] 79 | boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) 80 | return boxes 81 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/bbox/make.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py install 2 | mv build/*/*.so ./ 3 | rm -rf build/ -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/bbox/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | import numpy as np 4 | from Cython.Build import cythonize 5 | 6 | numpy_include = np.get_include() 7 | setup(ext_modules=cythonize("bbox.pyx"), include_dirs=[numpy_include]) 8 | setup(ext_modules=cythonize("nms.pyx"), include_dirs=[numpy_include]) 9 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/dataset/data_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | import time 4 | 5 | import numpy as np 6 | 7 | try: 8 | import queue 9 | except ImportError: 10 | import Queue as queue 11 | 12 | 13 | class GeneratorEnqueuer(): 14 | def __init__(self, generator, 15 | use_multiprocessing=False, 16 | wait_time=0.05, 17 | random_seed=None): 18 | self.wait_time = wait_time 19 | self._generator = generator 20 | self._use_multiprocessing = use_multiprocessing 21 | self._threads = [] 22 | self._stop_event = None 23 | self.queue = None 24 | self.random_seed = random_seed 25 | 26 | def start(self, workers=1, max_queue_size=10): 27 | def data_generator_task(): 28 | while not self._stop_event.is_set(): 29 | try: 30 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 31 | generator_output = next(self._generator) 32 | self.queue.put(generator_output) 33 | else: 34 | time.sleep(self.wait_time) 35 | except Exception: 36 | self._stop_event.set() 37 | raise 38 | 39 | try: 40 | if self._use_multiprocessing: 41 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 42 | self._stop_event = multiprocessing.Event() 43 | else: 44 | self.queue = queue.Queue() 45 | self._stop_event = threading.Event() 46 | 47 | for _ in range(workers): 48 | if self._use_multiprocessing: 49 | # Reset random seed else all children processes 50 | # share the same seed 51 | np.random.seed(self.random_seed) 52 | thread = multiprocessing.Process(target=data_generator_task) 53 | thread.daemon = True 54 | if self.random_seed is not None: 55 | self.random_seed += 1 56 | else: 57 | thread = threading.Thread(target=data_generator_task) 58 | self._threads.append(thread) 59 | thread.start() 60 | except: 61 | self.stop() 62 | raise 63 | 64 | def is_running(self): 65 | return self._stop_event is not None and not self._stop_event.is_set() 66 | 67 | def stop(self, timeout=None): 68 | if self.is_running(): 69 | self._stop_event.set() 70 | 71 | for thread in self._threads: 72 | if thread.is_alive(): 73 | if self._use_multiprocessing: 74 | thread.terminate() 75 | else: 76 | thread.join(timeout) 77 | 78 | if self._use_multiprocessing: 79 | if self.queue is not None: 80 | self.queue.close() 81 | 82 | self._threads = [] 83 | self._stop_event = None 84 | self.queue = None 85 | 86 | def get(self): 87 | while self.is_running(): 88 | if not self.queue.empty(): 89 | inputs = self.queue.get() 90 | if inputs is not None: 91 | yield inputs 92 | else: 93 | time.sleep(self.wait_time) 94 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/prepare/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*-coding:utf8 -*- 3 | # @TIME :2019/3/14 上午10:00 4 | # @Author :hwwu 5 | # @File :__init__.py.py -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/prepare/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | def pickTopLeft(poly): 6 | idx = np.argsort(poly[:, 0]) 7 | if poly[idx[0], 1] < poly[idx[1], 1]: 8 | s = idx[0] 9 | else: 10 | s = idx[1] 11 | 12 | return poly[(s, (s + 1) % 4, (s + 2) % 4, (s + 3) % 4), :] 13 | 14 | 15 | def orderConvex(p): 16 | points = Polygon(p).convex_hull 17 | points = np.array(points.exterior.coords)[:4] 18 | points = points[::-1] 19 | points = pickTopLeft(points) 20 | points = np.array(points).reshape([4, 2]) 21 | return points 22 | 23 | 24 | def shrink_poly(poly, r=16): 25 | # y = kx + b 26 | x_min = int(np.min(poly[:, 0])) 27 | x_max = int(np.max(poly[:, 0])) 28 | 29 | k1 = (poly[1][1] - poly[0][1]) / (poly[1][0] - poly[0][0]) 30 | b1 = poly[0][1] - k1 * poly[0][0] 31 | 32 | k2 = (poly[2][1] - poly[3][1]) / (poly[2][0] - poly[3][0]) 33 | b2 = poly[3][1] - k2 * poly[3][0] 34 | 35 | res = [] 36 | 37 | start = int((x_min // 16 + 1) * 16) 38 | end = int((x_max // 16) * 16) 39 | 40 | p = x_min 41 | res.append([p, int(k1 * p + b1), 42 | start - 1, int(k1 * (p + 15) + b1), 43 | start - 1, int(k2 * (p + 15) + b2), 44 | p, int(k2 * p + b2)]) 45 | 46 | for p in range(start, end + 1, r): 47 | res.append([p, int(k1 * p + b1), 48 | (p + 15), int(k1 * (p + 15) + b1), 49 | (p + 15), int(k2 * (p + 15) + b2), 50 | p, int(k2 * p + b2)]) 51 | return np.array(res, dtype=np.int).reshape([-1, 8]) 52 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/rpn_msr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ctpn-crnn/cptn/utils/rpn_msr/__init__.py -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/rpn_msr/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | EPS = 1e-14 3 | RPN_CLOBBER_POSITIVES = False 4 | RPN_NEGATIVE_OVERLAP = 0.3 5 | RPN_POSITIVE_OVERLAP = 0.7 6 | RPN_FG_FRACTION = 0.5 7 | RPN_BATCHSIZE = 300 8 | RPN_BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0) 9 | RPN_POSITIVE_WEIGHT = -1.0 10 | 11 | RPN_PRE_NMS_TOP_N = 12000 12 | RPN_POST_NMS_TOP_N = 1000 13 | RPN_NMS_THRESH = 0.7 14 | RPN_MIN_SIZE = 8 15 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/rpn_msr/generate_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generate_basic_anchors(sizes, base_size=16): 5 | base_anchor = np.array([0, 0, base_size - 1, base_size - 1], np.int32) 6 | anchors = np.zeros((len(sizes), 4), np.int32) 7 | index = 0 8 | for h, w in sizes: 9 | anchors[index] = scale_anchor(base_anchor, h, w) 10 | index += 1 11 | return anchors 12 | 13 | 14 | def scale_anchor(anchor, h, w): 15 | x_ctr = (anchor[0] + anchor[2]) * 0.5 16 | y_ctr = (anchor[1] + anchor[3]) * 0.5 17 | scaled_anchor = anchor.copy() 18 | scaled_anchor[0] = x_ctr - w / 2 # xmin 19 | scaled_anchor[2] = x_ctr + w / 2 # xmax 20 | scaled_anchor[1] = y_ctr - h / 2 # ymin 21 | scaled_anchor[3] = y_ctr + h / 2 # ymax 22 | return scaled_anchor 23 | 24 | 25 | def generate_anchors(base_size=16, ratios=[0.5, 1, 2], 26 | scales=2 ** np.arange(3, 6)): 27 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 28 | widths = [16] 29 | sizes = [] 30 | for h in heights: 31 | for w in widths: 32 | sizes.append((h, w)) 33 | return generate_basic_anchors(sizes) 34 | 35 | 36 | if __name__ == '__main__': 37 | import time 38 | 39 | t = time.time() 40 | a = generate_anchors() 41 | print(time.time() - t) 42 | print(a) 43 | from IPython import embed; 44 | 45 | embed() 46 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/text_connector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ctpn-crnn/cptn/utils/text_connector/__init__.py -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/text_connector/detectors.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import numpy as np 3 | from utils.bbox.nms import nms 4 | 5 | from .text_connect_cfg import Config as TextLineCfg 6 | from .text_proposal_connector import TextProposalConnector 7 | from .text_proposal_connector_oriented import TextProposalConnector as TextProposalConnectorOriented 8 | 9 | 10 | class TextDetector: 11 | def __init__(self, DETECT_MODE="H"): 12 | self.mode = DETECT_MODE 13 | if self.mode == "H": 14 | self.text_proposal_connector = TextProposalConnector() 15 | elif self.mode == "O": 16 | self.text_proposal_connector = TextProposalConnectorOriented() 17 | 18 | def detect(self, text_proposals, scores, size): 19 | # 删除得分较低的proposal 20 | keep_inds = np.where(scores > TextLineCfg.TEXT_PROPOSALS_MIN_SCORE)[0] 21 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 22 | 23 | # 按得分排序 24 | sorted_indices = np.argsort(scores.ravel())[::-1] 25 | text_proposals, scores = text_proposals[sorted_indices], scores[sorted_indices] 26 | 27 | # 对proposal做nms 28 | keep_inds = nms(np.hstack((text_proposals, scores)), TextLineCfg.TEXT_PROPOSALS_NMS_THRESH) 29 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 30 | 31 | # 获取检测结果 32 | text_recs = self.text_proposal_connector.get_text_lines(text_proposals, scores, size) 33 | keep_inds = self.filter_boxes(text_recs) 34 | return text_recs[keep_inds] 35 | 36 | def filter_boxes(self, boxes): 37 | heights = np.zeros((len(boxes), 1), np.float) 38 | widths = np.zeros((len(boxes), 1), np.float) 39 | scores = np.zeros((len(boxes), 1), np.float) 40 | index = 0 41 | for box in boxes: 42 | heights[index] = (abs(box[5] - box[1]) + abs(box[7] - box[3])) / 2.0 + 1 43 | widths[index] = (abs(box[2] - box[0]) + abs(box[6] - box[4])) / 2.0 + 1 44 | scores[index] = box[8] 45 | index += 1 46 | 47 | return np.where((widths / heights > TextLineCfg.MIN_RATIO) & (scores > TextLineCfg.LINE_MIN_SCORE) & 48 | (widths > (TextLineCfg.TEXT_PROPOSALS_WIDTH * TextLineCfg.MIN_NUM_PROPOSALS)))[0] 49 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/text_connector/other.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def threshold(coords, min_, max_): 5 | return np.maximum(np.minimum(coords, max_), min_) 6 | 7 | 8 | def clip_boxes(boxes, im_shape): 9 | """ 10 | Clip boxes to image boundaries. 11 | """ 12 | boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1) 13 | boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1) 14 | return boxes 15 | 16 | 17 | class Graph: 18 | def __init__(self, graph): 19 | self.graph = graph 20 | 21 | def sub_graphs_connected(self): 22 | sub_graphs = [] 23 | for index in range(self.graph.shape[0]): 24 | if not self.graph[:, index].any() and self.graph[index, :].any(): 25 | v = index 26 | sub_graphs.append([v]) 27 | while self.graph[v, :].any(): 28 | v = np.where(self.graph[v, :])[0][0] 29 | sub_graphs[-1].append(v) 30 | return sub_graphs 31 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/text_connector/text_connect_cfg.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | MAX_HORIZONTAL_GAP = 50 3 | TEXT_PROPOSALS_MIN_SCORE = 0.7 4 | TEXT_PROPOSALS_NMS_THRESH = 0.2 5 | MIN_V_OVERLAPS = 0.7 6 | MIN_SIZE_SIM = 0.7 7 | MIN_RATIO = 0.5 8 | LINE_MIN_SCORE = 0.9 9 | TEXT_PROPOSALS_WIDTH = 16 10 | MIN_NUM_PROPOSALS = 2 11 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/cptn/utils/text_connector/text_proposal_connector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.text_connector.other import clip_boxes 4 | from utils.text_connector.text_proposal_graph_builder import TextProposalGraphBuilder 5 | 6 | 7 | class TextProposalConnector: 8 | def __init__(self): 9 | self.graph_builder = TextProposalGraphBuilder() 10 | 11 | def group_text_proposals(self, text_proposals, scores, im_size): 12 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 13 | return graph.sub_graphs_connected() 14 | 15 | def fit_y(self, X, Y, x1, x2): 16 | len(X) != 0 17 | # if X only include one point, the function will get line y=Y[0] 18 | if np.sum(X == X[0]) == len(X): 19 | return Y[0], Y[0] 20 | p = np.poly1d(np.polyfit(X, Y, 1)) 21 | return p(x1), p(x2) 22 | 23 | def get_text_lines(self, text_proposals, scores, im_size): 24 | # tp=text proposal 25 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) 26 | text_lines = np.zeros((len(tp_groups), 5), np.float32) 27 | 28 | for index, tp_indices in enumerate(tp_groups): 29 | text_line_boxes = text_proposals[list(tp_indices)] 30 | 31 | x0 = np.min(text_line_boxes[:, 0]) 32 | x1 = np.max(text_line_boxes[:, 2]) 33 | 34 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 35 | 36 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) 37 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) 38 | 39 | # the score of a text line is the average score of the scores 40 | # of all text proposals contained in the text line 41 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) 42 | 43 | text_lines[index, 0] = x0 44 | text_lines[index, 1] = min(lt_y, rt_y) 45 | text_lines[index, 2] = x1 46 | text_lines[index, 3] = max(lb_y, rb_y) 47 | text_lines[index, 4] = score 48 | 49 | text_lines = clip_boxes(text_lines, im_size) 50 | 51 | text_recs = np.zeros((len(text_lines), 9), np.float) 52 | index = 0 53 | for line in text_lines: 54 | xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3] 55 | text_recs[index, 0] = xmin 56 | text_recs[index, 1] = ymin 57 | text_recs[index, 2] = xmax 58 | text_recs[index, 3] = ymin 59 | text_recs[index, 4] = xmax 60 | text_recs[index, 5] = ymax 61 | text_recs[index, 6] = xmin 62 | text_recs[index, 7] = ymax 63 | text_recs[index, 8] = line[4] 64 | index = index + 1 65 | 66 | return text_recs 67 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/alphabets.py: -------------------------------------------------------------------------------- 1 | alphabet = """ABCDEFGHIJKLMNOPQRSTUWXZY0123456789""" -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ctpn-crnn/crnn/models/__init__.py -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/models/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | # Inputs hidden units Out 6 | def __init__(self, nIn, nHidden, nOut): 7 | super(BidirectionalLSTM, self).__init__() 8 | 9 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 10 | self.embedding = nn.Linear(nHidden * 2, nOut) 11 | 12 | def forward(self, input): 13 | recurrent, _ = self.rnn(input) 14 | T, b, h = recurrent.size() 15 | t_rec = recurrent.view(T * b, h) 16 | 17 | output = self.embedding(t_rec) # [T * b, nOut] 18 | output = output.view(T, b, -1) 19 | 20 | return output 21 | 22 | 23 | class CRNN(nn.Module): 24 | # 32 1 37 256 25 | def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): 26 | super(CRNN, self).__init__() 27 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 28 | 29 | ks = [3, 3, 3, 3, 3, 3, 2] 30 | ps = [1, 1, 1, 1, 1, 1, 0] 31 | ss = [1, 1, 1, 1, 1, 1, 1] 32 | nm = [64, 128, 256, 256, 512, 512, 512] 33 | 34 | cnn = nn.Sequential() 35 | 36 | def convRelu(i, batchNormalization=False): 37 | nIn = nc if i == 0 else nm[i - 1] 38 | nOut = nm[i] 39 | cnn.add_module('conv{0}'.format(i), 40 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 41 | if batchNormalization: 42 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 43 | if leakyRelu: 44 | cnn.add_module('relu{0}'.format(i), 45 | nn.LeakyReLU(0.2, inplace=True)) 46 | else: 47 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 48 | 49 | convRelu(0) 50 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 51 | convRelu(1) 52 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 53 | convRelu(2, True) 54 | # cnn.add_module('pooling{0}'.format(4), nn.MaxPool2d(1, 2)) # 128x8x32 55 | convRelu(3) 56 | cnn.add_module('pooling{0}'.format(2), 57 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 58 | convRelu(4, True) 59 | convRelu(5) 60 | cnn.add_module('pooling{0}'.format(3), 61 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 62 | convRelu(6, True) # 512x1x16 63 | 64 | self.cnn = cnn 65 | self.rnn = nn.Sequential( 66 | BidirectionalLSTM(512, nh, nh), 67 | BidirectionalLSTM(nh, nh, nclass)) 68 | 69 | def forward(self, input): 70 | # conv features 71 | # print('---forward propagation---') 72 | conv = self.cnn(input) 73 | # print(conv.size()) # batch_size*512*1*with 74 | b, c, h, w = conv.size() 75 | assert h == 1, "the height of conv must be 1" 76 | conv = conv.squeeze(2) # b *512 * width 77 | # print(conv.size()) 78 | conv = conv.permute(2, 0, 1) # [w, b, c] 79 | # print(conv.size()) # width batch_size channel 80 | # rnn features 81 | output = self.rnn(conv) 82 | # print(output.size(0)) 83 | # print(output.size()) # width*batch_size*nclass 84 | return output 85 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/params.py: -------------------------------------------------------------------------------- 1 | import alphabets 2 | 3 | random_sample = True 4 | keep_ratio = False 5 | adam = True 6 | adadelta = False 7 | saveInterval = 3 8 | valInterval = 500 9 | n_test_disp = 10 10 | displayInterval = 100 11 | experiment = './expr' 12 | alphabet = alphabets.alphabet 13 | crnn = '' 14 | beta1 = 0.5 15 | lr = 0.00001 16 | niter = 50 17 | nh = 256 18 | imgW = 192 19 | imgH = 32 20 | batchSize = 16 21 | workers = 0 22 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*-coding:utf8 -*- 3 | # @TIME :2019/3/25 下午4:33 4 | # @Author :hwwu 5 | # @File :predict.py 6 | 7 | 8 | import numpy as np 9 | import sys, os 10 | import time 11 | import tensorflow as tf 12 | 13 | sys.path.append(os.getcwd()) 14 | 15 | # crnn packages 16 | import torch 17 | from torch.autograd import Variable 18 | import utils 19 | import dataset 20 | from PIL import Image, ImageFilter 21 | import models.crnn as crnn 22 | import alphabets 23 | 24 | str1 = alphabets.alphabet 25 | 26 | import argparse 27 | 28 | tf.app.flags.DEFINE_string('test_data_path', '/content/test_cptn_result', '') 29 | tf.app.flags.DEFINE_string('output_path', './', '') 30 | FLAGS = tf.app.flags.FLAGS 31 | 32 | 33 | def get_images(): 34 | files = [] 35 | exts = ['jpg', 'png', 'jpeg', 'JPG'] 36 | for parent, dirnames, filenames in os.walk(FLAGS.test_data_path): 37 | for filename in filenames: 38 | for ext in exts: 39 | if filename.endswith(ext): 40 | files.append(os.path.join(parent, filename)) 41 | break 42 | print('Find {} images'.format(len(files))) 43 | return files 44 | 45 | 46 | crnn_model_path = './expr/best_model.pth' 47 | alphabet = str1 48 | nclass = len(alphabet) + 1 49 | 50 | 51 | # crnn文本信息识别 52 | def crnn_recognition(cropped_image, model): 53 | converter = utils.strLabelConverter(alphabet) 54 | 55 | image = cropped_image.convert('L') 56 | 57 | ## 58 | # w = int(image.size[0] / (280 * 1.0 / 160)) 59 | transformer = dataset.resizeNormalize((192, 32)) 60 | image = transformer(image) 61 | # if torch.cuda.is_available(): 62 | # image = image.cuda() 63 | image = image.view(1, *image.size()) 64 | image = Variable(image) 65 | 66 | model.eval() 67 | preds = model(image) 68 | 69 | _, preds = preds.max(2) 70 | preds = preds.transpose(1, 0).contiguous().view(-1) 71 | 72 | preds_size = Variable(torch.IntTensor([preds.size(0)])) 73 | sim_pred = converter.decode(preds.data, preds_size.data, raw=False) 74 | # print('results: {0}'.format(sim_pred)) 75 | return sim_pred 76 | 77 | 78 | if __name__ == '__main__': 79 | # crnn network 80 | model = crnn.CRNN(32, 1, nclass, 256) 81 | # if torch.cuda.is_available(): 82 | # model = model.cuda() 83 | print('loading pretrained model from {0}'.format(crnn_model_path)) 84 | # 导入已经训练好的crnn模型 85 | model.load_state_dict(torch.load(crnn_model_path, map_location='cpu')) 86 | 87 | started = time.time() 88 | ## read an image 89 | im_fn_list = get_images() 90 | with open(os.path.join(FLAGS.output_path, "crnn_train_result_0606.csv"), 91 | "a") as f: 92 | title ='name,label'+ "\r\n" 93 | f.writelines(title) 94 | for i, im_fn in enumerate(im_fn_list): 95 | if i%1000==0: 96 | print('.................'+str(i)+'................') 97 | image = Image.open(im_fn) 98 | result = crnn_recognition(image, model) 99 | line = os.path.basename(im_fn) 100 | # print(line,result) 101 | line += ',' + result + "\r\n" 102 | f.writelines(line) 103 | 104 | finished = time.time() 105 | print('elapsed time: {0}'.format(finished - started)) 106 | -------------------------------------------------------------------------------- /1 - wei/ctpn-crnn/crnn/to_lmdb/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*-coding:utf8 -*- 3 | # @TIME :2019/3/27 下午12:09 4 | # @Author :hwwu 5 | # @File :__init__.py.py -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/README.md: -------------------------------------------------------------------------------- 1 | # OCR 2 | [第一届西安交通大学人工智能实践大赛(2018AI实践大赛--图片文字识别)](http://competition.heils.cn/main.html)冠军 3 | 4 | 5 | # 模型结果 6 | 该比赛计算每一个条目的f1score,取所有条目的平均,具体计算方式在[这里](http://competition.heils.cn/main.html)。这里的计算方式不对一句话里的相同文字重复计算,故f1score比提交的最终结果低: 7 | 8 | | - | train | val | 9 | | :----------------: | :----------------: | :----------------: | 10 | | f1score | 0.9911 | 0.9582 | 11 | | recall | 0.9943 | 0.9574 | 12 | | precision | 0.9894 | 0.9637 | 13 | 14 | # 模型说明 15 | 1. 模型 16 | 17 | 采用densenet结构,模型输入为(64×512)的图片,输出为(8×64×2159)的概率。 18 | 19 | 将图片划分为多个(8×8)的方格,在每个方格预测2159个字符的概率。 20 | 21 | 2. Loss 22 | 23 | 将(8×64×2159)的概率沿着长宽方向取最大值,得到(2159)的概率,表示这张图片里有对应字符的概率。 24 | 25 | balance: 对正例和负例分别计算loss,使得正例loss权重之和与负例loss权重之和相等,解决数据不平衡的问题。 26 | 27 | hard-mining 28 | 29 | 3. 文字检测 30 | 将(8×64×2159)的概率沿着宽方向取最大值,得到(64×2159)的概率。 31 | 沿着长方向一个个方格预测文字,然后连起来可得到一句完整的语句。 32 | 33 | 存在问题:两个连续的文字无法重复检测 34 | 35 | 下图是一个文字识别正确的示例:的长为半径作圆 36 | 37 | 38 | 39 | 下图是一个文字识别错误的示例:为10元;经粗加工后销售,每 40 | 41 | 42 | 43 | 44 | # 文件目录 45 | ocr 46 | | 47 | |--code 48 | | 49 | |--files 50 | | | 51 | | |--train.csv 52 | | 53 | |--data 54 | | 55 | |--dataset 56 | | | 57 | | |--train 58 | | | 59 | | |--test 60 | | 61 | |--result 62 | | | 63 | | |--test_result.csv 64 | | 65 | |--images 此文件夹放置任何图片均可,我放的celebA数据集用作pretrain 66 | 67 | # 运行环境 68 | Ubuntu16.04, python2.7, CUDA9.0 69 | 70 | 安装[pytorch](https://pytorch.org/), 推荐版本: 0.2.0_3 71 | ``` 72 | pip install -r requirement.txt 73 | ``` 74 | 75 | # 下载数据 76 | 从[这里](https://pan.baidu.com/s/1w0iEE7q84IolmZXwttOxVw)下载初赛、复赛数据、模型,合并训练集、测试集。 77 | 78 | 79 | # 预处理 80 | 如果不更换数据集,不需要执行这一步。 81 | 82 | 如果更换其他数据集,一并更换 files/train.csv 83 | ``` 84 | cd code/preprocessing 85 | python map_word_to_index.py 86 | python analysis_dataset.py 87 | ``` 88 | 89 | # 训练 90 | ``` 91 | cd code/ocr 92 | python main.py 93 | ``` 94 | 95 | # 测试 96 | f1score在0.9以下,lr=0.001,不使用hard-mining; 97 | 98 | f1score在0.9以上,lr=0.0001,使用hard-mining; 99 | 100 | 生成的model保存在不同的文件夹里。 101 | ``` 102 | cd code/ocr 103 | python main.py --phase test --resume ../../data/models-small/densenet/eval-16-1/best_f1score.ckpt 104 | ``` 105 | -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/code/.DS_Store -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/ocr/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/code/ocr/.DS_Store -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/ocr/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/code/ocr/tools/__init__.py -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/ocr/tools/plot.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def plot_multi_graph(image_list, name_list, save_path=None, show=False): 6 | graph_place = int(np.sqrt(len(name_list) - 1)) + 1 7 | for i, (image, name) in enumerate(zip(image_list, name_list)): 8 | ax1 = plt.subplot(graph_place,graph_place,i+1) 9 | ax1.set_title(name) 10 | # plt.imshow(image,cmap='gray') 11 | plt.imshow(image) 12 | plt.axis('off') 13 | if save_path: 14 | plt.savefig(save_path) 15 | pass 16 | if show: 17 | plt.show() 18 | 19 | def plot_multi_line(x_list, y_list, name_list, save_path=None, show=False): 20 | graph_place = int(np.sqrt(len(name_list) - 1)) + 1 21 | for i, (x, y, name) in enumerate(zip(x_list, y_list, name_list)): 22 | ax1 = plt.subplot(graph_place,graph_place,i+1) 23 | ax1.set_title(name) 24 | plt.plot(x,y) 25 | # plt.imshow(image,cmap='gray') 26 | if save_path: 27 | plt.savefig(save_path) 28 | if show: 29 | plt.show() 30 | 31 | 32 | -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/preprocessing/map_word_to_index.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | ######################################################################### 3 | # File Name: map_word_to_index.py 4 | # Author: ccyin 5 | # mail: ccyin04@gmail.com 6 | # Created Time: Fri 18 May 2018 03:30:26 PM CST 7 | ######################################################################### 8 | ''' 9 | 此代码用于将所有文字映射到index上,有两种方式 10 | 1. 映射每一个英文单词为一个index 11 | 2. 映射每一个英文字母为一个index 12 | ''' 13 | 14 | import os 15 | import sys 16 | reload(sys) 17 | sys.setdefaultencoding('utf8') 18 | import json 19 | from collections import OrderedDict 20 | 21 | def map_word_to_index(train_word_file, word_index_json, word_count_json, index_label_json, alphabet_to_index=True): 22 | with open(train_word_file, 'r') as f: 23 | labels = f.read().strip().decode('utf8') 24 | word_count_dict = { } 25 | for line in labels.split('\n')[1:]: 26 | line = line.strip() 27 | image, sentence = line.strip().split('.jpg,') 28 | sentence = sentence.strip() 29 | for w in sentence: 30 | word_count_dict[w] = word_count_dict.get(w,0) + 1 31 | print '一共有{:d}种字符,共{:d}个'.format(len(word_count_dict), sum(word_count_dict.values())) 32 | word_sorted = sorted(word_count_dict.keys(), key=lambda k:word_count_dict[k], reverse=True) 33 | word_index_dict = { w:i for i,w in enumerate(word_sorted)} 34 | # word_index_dict = json.load(open(word_index_json)) 35 | 36 | with open(word_count_json, 'w') as f: 37 | f.write(json.dumps(word_count_dict, indent=4, ensure_ascii=False)) 38 | with open(word_index_json, 'w') as f: 39 | f.write(json.dumps(word_index_dict, indent=4, ensure_ascii=False)) 40 | 41 | image_label_dict = OrderedDict() 42 | for line in labels.split('\n')[1:]: 43 | line = line.strip() 44 | image, sentence = line.strip().split('.jpg,') 45 | sentence = sentence.strip() 46 | 47 | # 换掉部分相似符号 48 | for c in u"  ": 49 | sentence = sentence.replace(c, '') 50 | replace_words = [ 51 | u'((', 52 | u'))', 53 | u',,', 54 | u"´'′", 55 | u"″"“", 56 | u"..", 57 | u"—-" 58 | ] 59 | for words in replace_words: 60 | for w in words[:-1]: 61 | sentence = sentence.replace(w, words[-1]) 62 | 63 | index_list = [] 64 | for w in sentence: 65 | index_list.append(str(word_index_dict[w])) 66 | image_label_dict[image + '.jpg'] = ' '.join(index_list) 67 | with open(index_label_json, 'w') as f: 68 | f.write(json.dumps(image_label_dict, indent=4)) 69 | 70 | 71 | def main(): 72 | 73 | # 映射字母为index 74 | train_word_file = '../../files/train.csv' 75 | word_index_json = '../../files/alphabet_index_dict.json' 76 | word_count_json = '../../files/alphabet_count_dict.json' 77 | index_label_json = '../../files/train_alphabet.json' 78 | map_word_to_index(train_word_file, word_index_json, word_count_json, index_label_json, True) 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/code/preprocessing/show_black.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | ######################################################################### 3 | # File Name: show_black.py 4 | # Author: ccyin 5 | # mail: ccyin04@gmail.com 6 | # Created Time: 2018年06月07日 星期四 01时06分22秒 7 | ######################################################################### 8 | 9 | import os 10 | import sys 11 | import json 12 | sys.path.append('../ocr') 13 | from tools import parse, py_op 14 | args = parse.args 15 | 16 | def cp_black_list(black_json, black_dir): 17 | word_index_dict = json.load(open(args.word_index_json)) 18 | index_word_dict = { v:k for k,v in word_index_dict.items() } 19 | train_word_dict = json.load(open(args.image_label_json)) 20 | train_word_dict = { k:''.join([index_word_dict[int(i)] for i in v.split()]) for k,v in train_word_dict.items() } 21 | 22 | py_op.mkdir(black_dir) 23 | black_list = json.load(open(black_json))['black_list'] 24 | for i,name in enumerate(black_list): 25 | cmd = 'cp {:s} {:s}'.format(os.path.join(args.data_dir, 'train', name), black_dir) 26 | if train_word_dict[name] in ['Err:501', '#NAME?', '###']: 27 | continue 28 | print name 29 | print train_word_dict[name] 30 | os.system(cmd) 31 | if i > 30: 32 | break 33 | 34 | if __name__ == '__main__': 35 | black_dir = os.path.join(args.save_dir, 'black') 36 | cp_black_list(args.black_json, black_dir) 37 | -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/files/.DS_Store -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/alphabet_count_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": 31768, 3 | "0": 31779, 4 | "3": 29441, 5 | "2": 29762, 6 | "5": 28435, 7 | "4": 37300, 8 | "7": 33561, 9 | "6": 31284, 10 | "9": 33105, 11 | "8": 30524, 12 | "A": 3472, 13 | "C": 3314, 14 | "B": 3993, 15 | "E": 3863, 16 | "D": 4278, 17 | "G": 2830, 18 | "F": 3127, 19 | "I": 1930, 20 | "H": 3893, 21 | "K": 3261, 22 | "J": 2271, 23 | "M": 2692, 24 | "L": 2623, 25 | "O": 1954, 26 | "N": 1909, 27 | "Q": 2917, 28 | "P": 3120, 29 | "S": 3392, 30 | "R": 3145, 31 | "U": 3855, 32 | "T": 3130, 33 | "W": 3110, 34 | "Y": 5193, 35 | "X": 4128, 36 | "Z": 1841 37 | } -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/alphabet_index_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": 3, 3 | "0": 4, 4 | "3": 8, 5 | "2": 7, 6 | "5": 9, 7 | "4": 0, 8 | "7": 1, 9 | "6": 5, 10 | "9": 2, 11 | "8": 6, 12 | "A": 18, 13 | "C": 17, 14 | "B": 14, 15 | "E": 16, 16 | "D": 11, 17 | "G": 26, 18 | "F": 24, 19 | "I": 33, 20 | "H": 15, 21 | "K": 25, 22 | "J": 30, 23 | "M": 28, 24 | "L": 29, 25 | "O": 31, 26 | "N": 32, 27 | "Q": 27, 28 | "P": 23, 29 | "S": 19, 30 | "R": 21, 31 | "U": 12, 32 | "T": 22, 33 | "W": 20, 34 | "Y": 10, 35 | "X": 13, 36 | "Z": 34 37 | } -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/src/A81.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/files/src/A81.png -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/src/B1000_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/files/src/B1000_0.png -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/files/ttf/simsun.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/1 - wei/ocr_densenet/files/ttf/simsun.ttf -------------------------------------------------------------------------------- /1 - wei/ocr_densenet/requirement.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | fuzzywuzzy 3 | numpy==1.14.2 4 | tqdm==4.19.4 5 | scikit-image==0.13.0 6 | scikit-learn==0.19.1 7 | torchvision==0.2.0 8 | scipy==0.19.0 9 | matplotlib==2.0.2 10 | -------------------------------------------------------------------------------- /2 - TitanikData/README.md: -------------------------------------------------------------------------------- 1 | 1. install the dependency packages 2 | pip3 install torch==1.1.0 3 | pip3 install lmdb pillow torchvision nltk natsort 4 | 5 | 2. use the faster rcnn to crop the ROI images 6 | 1.首先安装mmdet,pytorch为1.1.0版 github地址为:https://github.com/open-mmlab/mmdetection 7 | 2.python test.py,启动预测脚本 8 | 3.得到预测结果,python pick_picname.py 启动切片脚本,进行切片 9 | 4.切片结束,完成检测任务和训练数据准备任务 10 | 11 | 12 | 3. 13 | cd deep-text-recognition-benchmark-master 14 | Train the models 15 | CRNN: 16 | CUDA_VISIBLE_DEVICES=0 python3 train.py \ 17 | --train_data data_lmdb_release/training --valid_data data_lmdb_release/validation \ 18 | --select_data MJ-ST --batch_ratio 0.5-0.5 \ 19 | --Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC 20 | 21 | ATTENTION: 22 | CUDA_VISIBLE_DEVICES=0 python3 train.py \ 23 | --train_data data_lmdb_release/training --valid_data data_lmdb_release/validation \ 24 | --select_data MJ-ST --batch_ratio 0.5-0.5 \ 25 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn 26 | 27 | 28 | 4. Prediction 29 | use the CTC model: 30 | CUDA_VISIBLE_DEVICES=0 python3 demo.py \ 31 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \ 32 | --image_folder ****(DIY_PATH) \ 33 | --saved_model ./saved_models/TPS-ResNet-BiLSTM-CTC-Seed1111/best_accuracy.pth 34 | 35 | 36 | CUDA_VISIBLE_DEVICES=0 python3 demo.py \ 37 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \ 38 | --image_folder ****(DIY_PATH)/ \ 39 | --saved_model ./saved_models/TPS-ResNet-BiLSTM-Attn-Seed1111/best_accuracy.pth 40 | 41 | 在这次比赛中,我们发现CTC比ATTN的acc更加的高,所以我们采用的是CTC。 42 | 43 | other: When you need to create lmdb dataset 44 | pip3 install fire 45 | python3 create_lmdb_dataset.py --inputPath data/ --gtFile data/gt.txt --outputPath result/ 46 | 47 | -------------------------------------------------------------------------------- /2 - TitanikData/deep-text-recognition-benchmark-master/create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 2 | 3 | import fire 4 | import os 5 | import lmdb 6 | import cv2 7 | 8 | import numpy as np 9 | 10 | 11 | def checkImageIsValid(imageBin): 12 | if imageBin is None: 13 | return False 14 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | if imgH * imgW == 0: 18 | return False 19 | return True 20 | 21 | 22 | def writeCache(env, cache): 23 | with env.begin(write=True) as txn: 24 | for k, v in cache.items(): 25 | txn.put(k, v) 26 | 27 | 28 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 29 | """ 30 | Create LMDB dataset for training and evaluation. 31 | ARGS: 32 | inputPath : input folder path where starts imagePath 33 | outputPath : LMDB output path 34 | gtFile : list of image path and label 35 | checkValid : if true, check the validity of every image 36 | """ 37 | os.makedirs(outputPath, exist_ok=True) 38 | env = lmdb.open(outputPath, map_size=1099511627776) 39 | cache = {} 40 | cnt = 1 41 | 42 | with open(gtFile, 'r', encoding='utf-8') as data: 43 | datalist = data.readlines() 44 | 45 | nSamples = len(datalist) 46 | for i in range(nSamples): 47 | imagePath, label = datalist[i].strip('\n').split('\t') 48 | imagePath = os.path.join(inputPath, imagePath) 49 | 50 | # # only use alphanumeric data 51 | # if re.search('[^a-zA-Z0-9]', label): 52 | # continue 53 | 54 | if not os.path.exists(imagePath): 55 | print('%s does not exist' % imagePath) 56 | continue 57 | with open(imagePath, 'rb') as f: 58 | imageBin = f.read() 59 | if checkValid: 60 | try: 61 | if not checkImageIsValid(imageBin): 62 | print('%s is not a valid image' % imagePath) 63 | continue 64 | except: 65 | print('error occured', i) 66 | with open(outputPath + '/error_image_log.txt', 'a') as log: 67 | log.write('%s-th image data occured error\n' % str(i)) 68 | continue 69 | 70 | imageKey = 'image-%09d'.encode() % cnt 71 | labelKey = 'label-%09d'.encode() % cnt 72 | cache[imageKey] = imageBin 73 | cache[labelKey] = label.encode() 74 | 75 | if cnt % 1000 == 0: 76 | writeCache(env, cache) 77 | cache = {} 78 | print('Written %d / %d' % (cnt, nSamples)) 79 | cnt += 1 80 | nSamples = cnt-1 81 | cache['num-samples'.encode()] = str(nSamples).encode() 82 | writeCache(env, cache) 83 | print('Created dataset with %d samples' % nSamples) 84 | 85 | 86 | if __name__ == '__main__': 87 | fire.Fire(createDataset) 88 | -------------------------------------------------------------------------------- /2 - TitanikData/deep-text-recognition-benchmark-master/modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /2 - TitanikData/deep-text-recognition-benchmark-master/predict.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python demo.py \ 2 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \ 3 | --image_folder /home/dingzy/TorchModel_new/jue_5/ \ 4 | --saved_model saved_models/TPS-ResNet-BiLSTM-CTC-Seed1111/best_accuracy.pth 5 | -------------------------------------------------------------------------------- /2 - TitanikData/检测/README.txt: -------------------------------------------------------------------------------- 1 | 1.首先安装mmdet,pytorch为1.1.0版 github地址为:https://github.com/open-mmlab/mmdetection 2 | 2.python test.py,启动预测脚本 3 | 3.得到预测结果,python pick_picname.py 启动切片脚本,进行切片 4 | 4.切片结束,完成检测任务和训练数据准备任务 -------------------------------------------------------------------------------- /2 - TitanikData/检测/pick_picname.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import cv2 4 | import os 5 | 6 | #预测结果的txt文件 7 | f = open("./result_rcnn.txt",'r') 8 | #待预测的图片 9 | ori_root = '../private_test_data' 10 | #切片后图片存储的地址 11 | save_root = '../juesai_t' 12 | for line in f.readlines(): 13 | line_split = line.strip().split(' ') 14 | pic_name = line_split[0] 15 | pic_path = os.path.join(ori_root,pic_name) 16 | pic_cor = line_split[1].split(',') 17 | img = cv2.imread(pic_path) 18 | img_split = img[int(pic_cor[1]):int(pic_cor[3]),int(pic_cor[0]):int(pic_cor[2])] 19 | save_path = os.path.join(save_root,pic_name) 20 | cv2.imwrite(save_path,img_split) 21 | print('{} is saved!'.format(pic_name)) 22 | -------------------------------------------------------------------------------- /2 - TitanikData/检测/test.py: -------------------------------------------------------------------------------- 1 | from mmdet.apis import init_detector, inference_detector, show_result 2 | import os 3 | import cv2 4 | 5 | #配置文件 6 | config_file = './faster_rcnn_r50_fpn_1x_voc0712.py' 7 | #训练模型的地址 8 | checkpoint_file = './faster_rcnn_r50_fpn_1x_voc0712/epoch_4.pth' 9 | #记录坐标的txt文件的地址 10 | f = open('./result_rcnn.txt','w') 11 | 12 | # build the model from a config file and a checkpoint file 13 | model = init_detector(config_file, checkpoint_file, device='cuda:0') 14 | 15 | # test a single image and show the results 16 | #待预测图片文件夹的地址 17 | for root,dirs,files in os.walk('../private_test_data'): 18 | for file in files: 19 | pic_path = os.path.join(root,file) 20 | img = cv2.imread(pic_path) 21 | result = inference_detector(model, img) 22 | print(result[0][0]) 23 | try: 24 | length = len(result[0][0]) 25 | except IndexError: 26 | f.write(file) 27 | f.write('\n') 28 | continue 29 | f.write(file) 30 | f.write(' ') 31 | f.write(str(int(result[0][0][0]))) 32 | f.write(',') 33 | f.write(str(int(result[0][0][1]))) 34 | f.write(',') 35 | f.write(str(int(result[0][0][2]))) 36 | f.write(',') 37 | f.write(str(int(result[0][0][3]))) 38 | f.write('\n') 39 | f.flush() 40 | 41 | -------------------------------------------------------------------------------- /3 - TechDing/.idea/RMB_TechDing.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /3 - TechDing/.idea/dictionaries/pc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /3 - TechDing/.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /3 - TechDing/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /3 - TechDing/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /3 - TechDing/CTC_Models/__init__.py: -------------------------------------------------------------------------------- 1 | from CTC_Models.Utils import load_roi_setX,ensemble 2 | from CTC_Models.model_predict import ctc_model_predict -------------------------------------------------------------------------------- /3 - TechDing/CTC_Models/model_predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | """ 4 | # version: python3.6 5 | # author: TechDing 6 | # email: dingustb@126.com 7 | # file: model_predict.py 8 | # time: 2019/7/2 10:54 9 | # doc: 使用成熟的CTC_Models来预测roi_setX.npy,并将结果保存 10 | """ 11 | 12 | import string 13 | CHARS = string.digits + string.ascii_uppercase 14 | 15 | import numpy as np 16 | import keras.backend as K 17 | def pred_imgs(base_model,setX): 18 | ''' 19 | 使用base_model来预测roi图片 20 | :param base_model: 已经训练好的model 21 | :param setX: roi图片组成的数据,必须为0-255, uint8,shape为(bs,80,240,3) 22 | :return: 23 | ''' 24 | def predict_batch(batchX): 25 | testX2=batchX.astype('float64')/255 26 | testX2=testX2.transpose(0,2,1,3) 27 | y_pred=base_model.predict(testX2) 28 | y_pred = y_pred[:,2:,:] # (bs, 24,37) 29 | out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0])*y_pred.shape[1], )[0][0])[:, :10] 30 | return [''.join([CHARS[i] for i in line]) for line in out] 31 | 32 | all_results=[] 33 | if len(setX)>1000: 34 | for i in range(0,len(setX),1000): 35 | end=min(len(setX),i+1000) 36 | batchX=setX[i:end] 37 | all_results.extend(predict_batch(batchX)) 38 | else: 39 | all_results=predict_batch(setX) 40 | return all_results 41 | 42 | from keras.models import load_model 43 | import pandas as pd 44 | import os 45 | def ctc_model_predict(model_type,test_roiX_path,test_names_path,result_save_path): 46 | ''' 47 | 使用成熟的ctc模型来预测roi图片,该roi图片以npy的形式保存到test_roiX_path中,图片对应的名称为test_names_path中 48 | :param model_type: 所使用的模型类型,必须为{cv,ctpn1,ctpn2}三者之一 49 | :param test_roiX_path: roi图片,以npy的形式保存,0-255,uint8 50 | :param test_names_path: 图片名称保存的路径 51 | :param result_save_path: 最终预测的结果保存的路径,该结果为df,columns=['name', 'label'] 52 | :return: 53 | ''' 54 | # print('start to predict roi imgs by ctc_model') 55 | models_dict={'cv':os.path.abspath('./CTC_Models/Models/CV_Model.h5'), 56 | 'ctpn1':os.path.abspath('./CTC_Models/Models/FTPN_Model1.h5'), 57 | 'ctpn2':os.path.abspath('./CTC_Models/Models/FTPN_Model2.h5')} 58 | ctc_model_path=models_dict.get(model_type,'Error model path') 59 | testX=np.load(test_roiX_path) 60 | img_names=np.load(test_names_path) 61 | ctc_model = load_model(ctc_model_path) # 模型加载非常耗时,需要注意 62 | predicted=pred_imgs(ctc_model,testX) 63 | result=np.c_[img_names,np.array(predicted)] 64 | df = pd.DataFrame(result, columns=['name', 'label']) 65 | df.to_csv(result_save_path,index=False) 66 | print('predicted result is saved to {}'.format(result_save_path)) 67 | -------------------------------------------------------------------------------- /3 - TechDing/CTPN_ROI/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctpn import get_save_ctpn_roi -------------------------------------------------------------------------------- /3 - TechDing/CTPN_ROI/checkpoints/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "VGGnet_fast_rcnn_iter_93000.ckpt" 2 | all_model_checkpoint_paths: "VGGnet_fast_rcnn_iter_93000.ckpt" 3 | -------------------------------------------------------------------------------- /3 - TechDing/CTPN_ROI/ctpn/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_ROI_imgs import get_save_ctpn_roi 2 | -------------------------------------------------------------------------------- /3 - TechDing/CTPN_ROI/ctpn/text.yml: -------------------------------------------------------------------------------- 1 | DATA_DIR: /home/ray/DataSet/RMB/dataset # dataset path 2 | ROOT_DIR: /home/ray/DataSet/RMB/ctpn_trained # trained models and logs folder 3 | EXP_DIR: ctpn_rmb 4 | LOG_DIR: ctpn_log 5 | IS_MULTISCALE: False 6 | NET_NAME: VGGnet 7 | ANCHOR_SCALES: [16] 8 | NCLASSES: 2 9 | USE_GPU_NMS: True # whether to use nms implemented in cuda or not, in Windows must be False 10 | TEST: 11 | HAS_RPN: True 12 | DETECT_MODE: H 13 | # H represents horizontal mode, O represents oriented mode, default is H 14 | # checkpoints_path: E:\PyProjects\Codes\text-detection-ctpn-win/checkpoints\ 15 | # checkpoints_path: /home/ray/DataSet/RMB/ctpn_trained/output/ctpn_rmb/rmb_db 16 | # the model I provided is in checkpoints/, if you train the model by yourself,it will be saved in output/ 17 | -------------------------------------------------------------------------------- /3 - TechDing/CTPN_ROI/说明.txt: -------------------------------------------------------------------------------- 1 | 2 | 1. 本模块的功能:使用CTPN深度模型来提取出图像中的编码所在区域(ROI), 3 | CTPN模型是基于Faster_RCNN基础上发展起来的,网上有很多成熟代码 4 | 5 | 2. 本部分需要在Linux环境下运行,使用Tensorflow-gpu模块。 6 | 7 | 3. 提取的ROI图片保存到临时文件夹中,便于后续用CTC_Models来识别该ROI图片中的编码。 -------------------------------------------------------------------------------- /3 - TechDing/CV_ROI/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from CV_ROI.extract_ROI import get_save_roi -------------------------------------------------------------------------------- /3 - TechDing/CV_ROI/说明.txt: -------------------------------------------------------------------------------- 1 | 2 | 1. 本模块功能:使用OpenCV(而不是深度学习模型)的方式提取出RMB图片中的编码所在区域(ROI). 3 | 4 | 2. 其基本思想是:首先建立一个面值识别模型,识别出图片的面值类型(9种面值的一种), 5 | 然后对不同面值图片通过转换为灰度图,直方图均衡化,膨胀腐蚀等操作,获取出图片的外边框, 6 | 由于每种面值图片的ROI位于图片固定区域,所以用固定的像素截取即可获取ROI. 7 | 8 | 3. 将ROI图片保存到一个临时文件夹中,便于后续用CTC_Models来识别该ROI图片中的编码。 -------------------------------------------------------------------------------- /3 - TechDing/config.cfg: -------------------------------------------------------------------------------- 1 | [PATH] 2 | src_imgs_folder=/home/ray/DataSet/RMB/private_test/private_test_data 3 | # src_imgs_folder=./TestImgs 4 | 5 | # final_csv_path=/home/ray/DataSet/RMB/private_test/final_submit_result.csv 6 | # final_csv_path=./test_result.csv 7 | final_csv_path=./private_test_final_submit.csv 8 | -------------------------------------------------------------------------------- /3 - TechDing/test_result.csv: -------------------------------------------------------------------------------- 1 | name,label 2 | 0A2TRC3N.jpg,WE40869855 3 | 0ABCNHXI.jpg,QD77693991 4 | 0ABZLHS8.jpg,U7D1015703 5 | 0A8IHWYD.jpg,BW66783608 6 | 0A8EBTSG.jpg,U7D1039466 7 | 0A8CRDOV.jpg,K4X4348976 8 | 0A2PDULI.jpg,B8Y4976754 9 | 0A2TX8IH.jpg,BZ22158519 10 | 0ACGDL68.jpg,XD53669071 11 | 0A4DSPGE.jpg,UD18741546 12 | -------------------------------------------------------------------------------- /3 - TechDing/使用方法.txt: -------------------------------------------------------------------------------- 1 | 本项目是专门用于TinyMind人民币编码识别项目,团队名称:TechDing, 最终排名:第三名。联系方式:dingustb@126.com 2 | 3 | 使用方法: 4 | 5 | 1. 首先vim打开config.cfg,设置src_imgs_folder为需要预测的图片文件夹目录, 6 | 里面的所有jpg图片都会识别其编码,其他类型的图片不会识别。 7 | 设置final_csv_path为识别完成之后的csv结果的路径,必须以.csv结束。 8 | 上面两个路径最好都是绝对路径。 9 | 10 | 2. 开始运行:直接用 python Main.py即可运行,如果一切顺利,最终会打印出 "GOOD!. All Finished!!!"信息, 11 | 且最终的预测结果保存到final_csv_path指向的路径。 12 | 13 | 3. 注意:经测试,本项目可以运行在:Ubuntu1604系统, Python 3.6,tensorflow-gpu==1.13.1,keras==2.2.4, numpy,pandas等. 14 | 其他tensorflow-gpu版本或keras版本可能运行正常,但不确定。 15 | 在Windows下无法运行。 16 | 在预测时默认使用第0个GPU来运行,需要注意该GPU没有被其他应用程序占用。 -------------------------------------------------------------------------------- /4 - HLearning/readme.md: -------------------------------------------------------------------------------- 1 | ## 第一部分: 编码切割 2 | text-detection-ctpn: 编码位置识别代码, 主要来源于github:https://github.com/eragonruan/text-detection-ctpn 3 | 将最后的输入内容应用到了图像上, 进行了切图。 4 | 手动标记训练数据1000张, 训练的log日志, 文件夹中有 5 | 6 | ## 第二部分: 编码识别 7 | 代码: CRNN.ipynb 8 | 编码识别主要采用CNN + RNN + CTC 9 | 分别采用不同的CNN网络去提取图片特征, 进行训练, 最后进行预测 10 | 11 | ## 第三部分: 结果融合采用了4个主干网络, 对预测的60中结果进行投票融合, 得到最终结果 12 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 shaohui ruan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/nets/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def vgg_arg_scope(weight_decay=0.0005): 7 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 8 | activation_fn=tf.nn.relu, 9 | weights_regularizer=slim.l2_regularizer(weight_decay), 10 | biases_initializer=tf.zeros_initializer()): 11 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 12 | return arg_sc 13 | 14 | 15 | def vgg_16(inputs, scope='vgg_16'): 16 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 17 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d]): 18 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 19 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 20 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 21 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 22 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 23 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 24 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 25 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 26 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 27 | 28 | return net 29 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | opencv-python 3 | Shapely 4 | matplotlib 5 | numpy 6 | tensorflow-gpu 7 | Cython 8 | ipython 9 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/bbox/.fuse_hidden0000c4ff00000003: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/4 - HLearning/text-detection-ctpn/utils/bbox/.fuse_hidden0000c4ff00000003 -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/bbox/bbox.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | DTYPE = np.float 13 | ctypedef np.float_t DTYPE_t 14 | 15 | def bbox_overlaps( 16 | np.ndarray[DTYPE_t, ndim=2] boxes, 17 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 18 | """ 19 | Parameters 20 | ---------- 21 | boxes: (N, 4) ndarray of float 22 | query_boxes: (K, 4) ndarray of float 23 | Returns 24 | ------- 25 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 26 | """ 27 | cdef unsigned int N = boxes.shape[0] 28 | cdef unsigned int K = query_boxes.shape[0] 29 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) 30 | cdef DTYPE_t iw, ih, box_area 31 | cdef DTYPE_t ua 32 | cdef unsigned int k, n 33 | for k in range(K): 34 | box_area = ( 35 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 36 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 37 | ) 38 | for n in range(N): 39 | iw = ( 40 | min(boxes[n, 2], query_boxes[k, 2]) - 41 | max(boxes[n, 0], query_boxes[k, 0]) + 1 42 | ) 43 | if iw > 0: 44 | ih = ( 45 | min(boxes[n, 3], query_boxes[k, 3]) - 46 | max(boxes[n, 1], query_boxes[k, 1]) + 1 47 | ) 48 | if ih > 0: 49 | ua = float( 50 | (boxes[n, 2] - boxes[n, 0] + 1) * 51 | (boxes[n, 3] - boxes[n, 1] + 1) + 52 | box_area - iw * ih 53 | ) 54 | overlaps[n, k] = iw * ih / ua 55 | return overlaps 56 | 57 | def bbox_intersections( 58 | np.ndarray[DTYPE_t, ndim=2] boxes, 59 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 60 | """ 61 | For each query box compute the intersection ratio covered by boxes 62 | ---------- 63 | Parameters 64 | ---------- 65 | boxes: (N, 4) ndarray of float 66 | query_boxes: (K, 4) ndarray of float 67 | Returns 68 | ------- 69 | overlaps: (N, K) ndarray of intersec between boxes and query_boxes 70 | """ 71 | cdef unsigned int N = boxes.shape[0] 72 | cdef unsigned int K = query_boxes.shape[0] 73 | cdef np.ndarray[DTYPE_t, ndim=2] intersec = np.zeros((N, K), dtype=DTYPE) 74 | cdef DTYPE_t iw, ih, box_area 75 | cdef DTYPE_t ua 76 | cdef unsigned int k, n 77 | for k in range(K): 78 | box_area = ( 79 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 80 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 81 | ) 82 | for n in range(N): 83 | iw = ( 84 | min(boxes[n, 2], query_boxes[k, 2]) - 85 | max(boxes[n, 0], query_boxes[k, 0]) + 1 86 | ) 87 | if iw > 0: 88 | ih = ( 89 | min(boxes[n, 3], query_boxes[k, 3]) - 90 | max(boxes[n, 1], query_boxes[k, 1]) + 1 91 | ) 92 | if ih > 0: 93 | intersec[n, k] = iw * ih / box_area 94 | return intersec -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/bbox/bbox_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_transform(ex_rois, gt_rois): 5 | """ 6 | computes the distance from ground-truth boxes to the given boxes, normed by their size 7 | :param ex_rois: n * 4 numpy array, given boxes 8 | :param gt_rois: n * 4 numpy array, ground-truth boxes 9 | :return: deltas: n * 4 numpy array, ground-truth boxes 10 | """ 11 | ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 12 | ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 13 | ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths 14 | ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights 15 | 16 | assert np.min(ex_widths) > 0.1 and np.min(ex_heights) > 0.1, \ 17 | 'Invalid boxes found: {} {}'.format(ex_rois[np.argmin(ex_widths), :], ex_rois[np.argmin(ex_heights), :]) 18 | 19 | gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0 20 | gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0 21 | gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths 22 | gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights 23 | 24 | # warnings.catch_warnings() 25 | # warnings.filterwarnings('error') 26 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths 27 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights 28 | targets_dw = np.log(gt_widths / ex_widths) 29 | targets_dh = np.log(gt_heights / ex_heights) 30 | 31 | targets = np.vstack( 32 | (targets_dx, targets_dy, targets_dw, targets_dh)).transpose() 33 | 34 | return targets 35 | 36 | 37 | def bbox_transform_inv(boxes, deltas): 38 | boxes = boxes.astype(deltas.dtype, copy=False) 39 | 40 | widths = boxes[:, 2] - boxes[:, 0] + 1.0 41 | heights = boxes[:, 3] - boxes[:, 1] + 1.0 42 | ctr_x = boxes[:, 0] + 0.5 * widths 43 | ctr_y = boxes[:, 1] + 0.5 * heights 44 | 45 | dx = deltas[:, 0::4] 46 | dy = deltas[:, 1::4] 47 | dw = deltas[:, 2::4] 48 | dh = deltas[:, 3::4] 49 | 50 | pred_ctr_x = ctr_x[:, np.newaxis] 51 | pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis] 52 | pred_w = widths[:, np.newaxis] 53 | pred_h = np.exp(dh) * heights[:, np.newaxis] 54 | 55 | pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype) 56 | # x1 57 | pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w 58 | # y1 59 | pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h 60 | # x2 61 | pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w 62 | # y2 63 | pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h 64 | 65 | return pred_boxes 66 | 67 | 68 | def clip_boxes(boxes, im_shape): 69 | """ 70 | Clip boxes to image boundaries. 71 | """ 72 | # x1 >= 0 73 | boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) 74 | # y1 >= 0 75 | boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) 76 | # x2 < im_shape[1] 77 | boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) 78 | # y2 < im_shape[0] 79 | boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) 80 | return boxes 81 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/bbox/make.sh: -------------------------------------------------------------------------------- 1 | python setup.py install 2 | mv build/*/*.so ./ 3 | rm -rf build/ -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/bbox/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | import numpy as np 4 | from Cython.Build import cythonize 5 | 6 | numpy_include = np.get_include() 7 | setup(ext_modules=cythonize("bbox.pyx"), include_dirs=[numpy_include]) 8 | setup(ext_modules=cythonize("nms.pyx"), include_dirs=[numpy_include]) 9 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/dataset/data_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | import time 4 | 5 | import numpy as np 6 | 7 | try: 8 | import queue 9 | except ImportError: 10 | import Queue as queue 11 | 12 | 13 | class GeneratorEnqueuer(): 14 | def __init__(self, generator, 15 | use_multiprocessing=False, 16 | wait_time=0.05, 17 | random_seed=None): 18 | self.wait_time = wait_time 19 | self._generator = generator 20 | self._use_multiprocessing = use_multiprocessing 21 | self._threads = [] 22 | self._stop_event = None 23 | self.queue = None 24 | self.random_seed = random_seed 25 | 26 | def start(self, workers=1, max_queue_size=10): 27 | def data_generator_task(): 28 | while not self._stop_event.is_set(): 29 | try: 30 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 31 | generator_output = next(self._generator) 32 | self.queue.put(generator_output) 33 | else: 34 | time.sleep(self.wait_time) 35 | except Exception: 36 | self._stop_event.set() 37 | raise 38 | 39 | try: 40 | if self._use_multiprocessing: 41 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 42 | self._stop_event = multiprocessing.Event() 43 | else: 44 | self.queue = queue.Queue() 45 | self._stop_event = threading.Event() 46 | 47 | for _ in range(workers): 48 | if self._use_multiprocessing: 49 | # Reset random seed else all children processes 50 | # share the same seed 51 | np.random.seed(self.random_seed) 52 | thread = multiprocessing.Process(target=data_generator_task) 53 | thread.daemon = True 54 | if self.random_seed is not None: 55 | self.random_seed += 1 56 | else: 57 | thread = threading.Thread(target=data_generator_task) 58 | self._threads.append(thread) 59 | thread.start() 60 | except: 61 | self.stop() 62 | raise 63 | 64 | def is_running(self): 65 | return self._stop_event is not None and not self._stop_event.is_set() 66 | 67 | def stop(self, timeout=None): 68 | if self.is_running(): 69 | self._stop_event.set() 70 | 71 | for thread in self._threads: 72 | if thread.is_alive(): 73 | if self._use_multiprocessing: 74 | thread.terminate() 75 | else: 76 | thread.join(timeout) 77 | 78 | if self._use_multiprocessing: 79 | if self.queue is not None: 80 | self.queue.close() 81 | 82 | self._threads = [] 83 | self._stop_event = None 84 | self.queue = None 85 | 86 | def get(self): 87 | while self.is_running(): 88 | if not self.queue.empty(): 89 | inputs = self.queue.get() 90 | if inputs is not None: 91 | yield inputs 92 | else: 93 | time.sleep(self.wait_time) 94 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/prepare/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | def pickTopLeft(poly): 6 | idx = np.argsort(poly[:, 0]) 7 | if poly[idx[0], 1] < poly[idx[1], 1]: 8 | s = idx[0] 9 | else: 10 | s = idx[1] 11 | 12 | return poly[(s, (s + 1) % 4, (s + 2) % 4, (s + 3) % 4), :] 13 | 14 | 15 | def orderConvex(p): 16 | points = Polygon(p).convex_hull 17 | points = np.array(points.exterior.coords)[:4] 18 | points = points[::-1] 19 | points = pickTopLeft(points) 20 | points = np.array(points).reshape([4, 2]) 21 | return points 22 | 23 | 24 | def shrink_poly(poly, r=16): 25 | # y = kx + b 26 | x_min = int(np.min(poly[:, 0])) 27 | x_max = int(np.max(poly[:, 0])) 28 | 29 | k1 = (poly[1][1] - poly[0][1]) / (poly[1][0] - poly[0][0]) 30 | b1 = poly[0][1] - k1 * poly[0][0] 31 | 32 | k2 = (poly[2][1] - poly[3][1]) / (poly[2][0] - poly[3][0]) 33 | b2 = poly[3][1] - k2 * poly[3][0] 34 | 35 | res = [] 36 | 37 | start = int((x_min // 16 + 1) * 16) 38 | end = int((x_max // 16) * 16) 39 | 40 | p = x_min 41 | res.append([p, int(k1 * p + b1), 42 | start - 1, int(k1 * (p + 15) + b1), 43 | start - 1, int(k2 * (p + 15) + b2), 44 | p, int(k2 * p + b2)]) 45 | 46 | for p in range(start, end + 1, r): 47 | res.append([p, int(k1 * p + b1), 48 | (p + 15), int(k1 * (p + 15) + b1), 49 | (p + 15), int(k2 * (p + 15) + b2), 50 | p, int(k2 * p + b2)]) 51 | return np.array(res, dtype=np.int).reshape([-1, 8]) 52 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/rpn_msr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/4 - HLearning/text-detection-ctpn/utils/rpn_msr/__init__.py -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/rpn_msr/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | EPS = 1e-14 3 | RPN_CLOBBER_POSITIVES = False 4 | RPN_NEGATIVE_OVERLAP = 0.3 5 | RPN_POSITIVE_OVERLAP = 0.7 6 | RPN_FG_FRACTION = 0.5 7 | RPN_BATCHSIZE = 300 8 | RPN_BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0) 9 | RPN_POSITIVE_WEIGHT = -1.0 10 | 11 | RPN_PRE_NMS_TOP_N = 12000 12 | RPN_POST_NMS_TOP_N = 1000 13 | RPN_NMS_THRESH = 0.7 14 | RPN_MIN_SIZE = 8 15 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/rpn_msr/generate_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generate_basic_anchors(sizes, base_size=16): 5 | base_anchor = np.array([0, 0, base_size - 1, base_size - 1], np.int32) 6 | anchors = np.zeros((len(sizes), 4), np.int32) 7 | index = 0 8 | for h, w in sizes: 9 | anchors[index] = scale_anchor(base_anchor, h, w) 10 | index += 1 11 | return anchors 12 | 13 | 14 | def scale_anchor(anchor, h, w): 15 | x_ctr = (anchor[0] + anchor[2]) * 0.5 16 | y_ctr = (anchor[1] + anchor[3]) * 0.5 17 | scaled_anchor = anchor.copy() 18 | scaled_anchor[0] = x_ctr - w / 2 # xmin 19 | scaled_anchor[2] = x_ctr + w / 2 # xmax 20 | scaled_anchor[1] = y_ctr - h / 2 # ymin 21 | scaled_anchor[3] = y_ctr + h / 2 # ymax 22 | return scaled_anchor 23 | 24 | 25 | def generate_anchors(base_size=16, ratios=[0.5, 1, 2], 26 | scales=2 ** np.arange(3, 6)): 27 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 28 | widths = [16] 29 | sizes = [] 30 | for h in heights: 31 | for w in widths: 32 | sizes.append((h, w)) 33 | return generate_basic_anchors(sizes) 34 | 35 | 36 | if __name__ == '__main__': 37 | import time 38 | 39 | t = time.time() 40 | a = generate_anchors() 41 | print(time.time() - t) 42 | print(a) 43 | from IPython import embed; 44 | 45 | embed() 46 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/text_connector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/4 - HLearning/text-detection-ctpn/utils/text_connector/__init__.py -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/text_connector/detectors.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import numpy as np 3 | from utils.bbox.nms import nms 4 | 5 | from .text_connect_cfg import Config as TextLineCfg 6 | from .text_proposal_connector import TextProposalConnector 7 | from .text_proposal_connector_oriented import TextProposalConnector as TextProposalConnectorOriented 8 | 9 | 10 | class TextDetector: 11 | def __init__(self, DETECT_MODE="H"): 12 | self.mode = DETECT_MODE 13 | if self.mode == "H": 14 | self.text_proposal_connector = TextProposalConnector() 15 | elif self.mode == "O": 16 | self.text_proposal_connector = TextProposalConnectorOriented() 17 | 18 | def detect(self, text_proposals, scores, size): 19 | # 删除得分较低的proposal 20 | keep_inds = np.where(scores > TextLineCfg.TEXT_PROPOSALS_MIN_SCORE)[0] 21 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 22 | 23 | # 按得分排序 24 | sorted_indices = np.argsort(scores.ravel())[::-1] 25 | text_proposals, scores = text_proposals[sorted_indices], scores[sorted_indices] 26 | 27 | # 对proposal做nms 28 | keep_inds = nms(np.hstack((text_proposals, scores)), TextLineCfg.TEXT_PROPOSALS_NMS_THRESH) 29 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 30 | 31 | # 获取检测结果 32 | text_recs = self.text_proposal_connector.get_text_lines(text_proposals, scores, size) 33 | keep_inds = self.filter_boxes(text_recs) 34 | return text_recs[keep_inds] 35 | 36 | def filter_boxes(self, boxes): 37 | heights = np.zeros((len(boxes), 1), np.float) 38 | widths = np.zeros((len(boxes), 1), np.float) 39 | scores = np.zeros((len(boxes), 1), np.float) 40 | index = 0 41 | for box in boxes: 42 | heights[index] = (abs(box[5] - box[1]) + abs(box[7] - box[3])) / 2.0 + 1 43 | widths[index] = (abs(box[2] - box[0]) + abs(box[6] - box[4])) / 2.0 + 1 44 | scores[index] = box[8] 45 | index += 1 46 | 47 | return np.where((widths / heights > TextLineCfg.MIN_RATIO) & (scores > TextLineCfg.LINE_MIN_SCORE) & 48 | (widths > (TextLineCfg.TEXT_PROPOSALS_WIDTH * TextLineCfg.MIN_NUM_PROPOSALS)))[0] 49 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/text_connector/other.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def threshold(coords, min_, max_): 5 | return np.maximum(np.minimum(coords, max_), min_) 6 | 7 | 8 | def clip_boxes(boxes, im_shape): 9 | """ 10 | Clip boxes to image boundaries. 11 | """ 12 | boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1) 13 | boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1) 14 | return boxes 15 | 16 | 17 | class Graph: 18 | def __init__(self, graph): 19 | self.graph = graph 20 | 21 | def sub_graphs_connected(self): 22 | sub_graphs = [] 23 | for index in range(self.graph.shape[0]): 24 | if not self.graph[:, index].any() and self.graph[index, :].any(): 25 | v = index 26 | sub_graphs.append([v]) 27 | while self.graph[v, :].any(): 28 | v = np.where(self.graph[v, :])[0][0] 29 | sub_graphs[-1].append(v) 30 | return sub_graphs 31 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/text_connector/text_connect_cfg.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | MAX_HORIZONTAL_GAP = 50 3 | TEXT_PROPOSALS_MIN_SCORE = 0.7 4 | TEXT_PROPOSALS_NMS_THRESH = 0.2 5 | MIN_V_OVERLAPS = 0.7 6 | MIN_SIZE_SIM = 0.7 7 | MIN_RATIO = 0.5 8 | LINE_MIN_SCORE = 0.9 9 | TEXT_PROPOSALS_WIDTH = 16 10 | MIN_NUM_PROPOSALS = 2 11 | -------------------------------------------------------------------------------- /4 - HLearning/text-detection-ctpn/utils/text_connector/text_proposal_connector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.text_connector.other import clip_boxes 4 | from utils.text_connector.text_proposal_graph_builder import TextProposalGraphBuilder 5 | 6 | 7 | class TextProposalConnector: 8 | def __init__(self): 9 | self.graph_builder = TextProposalGraphBuilder() 10 | 11 | def group_text_proposals(self, text_proposals, scores, im_size): 12 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 13 | return graph.sub_graphs_connected() 14 | 15 | def fit_y(self, X, Y, x1, x2): 16 | len(X) != 0 17 | # if X only include one point, the function will get line y=Y[0] 18 | if np.sum(X == X[0]) == len(X): 19 | return Y[0], Y[0] 20 | p = np.poly1d(np.polyfit(X, Y, 1)) 21 | return p(x1), p(x2) 22 | 23 | def get_text_lines(self, text_proposals, scores, im_size): 24 | # tp=text proposal 25 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) 26 | text_lines = np.zeros((len(tp_groups), 5), np.float32) 27 | 28 | for index, tp_indices in enumerate(tp_groups): 29 | text_line_boxes = text_proposals[list(tp_indices)] 30 | 31 | x0 = np.min(text_line_boxes[:, 0]) 32 | x1 = np.max(text_line_boxes[:, 2]) 33 | 34 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 35 | 36 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) 37 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) 38 | 39 | # the score of a text line is the average score of the scores 40 | # of all text proposals contained in the text line 41 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) 42 | 43 | text_lines[index, 0] = x0 44 | text_lines[index, 1] = min(lt_y, rt_y) 45 | text_lines[index, 2] = x1 46 | text_lines[index, 3] = max(lb_y, rb_y) 47 | text_lines[index, 4] = score 48 | 49 | text_lines = clip_boxes(text_lines, im_size) 50 | 51 | text_recs = np.zeros((len(text_lines), 9), np.float) 52 | index = 0 53 | for line in text_lines: 54 | xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3] 55 | text_recs[index, 0] = xmin 56 | text_recs[index, 1] = ymin 57 | text_recs[index, 2] = xmax 58 | text_recs[index, 3] = ymin 59 | text_recs[index, 4] = xmax 60 | text_recs[index, 5] = ymax 61 | text_recs[index, 6] = xmin 62 | text_recs[index, 7] = ymax 63 | text_recs[index, 8] = line[4] 64 | index = index + 1 65 | 66 | return text_recs 67 | -------------------------------------------------------------------------------- /5 - ResNet34/README.md: -------------------------------------------------------------------------------- 1 | 运行环境: 2 | pytorch、sklearn、albumentations、cv2、numpy、pandas 3 | 4 | 以下是运行说明,大概需要运行1小时左右。 5 | 6 | cd multi-digit-pytorch/ 7 | python 2_predict.py 8 | 9 | cd ../crnn-pytorch/ 10 | python test2_tta.py --snapshot tmp/crnn_resnet18_0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_best --visualize False --data-path ../data/ 11 | python submit.py 12 | 13 | 生成的`tmp_rcnn_tta10_pb_submit.csv`就是最终的提交文件。 -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2008, Sergei Belousov aka BeS (belbes122@yandex.ru) 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER 18 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/README.md: -------------------------------------------------------------------------------- 1 | Convolutional Recurrent Neural Network 2 | ====================================== 3 | 4 | This software implements OCR system using CNN + RNN + CTCLoss, inspired by CRNN network. 5 | 6 | Usage 7 | ----- 8 | 9 | ` 10 | python ./train.py --help 11 | ` 12 | 13 | Demo 14 | ---- 15 | 16 | 1. Train simple OCR using TestDataset data generator. 17 | Training for ~60-100 epochs. 18 | ``` 19 | python train.py --test-init True --test-epoch 10 --output-dir 20 | ``` 21 | 22 | 2. Run test for trained model with visualization mode. 23 | ``` 24 | python test.py --snapshot /crnn_resnet18_10_best --visualize True 25 | ``` 26 | 27 | Train on custom dataset 28 | ----------------------- 29 | 30 | 1. Create dataset 31 | 32 | - Structure of dataset: 33 | ``` 34 | 35 | ---- data 36 | -------- 37 | ... 38 | -------- 39 | ---- desc.json 40 | ``` 41 | 42 | - Structure of desc.json: 43 | ``` 44 | { 45 | "abc": , 46 | "train": [ 47 | { 48 | "text": 49 | "name": 50 | }, 51 | ... 52 | { 53 | "text": 54 | "name": 55 | } 56 | ], 57 | "test": [ 58 | { 59 | "text": 60 | "name": 61 | }, 62 | ... 63 | { 64 | "text": 65 | "name": 66 | } 67 | ] 68 | } 69 | ``` 70 | 71 | 2. Train simple OCR using custom dataset. 72 | ``` 73 | python train.pt --test-init True --test-epoch 10 --output-dir --data-path 74 | ``` 75 | 76 | 3. Run test for trained model with visualization mode. 77 | ``` 78 | python test.py --snapshot /crnn_resnet18_10_best --visualize True --data-path 79 | ``` 80 | 81 | 82 | Dependence 83 | ---------- 84 | * pytorch 0.3.0 + 85 | * [warp-ctc](https://github.com/SeanNaren/warp-ctc) 86 | 87 | Articles 88 | -------- 89 | 90 | * [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/abs/1507.05717) 91 | * [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](https://dl.acm.org/citation.cfm?id=1143891) 92 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/5 - ResNet34/crnn-pytorch/__init__.py -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/5 - ResNet34/crnn-pytorch/dataset/__init__.py -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/dataset/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def text_collate(batch): 5 | img = list() 6 | seq = list() 7 | seq_len = list() 8 | for sample in batch: 9 | img.append(torch.from_numpy(sample["img"].transpose((2, 0, 1))).float()) 10 | seq.extend(sample["seq"]) 11 | seq_len.append(sample["seq_len"]) 12 | img = torch.stack(img) 13 | seq = torch.Tensor(seq).int() 14 | seq_len = torch.Tensor(seq_len).int() 15 | batch = {"img": img, "seq": seq, "seq_len": seq_len} 16 | return batch 17 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/dataset/test_data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import string 6 | import random 7 | 8 | class TestDataset(Dataset): 9 | def __init__(self, 10 | epoch_len = 10000, 11 | seq_len = 8, 12 | transform=None, 13 | abc=string.digits): 14 | super().__init__() 15 | self.abc = abc 16 | self.epoch_len = epoch_len 17 | self.seq_len = seq_len 18 | self.transform = transform 19 | 20 | def __len__(self): 21 | return self.epoch_len 22 | 23 | def get_abc(self): 24 | return self.abc 25 | 26 | def set_mode(self, mode='train'): 27 | return 28 | 29 | def generate_string(self): 30 | return ''.join(random.choice(self.abc) for _ in range(self.seq_len)) 31 | 32 | def get_sample(self): 33 | h, w = 64, int(self.seq_len * 64 * 2.5) 34 | pw = int(w / self.seq_len) 35 | seq = [] 36 | img = np.zeros((h, w), dtype=np.uint8) 37 | text = self.generate_string() 38 | for i in range(len(text)): 39 | c = text[i] 40 | seq.append(self.abc.find(c) + 1) 41 | hs, ws = 32, 32 42 | symb = np.zeros((hs, ws), dtype=np.uint8) 43 | font = cv2.FONT_HERSHEY_SIMPLEX 44 | cv2.putText(symb, str(c), (3, 30), font, 1.2, (255), 2, cv2.LINE_AA) 45 | # Rotation 46 | angle = 60 47 | ang_rot = np.random.uniform(angle) - angle/2 48 | transform = cv2.getRotationMatrix2D((ws/2, hs/2), ang_rot, 1) 49 | symb = cv2.warpAffine(symb, transform, (ws, hs), borderValue = 0) 50 | # Scale 51 | scale = np.random.uniform(0.7, 1.0) 52 | transform = np.float32([[scale, 0, 0],[0, scale, 0]]) 53 | symb = cv2.warpAffine(symb, transform, (ws, hs), borderValue = 0) 54 | y = np.random.randint(hs, h) 55 | x = np.random.randint(i * pw, (i + 1) * pw - ws) 56 | img[y-hs:y, x:x+ws] = symb 57 | nw = int(w * 32 / h) 58 | img = cv2.resize(img, (nw, 32)) 59 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 60 | return img, seq 61 | 62 | def __getitem__(self, idx): 63 | img, seq = self.get_sample() 64 | sample = {"img": img, "seq": seq, "seq_len": len(seq), "aug": True} 65 | if self.transform: 66 | sample = self.transform(sample) 67 | return sample 68 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/dataset/text_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import os 4 | import cv2 5 | 6 | class TextDataset(Dataset): 7 | def __init__(self, data_path, mode="train", transform=None): 8 | super(Dataset, self).__init__() 9 | self.data_path = data_path 10 | self.mode = mode 11 | self.config = json.load(open(os.path.join(data_path, "desc.json"))) 12 | self.transform = transform 13 | 14 | def abc_len(self): 15 | return len(self.config["abc"]) 16 | 17 | def get_abc(self): 18 | return self.config["abc"] 19 | 20 | def set_mode(self, mode): 21 | self.mode = mode 22 | 23 | def __len__(self): 24 | if self.mode == "test": 25 | return len(self.config[self.mode]) 26 | return len(self.config[self.mode]) 27 | 28 | def __getitem__(self, idx): 29 | 30 | name = self.config[self.mode][idx]["name"] 31 | text = self.config[self.mode][idx]["text"] 32 | 33 | img = cv2.imread(os.path.join(self.data_path, "data", name)) 34 | # print(os.path.join(self.data_path, "data", name)) 35 | # img = cv2.imread(os.path.join(self.data_path, name)) 36 | seq = self.text_to_seq(text) 37 | sample = {"img": img, "seq": seq, "seq_len": len(seq), "aug": self.mode == "train"} 38 | if self.transform: 39 | # print('trans') 40 | sample = self.transform(sample) 41 | return sample 42 | 43 | def text_to_seq(self, text): 44 | seq = [] 45 | for c in text: 46 | seq.append(self.config["abc"].find(c) + 1) 47 | return seq -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/fold_tta.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/5 - ResNet34/crnn-pytorch/fold_tta.pkl -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/lr_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class StepLR(object): 5 | def __init__(self, optimizer, step_size=1000, max_iter=10000): 6 | self.optimizer = optimizer 7 | self.max_iter = max_iter 8 | self.step_size = step_size 9 | self.last_iter = -1 10 | self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) 11 | 12 | def get_lr(self): 13 | return self.optimizer.param_groups[0]['lr'] 14 | 15 | def step(self, last_iter=None): 16 | if last_iter is not None: 17 | self.last_iter = last_iter 18 | if self.last_iter + 1 == self.max_iter: 19 | self.last_iter = -1 20 | self.last_iter = (self.last_iter + 1) % self.max_iter 21 | for ids, param_group in enumerate(self.optimizer.param_groups): 22 | param_group['lr'] = self.base_lrs[ids] * 0.8 ** ( self.last_iter // self.step_size ) 23 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/5 - ResNet34/crnn-pytorch/models/__init__.py -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/models/model_loader.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .crnn import CRNN 7 | 8 | def load_weights(target, source_state): 9 | new_dict = OrderedDict() 10 | for k, v in target.state_dict().items(): 11 | if k in source_state and v.size() == source_state[k].size(): 12 | new_dict[k] = source_state[k] 13 | else: 14 | new_dict[k] = v 15 | target.load_state_dict(new_dict) 16 | 17 | def load_model(abc, seq_proj=[0, 0], backend='resnet18', snapshot=None, cuda=True): 18 | net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend) 19 | net = nn.DataParallel(net) 20 | if snapshot is not None: 21 | load_weights(net, torch.load(snapshot)) 22 | if cuda: 23 | net = net.cuda() 24 | return net 25 | -------------------------------------------------------------------------------- /5 - ResNet34/crnn-pytorch/submit.py: -------------------------------------------------------------------------------- 1 | def check_label(s): 2 | if '*' in s: 3 | return True 4 | if len(s) != 10: 5 | return True 6 | 7 | if len(set(s[3:]) & set(string.ascii_uppercase)) > 0: 8 | return True 9 | 10 | if s[0] in string.digits: 11 | return True 12 | 13 | if s[0] in string.ascii_uppercase and s[1] in string.ascii_uppercase and s[2] in string.ascii_uppercase: 14 | return True 15 | 16 | if s[0] in string.ascii_uppercase and s[1] in string.ascii_uppercase: 17 | return True 18 | elif s[0] in string.ascii_uppercase and s[2] in string.ascii_uppercase and s[1] in string.digits: 19 | return True 20 | else: 21 | return False 22 | 23 | 24 | import pandas as pd 25 | import string 26 | submit_df1 = pd.read_csv('./tmp_rcnn_tta10_pb.csv') 27 | submit_df2 = pd.read_csv('../multi-digit-pytorch/tmp_rcnn_tta10_cnn.csv') 28 | 29 | submit_df1.loc[submit_df1['name'] == 'OFTUHPVE.jpg', 'label'] = submit_df2[submit_df2['name'] == 'OFTUHPVE.jpg']['label'] 30 | submit_df1[~submit_df1['label'].apply(lambda x: check_label(x))] 31 | submit_df1.to_csv('tmp_rcnn_tta10_pb_submit.csv',index=None) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TinyMind人民币编码识别比赛赛后分享(含文档,模型,代码,视频) 2 | === 3 | 4 | --- 5 | 6 | ## 竞赛地址:https://www.tinymind.cn/competitions/47 (分享视频见竞赛页面) 7 | 8 | --- 9 | 10 | 11 | 竞赛介绍 12 | --- 13 | 14 | 人民币作为我国法定货币,在我们日常生活中的重要性不言而喻,每张人民币纸币都有一串唯一的字符号码,是人民币的重要标识。针对人民币纸币面额及编码识别技术的深入研究,TinyMind发起挑战赛,要求参赛者根据训练集图片训练模型,使得该模型能自动根据输入的任意人民币图片识别对应人民币的编码。目前该比赛已圆满结束。进入自由练习阶段。 15 | 16 | 7月11日,我们邀请了决赛前五名大佬在线直播分享比赛心得。他们的模型准确率都在99.99%以上,听众受益良多,现将本次比赛大佬们提交的代码、模型、文档共享给大家,希望大家举一反三,在类似的项目中都能够取得较好的成绩。 17 | 18 | --- 19 | 20 | 21 | 学习交流群: 22 | --- 23 | 24 | ![tinymind客服](https://www.tinymind.cn/assets/images/wx-kf.jpg) 25 | 26 | 微信添加好友 `TinyMind小助手` 回复 `编码识别` 进交流群 27 | 28 | 29 | 更多竞赛咨询 30 | --- 31 | 32 | ![tinymind社区](https://www.tinymind.cn/assets/images/wx-tm.jpg) 33 | 34 | 35 | 请关注微信公众号 `TinyMind社区` 获得更多相关资讯 -------------------------------------------------------------------------------- /赛后分享PPT/1 - wei.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/赛后分享PPT/1 - wei.pdf -------------------------------------------------------------------------------- /赛后分享PPT/2 - TitanikData.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/赛后分享PPT/2 - TitanikData.pdf -------------------------------------------------------------------------------- /赛后分享PPT/3 - TechDing.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/赛后分享PPT/3 - TechDing.pdf -------------------------------------------------------------------------------- /赛后分享PPT/4 - HLearning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/赛后分享PPT/4 - HLearning.pdf -------------------------------------------------------------------------------- /赛后分享PPT/5 - ResNet34.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mind/RMB/cd656b3000fc612657b57dc70dfa76432a448014/赛后分享PPT/5 - ResNet34.pdf --------------------------------------------------------------------------------