├── README.md ├── __pycache__ └── program.cpython-38.pyc ├── configs ├── test_table_mv3.yml └── train_table_mv3.yml ├── inference_model ├── baidu │ └── infer │ │ └── table_rec │ │ └── pd2pt.pt └── juneli │ └── finetune │ └── table_rec │ └── best.pt ├── models ├── architectures │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── base_model.cpython-38.pyc │ └── base_model.py ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── table_mobilenet_v3.cpython-38.pyc │ │ └── table_resnet_vd.cpython-38.pyc │ ├── table_mobilenet_v3.py │ └── table_resnet_vd.py └── heads │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── table_att_head.cpython-38.pyc │ └── table_att_head.py ├── recognizer_table.py ├── script ├── paddle2pytorch.py ├── t_0.py ├── t_1.py └── t_2.py ├── test.py ├── train.py └── utils ├── __pycache__ ├── logging.cpython-38.pyc ├── save_load.cpython-38.pyc ├── stats.cpython-38.pyc ├── torch_utils.cpython-38.pyc └── utility.cpython-38.pyc ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── pubtab_dataset.cpython-38.pyc ├── imaug │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── copy_paste.cpython-38.pyc │ │ ├── east_process.cpython-38.pyc │ │ ├── gen_table_mask.cpython-38.pyc │ │ ├── iaa_augment.cpython-38.pyc │ │ ├── label_ops.cpython-38.pyc │ │ ├── make_border_map.cpython-38.pyc │ │ ├── make_shrink_map.cpython-38.pyc │ │ ├── operators.cpython-38.pyc │ │ ├── pg_process.cpython-38.pyc │ │ ├── randaugment.cpython-38.pyc │ │ ├── random_crop_data.cpython-38.pyc │ │ ├── rec_img_aug.cpython-38.pyc │ │ └── sast_process.cpython-38.pyc │ ├── copy_paste.py │ ├── east_process.py │ ├── gen_table_mask.py │ ├── iaa_augment.py │ ├── label_ops.py │ ├── make_border_map.py │ ├── make_shrink_map.py │ ├── operators.py │ ├── pg_process.py │ ├── randaugment.py │ ├── random_crop_data.py │ ├── rec_img_aug.py │ ├── sast_process.py │ └── text_image_aug │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── augment.cpython-38.pyc │ │ └── warp_mls.cpython-38.pyc │ │ ├── augment.py │ │ └── warp_mls.py └── pubtab_dataset.py ├── dict ├── ar_dict.txt ├── arabic_dict.txt ├── be_dict.txt ├── bg_dict.txt ├── chinese_cht_dict.txt ├── cyrillic_dict.txt ├── devanagari_dict.txt ├── en_dict.txt ├── fa_dict.txt ├── french_dict.txt ├── german_dict.txt ├── hi_dict.txt ├── it_dict.txt ├── japan_dict.txt ├── ka_dict.txt ├── korean_dict.txt ├── latin_dict.txt ├── mr_dict.txt ├── ne_dict.txt ├── oc_dict.txt ├── pu_dict.txt ├── rs_dict.txt ├── rsc_dict.txt ├── ru_dict.txt ├── ta_dict.txt ├── table_dict.txt ├── table_structure_dict.txt ├── te_dict.txt ├── ug_dict.txt ├── uk_dict.txt ├── ur_dict.txt └── xi_dict.txt ├── logging.py ├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── table_att_loss.cpython-38.pyc └── table_att_loss.py ├── metrics ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── table_metric.cpython-38.pyc └── table_metric.py ├── optimizer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── optimizer.cpython-38.pyc └── optimizer.py ├── postprocess ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── rec_postprocess.cpython-38.pyc └── rec_postprocess.py ├── save_load.py ├── stats.py ├── torch_utils.py └── utility.py /README.md: -------------------------------------------------------------------------------- 1 | *注:本项目只是做表格结构预测和cell location回归,如果需要填充内容,需要连接自己的OCR模块; 2 | - todo list 3 | - [x] 完成table cell box格式(box格式可以用来对单元格做检测,且可简化标注难度)转html算法 4 | - [x] paddle表格识别代码和预训练模型转pytorch,预训练模型转换代码script/paddle2pytorch.py 5 | - [x] 私有有线表格数据训练 6 | - [x] 训练、测试、inder代码重构 7 | - [ ] 私有无线表格数据训练 8 | - [ ] 表格识别模型cell location回归优化 9 | 10 | # 环境 11 | python:3.8 12 | pytorch:1.7.1 13 | CUDA:10.1 14 | 15 | # DATASETS 16 | ## Public: 17 | [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet): 预训练使用了这个数据集,可自行去官网下载。 18 | 19 | ## Private: 20 | 私人数据暂不公开,后期会提供table cell box格式转HTML代码。 21 | 22 | # train 23 | python train.py --config=configs/train_table_mv3.yml 24 | 25 | # test 26 | python test.py --config=configs/test_table_mv3.yml 27 | 28 | # infer 29 | python recognizer_table.py 30 | 31 | # model 32 | 百度基于Pubtabnet训练的模型(paddle转pytorch): 33 | inference_model/juneli/finetune/table_rec/pd2pt.pt 34 | 基于私人数据训练的有线表格识别模型(目前场景比较固定,后续会继续泛化,使用者也可自行finetune): 35 | inference_model/juneli/finetune/table_rec/best.pt 36 | 37 | # 参考 38 | [PPOCR](https://github.com/PaddlePaddle/PaddleOCR) 39 | -------------------------------------------------------------------------------- /__pycache__/program.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/__pycache__/program.cpython-38.pyc -------------------------------------------------------------------------------- /configs/test_table_mv3.yml: -------------------------------------------------------------------------------- 1 | Global: 2 | use_gpu: '3' 3 | # evaluation is run every 400 iterations after the 0th iteration 4 | save_model_dir: / 5 | checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/baidu/infer/table_rec/pd2pt.pt 6 | # checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/juneli/finetune/table_rec/last.pt 7 | character_dict_path: utils/dict/table_structure_dict.txt 8 | character_type: en 9 | max_text_length: 100 10 | max_elem_length: 800 11 | max_cell_num: 500 12 | 13 | Architecture: 14 | model_type: table 15 | algorithm: TableAttn 16 | Backbone: 17 | name: MobileNetV3 18 | scale: 1.0 19 | model_name: large 20 | Head: 21 | name: TableAttentionHead 22 | hidden_size: 256 23 | l2_decay: 0.00001 24 | loc_type: 2 25 | max_text_length: 100 26 | max_elem_length: 800 27 | max_cell_num: 500 28 | 29 | PostProcess: 30 | name: TableLabelDecode 31 | 32 | Metric: 33 | name: TableMetric 34 | main_indicator: acc 35 | 36 | Eval: 37 | dataset: 38 | name: PubTabDataSet 39 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/public/pubtabnet_split/val/ 40 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/public/pubtabnet_split/val.json 41 | # data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/ 42 | # label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/val.json 43 | transforms: 44 | - DecodeImage: # load image 45 | img_mode: BGR 46 | channel_first: False 47 | - ResizeTableImage: 48 | max_len: 488 49 | - TableLabelEncode: 50 | - NormalizeImage: 51 | scale: 1./255. 52 | mean: [0.485, 0.456, 0.406] 53 | std: [0.229, 0.224, 0.225] 54 | order: 'hwc' 55 | - PaddingTableImage: 56 | - ToCHWImage: 57 | - KeepKeys: 58 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] 59 | loader: 60 | shuffle: False 61 | drop_last: False 62 | batch_size_per_card: 128 63 | num_workers: 8 64 | -------------------------------------------------------------------------------- /configs/train_table_mv3.yml: -------------------------------------------------------------------------------- 1 | Global: 2 | use_gpu: '0,1,2,3' 3 | epoch_num: 400 4 | log_smooth_window: 20 5 | print_batch_step: 5 6 | checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/baidu/infer/table_rec/pd2pt.pt 7 | # checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/juneli/finetune/table_rec/last.pt 8 | save_model_dir: ./inference_model/juneli/finetune/table_rec/ 9 | save_epoch_step: 1 10 | # evaluation is run every 1000 iterations after the 0th iteration 11 | eval_batch_step: 1000 12 | cal_metric_during_train: True 13 | # for data or label process 14 | character_dict_path: utils/dict/table_structure_dict.txt 15 | character_type: en 16 | max_text_length: 100 17 | max_elem_length: 800 18 | max_cell_num: 500 19 | process_total_num: 0 20 | process_cut_num: 0 21 | 22 | Optimizer: 23 | name: Adam 24 | beta1: 0.9 25 | beta2: 0.999 26 | clip_norm: 5.0 27 | lr: 28 | learning_rate: 0.0001 29 | regularizer: 30 | name: 'L2' 31 | factor: 0.00000 32 | 33 | Architecture: 34 | model_type: table 35 | algorithm: TableAttn 36 | Backbone: 37 | name: MobileNetV3 38 | scale: 1.0 39 | model_name: large 40 | Head: 41 | name: TableAttentionHead 42 | hidden_size: 256 43 | l2_decay: 0.00001 44 | loc_type: 2 45 | max_text_length: 100 46 | max_elem_length: 800 47 | max_cell_num: 500 48 | 49 | Loss: 50 | name: TableAttentionLoss 51 | structure_weight: 100.0 52 | loc_weight: 10000.0 53 | 54 | PostProcess: 55 | name: TableLabelDecode 56 | 57 | Metric: 58 | name: TableMetric 59 | main_indicator: acc 60 | 61 | Train: 62 | dataset: 63 | name: PubTabDataSet 64 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/ 65 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/train.json 66 | transforms: 67 | - DecodeImage: # load image 68 | img_mode: BGR 69 | channel_first: False 70 | - ResizeTableImage: 71 | max_len: 488 72 | - TableLabelEncode: 73 | - NormalizeImage: 74 | scale: 1./255. 75 | mean: [0.485, 0.456, 0.406] 76 | std: [0.229, 0.224, 0.225] 77 | order: 'hwc' 78 | - PaddingTableImage: 79 | - ToCHWImage: 80 | - KeepKeys: 81 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] 82 | loader: 83 | shuffle: True 84 | batch_size_per_card: 24 85 | drop_last: True 86 | num_workers: 8 87 | 88 | Eval: 89 | dataset: 90 | name: PubTabDataSet 91 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/ 92 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/val.json 93 | transforms: 94 | - DecodeImage: # load image 95 | img_mode: BGR 96 | channel_first: False 97 | - ResizeTableImage: 98 | max_len: 488 99 | - TableLabelEncode: 100 | - NormalizeImage: 101 | scale: 1./255. 102 | mean: [0.485, 0.456, 0.406] 103 | std: [0.229, 0.224, 0.225] 104 | order: 'hwc' 105 | - PaddingTableImage: 106 | - ToCHWImage: 107 | - KeepKeys: 108 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] 109 | loader: 110 | shuffle: False 111 | drop_last: False 112 | batch_size_per_card: 32 113 | num_workers: 4 114 | -------------------------------------------------------------------------------- /inference_model/baidu/infer/table_rec/pd2pt.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/inference_model/baidu/infer/table_rec/pd2pt.pt -------------------------------------------------------------------------------- /inference_model/juneli/finetune/table_rec/best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/inference_model/juneli/finetune/table_rec/best.pt -------------------------------------------------------------------------------- /models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | 4 | from .base_model import BaseModel 5 | # from .distillation_model import DistillationModel 6 | 7 | __all__ = ['build_model'] 8 | 9 | 10 | def build_model(config): 11 | config = copy.deepcopy(config) 12 | if not "name" in config: 13 | arch = BaseModel(config) 14 | else: 15 | name = config.pop("name") 16 | mod = importlib.import_module(__name__) 17 | arch = getattr(mod, name)(config) 18 | return arch 19 | -------------------------------------------------------------------------------- /models/architectures/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/architectures/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/architectures/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/architectures/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /models/architectures/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | from torch import nn 18 | from models.backbones import build_backbone 19 | from models.heads import build_head 20 | 21 | __all__ = ['BaseModel'] 22 | 23 | 24 | class BaseModel(nn.Module): 25 | def __init__(self, config): 26 | """ 27 | the module for OCR. 28 | args: 29 | config (dict): the super parameters for module. 30 | """ 31 | super(BaseModel, self).__init__() 32 | in_channels = config.get('in_channels', 3) 33 | model_type = config['model_type'] 34 | # build transfrom, 35 | # for rec, transfrom can be TPS,None 36 | # for det and cls, transfrom should be None, 37 | # if you make model differently, you can use transfrom in det and cls 38 | if 'Transform' not in config or config['Transform'] is None: 39 | self.use_transform = False 40 | else: 41 | self.use_transform = True 42 | config['Transform']['in_channels'] = in_channels 43 | self.transform = build_transform(config['Transform']) 44 | in_channels = self.transform.out_channels 45 | 46 | # build backbone, backbone is needed for del, rec and cls 47 | config["Backbone"]['in_channels'] = in_channels 48 | self.backbone = build_backbone(config["Backbone"], model_type) 49 | in_channels = self.backbone.out_channels 50 | 51 | # build neck 52 | # for rec, neck can be cnn,rnn or reshape(None) 53 | # for det, neck can be FPN, BIFPN and so on. 54 | # for cls, neck should be none 55 | if 'Neck' not in config or config['Neck'] is None: 56 | self.use_neck = False 57 | else: 58 | self.use_neck = True 59 | config['Neck']['in_channels'] = in_channels 60 | self.neck = build_neck(config['Neck']) 61 | in_channels = self.neck.out_channels 62 | 63 | # # build head, head is needed for det, rec and cls 64 | config["Head"]['in_channels'] = in_channels 65 | self.head = build_head(config["Head"]) 66 | 67 | self.return_all_feats = config.get("return_all_feats", False) 68 | 69 | def forward(self, x, data=None): 70 | y = dict() 71 | if self.use_transform: 72 | x = self.transform(x) 73 | x = self.backbone(x) 74 | y["backbone_out"] = x 75 | if self.use_neck: 76 | x = self.neck(x) 77 | y["neck_out"] = x 78 | x = self.head(x, targets=data) 79 | if isinstance(x, dict): 80 | y.update(x) 81 | else: 82 | y["head_out"] = x 83 | if self.return_all_feats: 84 | return y 85 | else: 86 | return x 87 | -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ["build_backbone"] 16 | 17 | 18 | def build_backbone(config, model_type): 19 | if model_type == "table": 20 | from .table_resnet_vd import ResNet 21 | from .table_mobilenet_v3 import MobileNetV3 22 | support_dict = ["ResNet", "MobileNetV3"] 23 | else: 24 | raise NotImplementedError 25 | 26 | module_name = config.pop("name") 27 | assert module_name in support_dict, Exception( 28 | "when model typs is {}, backbone only support {}".format(model_type, 29 | support_dict)) 30 | module_class = eval(module_name)(**config) 31 | return module_class 32 | -------------------------------------------------------------------------------- /models/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/table_mobilenet_v3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/table_mobilenet_v3.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/__pycache__/table_resnet_vd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/table_resnet_vd.cpython-38.pyc -------------------------------------------------------------------------------- /models/backbones/table_mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['MobileNetV3'] 10 | 11 | 12 | def make_divisible(v, divisor=8, min_value=None): 13 | if min_value is None: 14 | min_value = divisor 15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 16 | if new_v < 0.9 * v: 17 | new_v += divisor 18 | return new_v 19 | 20 | 21 | class MobileNetV3(nn.Module): 22 | def __init__(self, 23 | in_channels=3, 24 | model_name='large', 25 | scale=0.5, 26 | disable_se=False, 27 | **kwargs): 28 | """ 29 | the MobilenetV3 backbone network for detection module. 30 | Args: 31 | params(dict): the super parameters for build network 32 | """ 33 | super(MobileNetV3, self).__init__() 34 | 35 | self.disable_se = disable_se 36 | 37 | if model_name == "large": 38 | cfg = [ 39 | # k, exp, c, se, nl, s, 40 | [3, 16, 16, False, 'relu', 1], 41 | [3, 64, 24, False, 'relu', 2], 42 | [3, 72, 24, False, 'relu', 1], 43 | [5, 72, 40, True, 'relu', 2], 44 | [5, 120, 40, True, 'relu', 1], 45 | [5, 120, 40, True, 'relu', 1], 46 | [3, 240, 80, False, 'hardswish', 2], 47 | [3, 200, 80, False, 'hardswish', 1], 48 | [3, 184, 80, False, 'hardswish', 1], 49 | [3, 184, 80, False, 'hardswish', 1], 50 | [3, 480, 112, True, 'hardswish', 1], 51 | [3, 672, 112, True, 'hardswish', 1], 52 | [5, 672, 160, True, 'hardswish', 2], 53 | [5, 960, 160, True, 'hardswish', 1], 54 | [5, 960, 160, True, 'hardswish', 1], 55 | ] 56 | cls_ch_squeeze = 960 57 | elif model_name == "small": 58 | cfg = [ 59 | # k, exp, c, se, nl, s, 60 | [3, 16, 16, True, 'relu', 2], 61 | [3, 72, 24, False, 'relu', 2], 62 | [3, 88, 24, False, 'relu', 1], 63 | [5, 96, 40, True, 'hardswish', 2], 64 | [5, 240, 40, True, 'hardswish', 1], 65 | [5, 240, 40, True, 'hardswish', 1], 66 | [5, 120, 48, True, 'hardswish', 1], 67 | [5, 144, 48, True, 'hardswish', 1], 68 | [5, 288, 96, True, 'hardswish', 2], 69 | [5, 576, 96, True, 'hardswish', 1], 70 | [5, 576, 96, True, 'hardswish', 1], 71 | ] 72 | cls_ch_squeeze = 576 73 | else: 74 | raise NotImplementedError("mode[" + model_name + 75 | "_model] is not implemented!") 76 | 77 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] 78 | assert scale in supported_scale, \ 79 | "supported scale are {} but input scale is {}".format(supported_scale, scale) 80 | inplanes = 16 81 | # conv1 82 | self.conv = ConvBNLayer( 83 | in_channels=in_channels, 84 | out_channels=make_divisible(inplanes * scale), 85 | kernel_size=3, 86 | stride=2, 87 | padding=1, 88 | groups=1, 89 | if_act=True, 90 | act='hardswish', 91 | name='conv1') 92 | 93 | self.stages = [] 94 | self.out_channels = [] 95 | block_list = [] 96 | i = 0 97 | inplanes = make_divisible(inplanes * scale) 98 | for (k, exp, c, se, nl, s) in cfg: 99 | se = se and not self.disable_se 100 | start_idx = 2 if model_name == 'large' else 0 101 | if s == 2 and i > start_idx: 102 | self.out_channels.append(inplanes) 103 | self.stages.append(nn.Sequential(*block_list)) 104 | block_list = [] 105 | block_list.append( 106 | ResidualUnit( 107 | in_channels=inplanes, 108 | mid_channels=make_divisible(scale * exp), 109 | out_channels=make_divisible(scale * c), 110 | kernel_size=k, 111 | stride=s, 112 | use_se=se, 113 | act=nl, 114 | name="conv" + str(i + 2))) 115 | inplanes = make_divisible(scale * c) 116 | i += 1 117 | block_list.append( 118 | ConvBNLayer( 119 | in_channels=inplanes, 120 | out_channels=make_divisible(scale * cls_ch_squeeze), 121 | kernel_size=1, 122 | stride=1, 123 | padding=0, 124 | groups=1, 125 | if_act=True, 126 | act='hardswish', 127 | name='conv_last')) 128 | self.stages.append(nn.Sequential(*block_list)) 129 | self.stages_pipline = nn.Sequential(*self.stages) 130 | self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) 131 | # for i, stage in enumerate(self.stages): 132 | # self.add_sublayer(sublayer=stage, name="stage{}".format(i)) 133 | 134 | def forward(self, x): 135 | x = self.conv(x) 136 | out_list = [] 137 | # for stage in self.stages: 138 | # x = stage(x) 139 | # out_list.append(x) 140 | x = self.stages_pipline(x) 141 | out_list.append(x) 142 | return out_list 143 | 144 | 145 | class ConvBNLayer(nn.Module): 146 | def __init__(self, 147 | in_channels, 148 | out_channels, 149 | kernel_size, 150 | stride, 151 | padding, 152 | groups=1, 153 | if_act=True, 154 | act=None, 155 | name=None): 156 | super(ConvBNLayer, self).__init__() 157 | self.if_act = if_act 158 | self.act = act 159 | self.conv = nn.Conv2d( 160 | in_channels=in_channels, 161 | out_channels=out_channels, 162 | kernel_size=kernel_size, 163 | stride=stride, 164 | padding=padding, 165 | groups=groups, 166 | bias=False) 167 | 168 | self.bn = nn.BatchNorm2d(out_channels) 169 | 170 | def forward(self, x): 171 | x = self.conv(x) 172 | x = self.bn(x) 173 | if self.if_act: 174 | if self.act == "relu": 175 | x = F.relu(x) 176 | elif self.act == "hardswish": 177 | x = F.hardswish(x) 178 | else: 179 | print("The activation function({}) is selected incorrectly.". 180 | format(self.act)) 181 | exit() 182 | return x 183 | 184 | 185 | class ResidualUnit(nn.Module): 186 | def __init__(self, 187 | in_channels, 188 | mid_channels, 189 | out_channels, 190 | kernel_size, 191 | stride, 192 | use_se, 193 | act=None, 194 | name=''): 195 | super(ResidualUnit, self).__init__() 196 | self.if_shortcut = stride == 1 and in_channels == out_channels 197 | self.if_se = use_se 198 | 199 | self.expand_conv = ConvBNLayer( 200 | in_channels=in_channels, 201 | out_channels=mid_channels, 202 | kernel_size=1, 203 | stride=1, 204 | padding=0, 205 | if_act=True, 206 | act=act, 207 | name=name + "_expand") 208 | self.bottleneck_conv = ConvBNLayer( 209 | in_channels=mid_channels, 210 | out_channels=mid_channels, 211 | kernel_size=kernel_size, 212 | stride=stride, 213 | padding=int((kernel_size - 1) // 2), 214 | groups=mid_channels, 215 | if_act=True, 216 | act=act, 217 | name=name + "_depthwise") 218 | if self.if_se: 219 | self.mid_se = SEModule(mid_channels, name=name + "_se") 220 | self.linear_conv = ConvBNLayer( 221 | in_channels=mid_channels, 222 | out_channels=out_channels, 223 | kernel_size=1, 224 | stride=1, 225 | padding=0, 226 | if_act=False, 227 | act=None, 228 | name=name + "_linear") 229 | 230 | def forward(self, inputs): 231 | x = self.expand_conv(inputs) 232 | x = self.bottleneck_conv(x) 233 | if self.if_se: 234 | x = self.mid_se(x) 235 | x = self.linear_conv(x) 236 | if self.if_shortcut: 237 | x = torch.add(inputs, x) 238 | return x 239 | 240 | 241 | class SEModule(nn.Module): 242 | def __init__(self, in_channels, reduction=4, name=""): 243 | super(SEModule, self).__init__() 244 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 245 | self.conv1 = nn.Conv2d( 246 | in_channels=in_channels, 247 | out_channels=in_channels // reduction, 248 | kernel_size=1, 249 | stride=1, 250 | padding=0) 251 | self.conv2 = nn.Conv2d( 252 | in_channels=in_channels // reduction, 253 | out_channels=in_channels, 254 | kernel_size=1, 255 | stride=1, 256 | padding=0) 257 | 258 | def forward(self, inputs): 259 | outputs = self.avg_pool(inputs) 260 | outputs = self.conv1(outputs) 261 | outputs = F.relu(outputs) 262 | outputs = self.conv2(outputs) 263 | # outputs = F.hardsigmoid(outputs) 264 | outputs = F.relu6(1.2 * outputs + 3.) / 6. 265 | return inputs * outputs 266 | -------------------------------------------------------------------------------- /models/backbones/table_resnet_vd.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import paddle 20 | from paddle import ParamAttr 21 | import paddle.nn as nn 22 | import paddle.nn.functional as F 23 | 24 | __all__ = ["ResNet"] 25 | 26 | 27 | class ConvBNLayer(nn.Layer): 28 | def __init__( 29 | self, 30 | in_channels, 31 | out_channels, 32 | kernel_size, 33 | stride=1, 34 | groups=1, 35 | is_vd_mode=False, 36 | act=None, 37 | name=None, ): 38 | super(ConvBNLayer, self).__init__() 39 | 40 | self.is_vd_mode = is_vd_mode 41 | self._pool2d_avg = nn.AvgPool2D( 42 | kernel_size=2, stride=2, padding=0, ceil_mode=True) 43 | self._conv = nn.Conv2D( 44 | in_channels=in_channels, 45 | out_channels=out_channels, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=(kernel_size - 1) // 2, 49 | groups=groups, 50 | weight_attr=ParamAttr(name=name + "_weights"), 51 | bias_attr=False) 52 | if name == "conv1": 53 | bn_name = "bn_" + name 54 | else: 55 | bn_name = "bn" + name[3:] 56 | self._batch_norm = nn.BatchNorm( 57 | out_channels, 58 | act=act, 59 | param_attr=ParamAttr(name=bn_name + '_scale'), 60 | bias_attr=ParamAttr(bn_name + '_offset'), 61 | moving_mean_name=bn_name + '_mean', 62 | moving_variance_name=bn_name + '_variance') 63 | 64 | def forward(self, inputs): 65 | if self.is_vd_mode: 66 | inputs = self._pool2d_avg(inputs) 67 | y = self._conv(inputs) 68 | y = self._batch_norm(y) 69 | return y 70 | 71 | 72 | class BottleneckBlock(nn.Layer): 73 | def __init__(self, 74 | in_channels, 75 | out_channels, 76 | stride, 77 | shortcut=True, 78 | if_first=False, 79 | name=None): 80 | super(BottleneckBlock, self).__init__() 81 | 82 | self.conv0 = ConvBNLayer( 83 | in_channels=in_channels, 84 | out_channels=out_channels, 85 | kernel_size=1, 86 | act='relu', 87 | name=name + "_branch2a") 88 | self.conv1 = ConvBNLayer( 89 | in_channels=out_channels, 90 | out_channels=out_channels, 91 | kernel_size=3, 92 | stride=stride, 93 | act='relu', 94 | name=name + "_branch2b") 95 | self.conv2 = ConvBNLayer( 96 | in_channels=out_channels, 97 | out_channels=out_channels * 4, 98 | kernel_size=1, 99 | act=None, 100 | name=name + "_branch2c") 101 | 102 | if not shortcut: 103 | self.short = ConvBNLayer( 104 | in_channels=in_channels, 105 | out_channels=out_channels * 4, 106 | kernel_size=1, 107 | stride=1, 108 | is_vd_mode=False if if_first else True, 109 | name=name + "_branch1") 110 | 111 | self.shortcut = shortcut 112 | 113 | def forward(self, inputs): 114 | y = self.conv0(inputs) 115 | conv1 = self.conv1(y) 116 | conv2 = self.conv2(conv1) 117 | 118 | if self.shortcut: 119 | short = inputs 120 | else: 121 | short = self.short(inputs) 122 | y = paddle.add(x=short, y=conv2) 123 | y = F.relu(y) 124 | return y 125 | 126 | 127 | class BasicBlock(nn.Layer): 128 | def __init__(self, 129 | in_channels, 130 | out_channels, 131 | stride, 132 | shortcut=True, 133 | if_first=False, 134 | name=None): 135 | super(BasicBlock, self).__init__() 136 | self.stride = stride 137 | self.conv0 = ConvBNLayer( 138 | in_channels=in_channels, 139 | out_channels=out_channels, 140 | kernel_size=3, 141 | stride=stride, 142 | act='relu', 143 | name=name + "_branch2a") 144 | self.conv1 = ConvBNLayer( 145 | in_channels=out_channels, 146 | out_channels=out_channels, 147 | kernel_size=3, 148 | act=None, 149 | name=name + "_branch2b") 150 | 151 | if not shortcut: 152 | self.short = ConvBNLayer( 153 | in_channels=in_channels, 154 | out_channels=out_channels, 155 | kernel_size=1, 156 | stride=1, 157 | is_vd_mode=False if if_first else True, 158 | name=name + "_branch1") 159 | 160 | self.shortcut = shortcut 161 | 162 | def forward(self, inputs): 163 | y = self.conv0(inputs) 164 | conv1 = self.conv1(y) 165 | 166 | if self.shortcut: 167 | short = inputs 168 | else: 169 | short = self.short(inputs) 170 | y = paddle.add(x=short, y=conv1) 171 | y = F.relu(y) 172 | return y 173 | 174 | 175 | class ResNet(nn.Layer): 176 | def __init__(self, in_channels=3, layers=50, **kwargs): 177 | super(ResNet, self).__init__() 178 | 179 | self.layers = layers 180 | supported_layers = [18, 34, 50, 101, 152, 200] 181 | assert layers in supported_layers, \ 182 | "supported layers are {} but input layer is {}".format( 183 | supported_layers, layers) 184 | 185 | if layers == 18: 186 | depth = [2, 2, 2, 2] 187 | elif layers == 34 or layers == 50: 188 | depth = [3, 4, 6, 3] 189 | elif layers == 101: 190 | depth = [3, 4, 23, 3] 191 | elif layers == 152: 192 | depth = [3, 8, 36, 3] 193 | elif layers == 200: 194 | depth = [3, 12, 48, 3] 195 | num_channels = [64, 256, 512, 196 | 1024] if layers >= 50 else [64, 64, 128, 256] 197 | num_filters = [64, 128, 256, 512] 198 | 199 | self.conv1_1 = ConvBNLayer( 200 | in_channels=in_channels, 201 | out_channels=32, 202 | kernel_size=3, 203 | stride=2, 204 | act='relu', 205 | name="conv1_1") 206 | self.conv1_2 = ConvBNLayer( 207 | in_channels=32, 208 | out_channels=32, 209 | kernel_size=3, 210 | stride=1, 211 | act='relu', 212 | name="conv1_2") 213 | self.conv1_3 = ConvBNLayer( 214 | in_channels=32, 215 | out_channels=64, 216 | kernel_size=3, 217 | stride=1, 218 | act='relu', 219 | name="conv1_3") 220 | self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) 221 | 222 | self.stages = [] 223 | self.out_channels = [] 224 | if layers >= 50: 225 | for block in range(len(depth)): 226 | block_list = [] 227 | shortcut = False 228 | for i in range(depth[block]): 229 | if layers in [101, 152] and block == 2: 230 | if i == 0: 231 | conv_name = "res" + str(block + 2) + "a" 232 | else: 233 | conv_name = "res" + str(block + 2) + "b" + str(i) 234 | else: 235 | conv_name = "res" + str(block + 2) + chr(97 + i) 236 | bottleneck_block = self.add_sublayer( 237 | 'bb_%d_%d' % (block, i), 238 | BottleneckBlock( 239 | in_channels=num_channels[block] 240 | if i == 0 else num_filters[block] * 4, 241 | out_channels=num_filters[block], 242 | stride=2 if i == 0 and block != 0 else 1, 243 | shortcut=shortcut, 244 | if_first=block == i == 0, 245 | name=conv_name)) 246 | shortcut = True 247 | block_list.append(bottleneck_block) 248 | self.out_channels.append(num_filters[block] * 4) 249 | self.stages.append(nn.Sequential(*block_list)) 250 | else: 251 | for block in range(len(depth)): 252 | block_list = [] 253 | shortcut = False 254 | for i in range(depth[block]): 255 | conv_name = "res" + str(block + 2) + chr(97 + i) 256 | basic_block = self.add_sublayer( 257 | 'bb_%d_%d' % (block, i), 258 | BasicBlock( 259 | in_channels=num_channels[block] 260 | if i == 0 else num_filters[block], 261 | out_channels=num_filters[block], 262 | stride=2 if i == 0 and block != 0 else 1, 263 | shortcut=shortcut, 264 | if_first=block == i == 0, 265 | name=conv_name)) 266 | shortcut = True 267 | block_list.append(basic_block) 268 | self.out_channels.append(num_filters[block]) 269 | self.stages.append(nn.Sequential(*block_list)) 270 | 271 | def forward(self, inputs): 272 | y = self.conv1_1(inputs) 273 | y = self.conv1_2(y) 274 | y = self.conv1_3(y) 275 | y = self.pool2d_max(y) 276 | out = [] 277 | for block in self.stages: 278 | y = block(y) 279 | out.append(y) 280 | return out 281 | -------------------------------------------------------------------------------- /models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ['build_head'] 16 | 17 | 18 | def build_head(config): 19 | support_dict = ['TableAttentionHead'] 20 | #table head 21 | from .table_att_head import TableAttentionHead 22 | 23 | module_name = config.pop('name') 24 | assert module_name in support_dict, Exception('head only support {}'.format( 25 | support_dict)) 26 | module_class = eval(module_name)(**config) 27 | return module_class 28 | -------------------------------------------------------------------------------- /models/heads/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/heads/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/heads/__pycache__/table_att_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/heads/__pycache__/table_att_head.cpython-38.pyc -------------------------------------------------------------------------------- /recognizer_table.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/01/25 09:52 3 | # @Author : lijun 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | import platform 11 | import time 12 | import cv2 13 | import torch 14 | import numpy as np 15 | from tqdm import tqdm 16 | 17 | from models.architectures import build_model 18 | from utils.save_load import load_model 19 | from utils.torch_utils import select_device 20 | from utils.data import transform, create_operators 21 | from utils.postprocess import build_post_process 22 | 23 | 24 | class Recognizer: 25 | def __init__(self, model_path, table_char_dict_path='utils/dict/table_structure_dict.txt', 26 | gpu='0', half_flag=False): 27 | self.weights = model_path 28 | self.table_char_dict_path = table_char_dict_path 29 | self.gpu = gpu 30 | self.device = select_device(self.gpu) 31 | self.half_flag = half_flag 32 | ckpt = torch.load(self.weights, map_location='cpu') 33 | cfg = ckpt['cfg'] 34 | 35 | pre_process_list = [] 36 | for pre_process in cfg['Train']['dataset']['transforms']: 37 | if 'DecodeImage' in pre_process.keys() or 'TableLabelEncode' in pre_process.keys(): 38 | continue 39 | if 'KeepKeys' in pre_process.keys(): 40 | pre_process['KeepKeys'] = {'keep_keys': ['image']} 41 | pre_process_list.append(pre_process) 42 | self.preprocess_op = create_operators(pre_process_list) 43 | postprocess_params = { 44 | 'name': 'TableLabelDecode', 45 | "character_type": 'en', 46 | "character_dict_path": self.table_char_dict_path, 47 | } 48 | self.postprocess_op = build_post_process(postprocess_params) 49 | 50 | self.model = build_model(cfg['Architecture']) 51 | load_model(self.model, model_path) 52 | self.model.to(self.device) 53 | self.half = self.device.type != 'cpu' # half precision only supported on CUDA 54 | if self.half_flag and self.half: 55 | self.model.half() # to FP16 56 | 57 | def inference(self, img): 58 | ori_im = img.copy() 59 | data = {'image': img} 60 | data = transform(data, self.preprocess_op) 61 | img = data[0] 62 | if img is None: 63 | return None 64 | img = np.expand_dims(img, axis=0) 65 | img = img.copy() 66 | img = torch.tensor(img).to(self.device) 67 | img = img.half() if self.half_flag and self.half else img 68 | preds = self.model(img) 69 | 70 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(), 71 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()} 72 | post_result = self.postprocess_op(preds) 73 | 74 | structure_str_list = post_result['structure_str_list'] 75 | res_loc = post_result['res_loc'] 76 | imgh, imgw = ori_im.shape[0:2] 77 | res_loc_final = [] 78 | for rno in range(len(res_loc[0])): 79 | x0, y0, x1, y1 = res_loc[0][rno] 80 | left = max(int(imgw * x0), 0) 81 | top = max(int(imgh * y0), 0) 82 | right = min(int(imgw * x1), imgw - 1) 83 | bottom = min(int(imgh * y1), imgh - 1) 84 | res_loc_final.append([left, top, right, bottom]) 85 | structure_str_list = structure_str_list[0][:-1] 86 | structure_str_list = ['', '', ''] + structure_str_list + ['
', '', ''] 87 | return structure_str_list, res_loc_final 88 | 89 | 90 | if __name__ == '__main__': 91 | def html_configuration(): 92 | html_head = '\n' + \ 93 | '\n' + \ 94 | '\n' + \ 95 | ' \n' + \ 96 | ' Title\n' + \ 97 | '\n' + \ 98 | '\n' + \ 99 | '\n' + \ 107 | '\n' 108 | html_tail = '
\n' + \ 109 | '\n' + \ 110 | '\n' 111 | return [html_head, html_tail] 112 | 113 | 114 | # table_rec = Recognizer('inference_model/baidu/infer/table_rec/pd2pt.pt', gpu='1') 115 | table_rec = Recognizer('inference_model/juneli/finetune/table_rec/last.pt', gpu='cpu') 116 | base_path = './' 117 | in_path = base_path + 'test_data/FundScan/' 118 | out_path = base_path + 'test_out/FundScan/' 119 | if not os.path.exists(out_path): 120 | os.makedirs(out_path) 121 | # file_name_list = os.listdir(base_path + 'datasets/public/pubtabnet_subset/train/') 122 | file_name_list = os.listdir(in_path) 123 | for idx, file_name in enumerate(tqdm(file_name_list)): 124 | # if not file_name.startswith('0.jpg'): 125 | # continue 126 | # if file_name != '0.jpg': 127 | # continue 128 | img = cv2.imread(in_path + file_name) 129 | structure, loc = table_rec.inference(img) 130 | # my output->开始 131 | # print('html: \n', structure) 132 | open(out_path + file_name.replace('.jpg', '.html').replace('.png', '.html'), 'w').write( 133 | html_configuration()[0] + 134 | ''.join([i.replace('"', '"') for i in structure[3:-3]]) + '\n' + html_configuration()[1]) 135 | print(loc) 136 | # print(len(structure), len(loc)) 137 | show_img = img.copy() 138 | for box in loc: 139 | cv2.rectangle(show_img, tuple(box[:2]), tuple(box[2:]), (0, 0, 255), 2) 140 | cv2.imwrite(out_path + file_name, show_img) 141 | # my output->结束 142 | 143 | 144 | -------------------------------------------------------------------------------- /script/paddle2pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import paddle 3 | import numpy as np 4 | import yaml 5 | 6 | 7 | config = yaml.load(open('0.yml', 'r'), Loader=yaml.FullLoader) # 0.yml是配置文件,eg:configs/train_table_mv3.yml 8 | pd = paddle.load('/Volumes/my_disk/company/xxx/buffer_disk/a/best_accuracy.pdparams') 9 | ckpt = torch.load('/Volumes/my_disk/company/xxx/buffer_disk/a/last.pt', map_location='cpu') # ckpt表示pytorch模型 10 | 11 | out_dict = {} 12 | for pd_k in pd.keys(): 13 | ckpt_k = pd_k.replace('stage', 'stages_pipline.').replace('_mean', 'running_mean').replace('_variance', 'running_var') 14 | if np.shape(pd[pd_k].numpy()) != np.shape(ckpt[ckpt_k].numpy()) or pd_k == 'head.structure_attention_cell.h2h.weight': 15 | pd[pd_k] = paddle.transpose(pd[pd_k], (1, 0)) 16 | out_dict[ckpt_k] = torch.tensor(pd[pd_k].numpy()) 17 | torch.save({'state_dict': out_dict, 'cfg': config}, '/Volumes/my_disk/company/xxx/buffer_disk/a/pd2pt.pt') -------------------------------------------------------------------------------- /script/t_0.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | 17 | __dir__ = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(__dir__) 19 | sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) 20 | 21 | os.environ["FLAGS_allocator_strategy"] = 'auto_growth' 22 | 23 | import cv2 24 | import numpy as np 25 | import time 26 | 27 | import tools.infer.utility as utility 28 | from ppocr.data import create_operators, transform 29 | from ppocr.postprocess import build_post_process 30 | from ppocr.utils.logging import get_logger 31 | from ppocr.utils.utility import get_image_file_list, check_and_read_gif 32 | from ppstructure.utility import parse_args 33 | 34 | logger = get_logger() 35 | 36 | 37 | class TableStructurer(object): 38 | def __init__(self, args): 39 | pre_process_list = [{ 40 | 'ResizeTableImage': { 41 | 'max_len': args.table_max_len 42 | } 43 | }, { 44 | 'NormalizeImage': { 45 | 'std': [0.229, 0.224, 0.225], 46 | 'mean': [0.485, 0.456, 0.406], 47 | 'scale': '1./255.', 48 | 'order': 'hwc' 49 | } 50 | }, { 51 | 'PaddingTableImage': None 52 | }, { 53 | 'ToCHWImage': None 54 | }, { 55 | 'KeepKeys': { 56 | 'keep_keys': ['image'] 57 | } 58 | }] 59 | postprocess_params = { 60 | 'name': 'TableLabelDecode', 61 | "character_type": args.table_char_type, 62 | "character_dict_path": args.table_char_dict_path, 63 | } 64 | 65 | self.preprocess_op = create_operators(pre_process_list) 66 | self.postprocess_op = build_post_process(postprocess_params) 67 | self.predictor, self.input_tensor, self.output_tensors, self.config = \ 68 | utility.create_predictor(args, 'table', logger) 69 | 70 | def __call__(self, img): 71 | ori_im = img.copy() 72 | data = {'image': img} 73 | data = transform(data, self.preprocess_op) 74 | img = data[0] 75 | if img is None: 76 | return None, 0 77 | img = np.expand_dims(img, axis=0) 78 | img = img.copy() 79 | starttime = time.time() 80 | 81 | self.input_tensor.copy_from_cpu(img) 82 | self.predictor.run() 83 | outputs = [] 84 | for output_tensor in self.output_tensors: 85 | output = output_tensor.copy_to_cpu() 86 | outputs.append(output) 87 | 88 | preds = {} 89 | preds['structure_probs'] = outputs[1] 90 | preds['loc_preds'] = outputs[0] 91 | 92 | post_result = self.postprocess_op(preds) 93 | 94 | structure_str_list = post_result['structure_str_list'] 95 | res_loc = post_result['res_loc'] 96 | imgh, imgw = ori_im.shape[0:2] 97 | res_loc_final = [] 98 | for rno in range(len(res_loc[0])): 99 | x0, y0, x1, y1 = res_loc[0][rno] 100 | left = max(int(imgw * x0), 0) 101 | top = max(int(imgh * y0), 0) 102 | right = min(int(imgw * x1), imgw - 1) 103 | bottom = min(int(imgh * y1), imgh - 1) 104 | res_loc_final.append([left, top, right, bottom]) 105 | 106 | structure_str_list = structure_str_list[0][:-1] 107 | structure_str_list = ['', '', ''] + structure_str_list + ['
', '', ''] 108 | 109 | elapse = time.time() - starttime 110 | return (structure_str_list, res_loc_final), elapse 111 | 112 | 113 | def main(args): 114 | image_file_list = get_image_file_list(args.image_dir) 115 | table_structurer = TableStructurer(args) 116 | count = 0 117 | total_time = 0 118 | for image_file in image_file_list: 119 | img, flag = check_and_read_gif(image_file) 120 | if not flag: 121 | img = cv2.imread(image_file) 122 | if img is None: 123 | logger.info("error in loading image:{}".format(image_file)) 124 | continue 125 | structure_res, elapse = table_structurer(img) 126 | 127 | logger.info("result: {}".format(structure_res)) 128 | 129 | if count > 0: 130 | total_time += elapse 131 | count += 1 132 | logger.info("Predict time of {}: {}".format(image_file, elapse)) 133 | 134 | 135 | if __name__ == "__main__": 136 | main(parse_args()) 137 | -------------------------------------------------------------------------------- /script/t_2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['MobileNetV3'] 10 | 11 | 12 | def make_divisible(v, divisor=8, min_value=None): 13 | if min_value is None: 14 | min_value = divisor 15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 16 | if new_v < 0.9 * v: 17 | new_v += divisor 18 | return new_v 19 | 20 | 21 | class MobileNetV3(nn.Module): 22 | def __init__(self, 23 | in_channels=3, 24 | model_name='large', 25 | scale=0.5, 26 | disable_se=False, 27 | **kwargs): 28 | """ 29 | the MobilenetV3 backbone network for detection module. 30 | Args: 31 | params(dict): the super parameters for build network 32 | """ 33 | super(MobileNetV3, self).__init__() 34 | 35 | self.disable_se = disable_se 36 | 37 | if model_name == "large": 38 | cfg = [ 39 | # k, exp, c, se, nl, s, 40 | [3, 16, 16, False, 'relu', 1], 41 | [3, 64, 24, False, 'relu', 2], 42 | [3, 72, 24, False, 'relu', 1], 43 | [5, 72, 40, True, 'relu', 2], 44 | [5, 120, 40, True, 'relu', 1], 45 | [5, 120, 40, True, 'relu', 1], 46 | [3, 240, 80, False, 'hardswish', 2], 47 | [3, 200, 80, False, 'hardswish', 1], 48 | [3, 184, 80, False, 'hardswish', 1], 49 | [3, 184, 80, False, 'hardswish', 1], 50 | [3, 480, 112, True, 'hardswish', 1], 51 | [3, 672, 112, True, 'hardswish', 1], 52 | [5, 672, 160, True, 'hardswish', 2], 53 | [5, 960, 160, True, 'hardswish', 1], 54 | [5, 960, 160, True, 'hardswish', 1], 55 | ] 56 | cls_ch_squeeze = 960 57 | elif model_name == "small": 58 | cfg = [ 59 | # k, exp, c, se, nl, s, 60 | [3, 16, 16, True, 'relu', 2], 61 | [3, 72, 24, False, 'relu', 2], 62 | [3, 88, 24, False, 'relu', 1], 63 | [5, 96, 40, True, 'hardswish', 2], 64 | [5, 240, 40, True, 'hardswish', 1], 65 | [5, 240, 40, True, 'hardswish', 1], 66 | [5, 120, 48, True, 'hardswish', 1], 67 | [5, 144, 48, True, 'hardswish', 1], 68 | [5, 288, 96, True, 'hardswish', 2], 69 | [5, 576, 96, True, 'hardswish', 1], 70 | [5, 576, 96, True, 'hardswish', 1], 71 | ] 72 | cls_ch_squeeze = 576 73 | else: 74 | raise NotImplementedError("mode[" + model_name + 75 | "_model] is not implemented!") 76 | 77 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] 78 | assert scale in supported_scale, \ 79 | "supported scale are {} but input scale is {}".format(supported_scale, scale) 80 | inplanes = 16 81 | # conv1 82 | self.conv = ConvBNLayer( 83 | in_channels=in_channels, 84 | out_channels=make_divisible(inplanes * scale), 85 | kernel_size=3, 86 | stride=2, 87 | padding=1, 88 | groups=1, 89 | if_act=True, 90 | act='hardswish', 91 | name='conv1') 92 | 93 | self.stages = [] 94 | self.out_channels = [] 95 | block_list = [] 96 | i = 0 97 | inplanes = make_divisible(inplanes * scale) 98 | for (k, exp, c, se, nl, s) in cfg: 99 | se = se and not self.disable_se 100 | start_idx = 2 if model_name == 'large' else 0 101 | if s == 2 and i > start_idx: 102 | self.out_channels.append(inplanes) 103 | self.stages.append(nn.Sequential(*block_list)) 104 | block_list = [] 105 | block_list.append( 106 | ResidualUnit( 107 | in_channels=inplanes, 108 | mid_channels=make_divisible(scale * exp), 109 | out_channels=make_divisible(scale * c), 110 | kernel_size=k, 111 | stride=s, 112 | use_se=se, 113 | act=nl, 114 | name="conv" + str(i + 2))) 115 | inplanes = make_divisible(scale * c) 116 | i += 1 117 | block_list.append( 118 | ConvBNLayer( 119 | in_channels=inplanes, 120 | out_channels=make_divisible(scale * cls_ch_squeeze), 121 | kernel_size=1, 122 | stride=1, 123 | padding=0, 124 | groups=1, 125 | if_act=True, 126 | act='hardswish', 127 | name='conv_last')) 128 | self.stages.append(nn.Sequential(*block_list)) 129 | self.stages_pipline = nn.Sequential(*self.stages) 130 | self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) 131 | # for i, stage in enumerate(self.stages): 132 | # self.add_sublayer(sublayer=stage, name="stage{}".format(i)) 133 | 134 | def forward(self, x): 135 | x = self.conv(x) 136 | out_list = [] 137 | # for stage in self.stages: 138 | # x = stage(x) 139 | # out_list.append(x) 140 | x = self.stages_pipline(x) 141 | out_list.append(x) 142 | return out_list 143 | 144 | 145 | class ConvBNLayer(nn.Module): 146 | def __init__(self, 147 | in_channels, 148 | out_channels, 149 | kernel_size, 150 | stride, 151 | padding, 152 | groups=1, 153 | if_act=True, 154 | act=None, 155 | name=None): 156 | super(ConvBNLayer, self).__init__() 157 | self.if_act = if_act 158 | self.act = act 159 | self.conv = nn.Conv2d( 160 | in_channels=in_channels, 161 | out_channels=out_channels, 162 | kernel_size=kernel_size, 163 | stride=stride, 164 | padding=padding, 165 | groups=groups, 166 | bias=False) 167 | 168 | self.bn = nn.BatchNorm2d(out_channels) 169 | 170 | def forward(self, x): 171 | x = self.conv(x) 172 | x = self.bn(x) 173 | if self.if_act: 174 | if self.act == "relu": 175 | x = F.relu(x) 176 | elif self.act == "hardswish": 177 | x = F.hardswish(x) 178 | else: 179 | print("The activation function({}) is selected incorrectly.". 180 | format(self.act)) 181 | exit() 182 | return x 183 | 184 | 185 | class ResidualUnit(nn.Module): 186 | def __init__(self, 187 | in_channels, 188 | mid_channels, 189 | out_channels, 190 | kernel_size, 191 | stride, 192 | use_se, 193 | act=None, 194 | name=''): 195 | super(ResidualUnit, self).__init__() 196 | self.if_shortcut = stride == 1 and in_channels == out_channels 197 | self.if_se = use_se 198 | 199 | self.expand_conv = ConvBNLayer( 200 | in_channels=in_channels, 201 | out_channels=mid_channels, 202 | kernel_size=1, 203 | stride=1, 204 | padding=0, 205 | if_act=True, 206 | act=act, 207 | name=name + "_expand") 208 | self.bottleneck_conv = ConvBNLayer( 209 | in_channels=mid_channels, 210 | out_channels=mid_channels, 211 | kernel_size=kernel_size, 212 | stride=stride, 213 | padding=int((kernel_size - 1) // 2), 214 | groups=mid_channels, 215 | if_act=True, 216 | act=act, 217 | name=name + "_depthwise") 218 | if self.if_se: 219 | self.mid_se = SEModule(mid_channels, name=name + "_se") 220 | self.linear_conv = ConvBNLayer( 221 | in_channels=mid_channels, 222 | out_channels=out_channels, 223 | kernel_size=1, 224 | stride=1, 225 | padding=0, 226 | if_act=False, 227 | act=None, 228 | name=name + "_linear") 229 | 230 | def forward(self, inputs): 231 | x = self.expand_conv(inputs) 232 | x = self.bottleneck_conv(x) 233 | if self.if_se: 234 | x = self.mid_se(x) 235 | x = self.linear_conv(x) 236 | if self.if_shortcut: 237 | x = torch.add(inputs, x) 238 | return x 239 | 240 | 241 | class SEModule(nn.Module): 242 | def __init__(self, in_channels, reduction=4, name=""): 243 | super(SEModule, self).__init__() 244 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 245 | self.conv1 = nn.Conv2d( 246 | in_channels=in_channels, 247 | out_channels=in_channels // reduction, 248 | kernel_size=1, 249 | stride=1, 250 | padding=0) 251 | self.conv2 = nn.Conv2d( 252 | in_channels=in_channels // reduction, 253 | out_channels=in_channels, 254 | kernel_size=1, 255 | stride=1, 256 | padding=0) 257 | 258 | def forward(self, inputs): 259 | outputs = self.avg_pool(inputs) 260 | outputs = self.conv1(outputs) 261 | outputs = F.relu(outputs) 262 | outputs = self.conv2(outputs) 263 | outputs = F.hardsigmoid(outputs) 264 | return inputs * outputs 265 | 266 | 267 | def test(): 268 | net = MobileNetV3() 269 | net.to('cuda:0') 270 | x = torch.randn(2, 3, 224, 224) 271 | y = net(x) 272 | print(y.size()) 273 | 274 | 275 | # test() 276 | a = torch.tensor([1, 2, 3]) 277 | print() 278 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/01/25 09:52 3 | # @Author : lijun 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | import platform 11 | import time 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from models.architectures import build_model 16 | from utils.save_load import load_model 17 | from utils.metrics import build_metric 18 | from utils.data import build_dataloader 19 | from utils.utility import preprocess 20 | 21 | 22 | def test(): 23 | """ 24 | :param: None 25 | :return: 准确率 26 | """ 27 | 28 | """ 读取配置文件--->开始 """ 29 | config, device, logger = preprocess(is_train=True) 30 | """ 读取配置文件--->结束 """ 31 | 32 | """ 构建dataloader--->开始 """ 33 | valid_dataloader = build_dataloader(config, 'Eval', device, logger) 34 | if len(valid_dataloader) == 0: 35 | logger.error("please check val_dataloader\n") 36 | return 37 | """ 构建dataloader--->结束 """ 38 | 39 | """ 构建模型--->开始 """ 40 | model = build_model(config['Architecture']) 41 | load_model(model, config['Global']['checkpoints']) 42 | model.to(device) 43 | """ 构建模型--->结束 """ 44 | 45 | """ 构建metric--->开始 """ 46 | eval_class = build_metric(config['Metric']) 47 | """ 构建metric--->结束 """ 48 | 49 | """ test--->开始 """ 50 | model.eval() 51 | with torch.no_grad(): 52 | total_frame = 0.0 53 | total_time = 0.0 54 | pbar = tqdm(total=len(valid_dataloader), desc='eval model:') 55 | max_iter = len(valid_dataloader) - 1 if platform.system() == "Windows" else len(valid_dataloader) 56 | for idx, batch in enumerate(valid_dataloader): 57 | batch = [i.to(device) for i in batch] 58 | if idx >= max_iter: 59 | break 60 | images = batch[0] 61 | start = time.time() 62 | 63 | preds = model(images, data=batch[1:]) 64 | 65 | # Obtain usable results from post-processing methods 66 | total_time += time.time() - start 67 | # Evaluate the results of the current batch 68 | 69 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(), 70 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()} 71 | batch = [item.cpu().numpy() for item in batch] 72 | eval_class(preds, batch) 73 | 74 | pbar.update(1) 75 | total_frame += len(images) 76 | # Get final metric->acc 77 | metric = eval_class.get_metric() 78 | """ test--->结束 """ 79 | 80 | pbar.close() 81 | model.train() 82 | metric['fps'] = total_frame / total_time 83 | return metric 84 | 85 | 86 | if __name__ == '__main__': 87 | metric = test() 88 | print('metric: ', metric) 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/01/25 09:52 3 | # @Author : lijun 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import sys 11 | import platform 12 | import yaml 13 | import time 14 | import shutil 15 | import torch 16 | from tqdm import tqdm 17 | import numpy as np 18 | 19 | from models.architectures import build_model 20 | from utils.save_load import load_model 21 | from utils.losses import build_loss 22 | from utils.optimizer import build_optimizer 23 | from utils.metrics import build_metric 24 | from utils.stats import TrainingStats 25 | from utils.save_load import save_model 26 | from utils.data import build_dataloader 27 | from utils.utility import preprocess 28 | 29 | 30 | def eval(model, 31 | valid_dataloader, 32 | eval_class, 33 | device): 34 | model.eval() 35 | with torch.no_grad(): 36 | total_frame = 0.0 37 | total_time = 0.0 38 | pbar = tqdm(total=len(valid_dataloader), desc='eval model:') 39 | max_iter = len(valid_dataloader) - 1 if platform.system() == "Windows" else len(valid_dataloader) 40 | for idx, batch in enumerate(valid_dataloader): 41 | batch = [i.to(device) for i in batch] 42 | if idx >= max_iter: 43 | break 44 | images = batch[0] 45 | start = time.time() 46 | 47 | preds = model(images, data=batch[1:]) 48 | 49 | # Obtain usable results from post-processing methods 50 | total_time += time.time() - start 51 | # Evaluate the results of the current batch 52 | 53 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(), 54 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()} 55 | batch = [item.cpu().numpy() for item in batch] 56 | eval_class(preds, batch) 57 | 58 | pbar.update(1) 59 | total_frame += len(images) 60 | # Get final metric->acc 61 | metric = eval_class.get_metric() 62 | 63 | pbar.close() 64 | model.train() 65 | metric['fps'] = total_frame / total_time 66 | return metric 67 | 68 | 69 | def train(): 70 | """ 71 | :param:None 72 | :return: None 73 | """ 74 | 75 | """ 读取配置文件--->开始 """ 76 | config, device, logger = preprocess(is_train=True) 77 | """ 读取配置文件--->结束 """ 78 | 79 | """ 构建dataloader--->开始 """ 80 | train_dataloader = build_dataloader(config, 'Train', device, logger) 81 | valid_dataloader = build_dataloader(config, 'Eval', device, logger) 82 | if len(train_dataloader) == 0: 83 | logger.error("please check train_dataloader\n") 84 | return 85 | if len(valid_dataloader) == 0: 86 | logger.error("please check val_dataloader\n") 87 | return 88 | """ 构建dataloader--->结束 """ 89 | 90 | """ 构建模型--->开始 """ 91 | model = build_model(config['Architecture']) 92 | if config['Global']['checkpoints']: 93 | load_model(model, config['Global']['checkpoints']) 94 | model.to(device) 95 | model = torch.nn.DataParallel(model) 96 | """ 构建模型--->结束 """ 97 | 98 | """ 构建loss--->开始 """ 99 | loss_class = build_loss(config['Loss']) 100 | """ 构建loss--->结束 """ 101 | 102 | """ 构建optim--->开始 """ 103 | optimizer, lr_scheduler = build_optimizer( 104 | config['Optimizer'], 105 | epochs=config['Global']['epoch_num'], 106 | step_each_epoch=len(train_dataloader), 107 | parameters=model.parameters()) 108 | """ 构建optim--->结束 """ 109 | 110 | """ 构建metric--->开始 """ 111 | eval_class = build_metric(config['Metric']) 112 | """ 构建metric--->结束 """ 113 | 114 | """ 其他---> """ 115 | cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) 116 | log_smooth_window = config['Global']['log_smooth_window'] 117 | epoch_num = config['Global']['epoch_num'] 118 | print_batch_step = config['Global']['print_batch_step'] 119 | eval_batch_step = config['Global']['eval_batch_step'] 120 | 121 | global_step = 0 122 | save_epoch_step = config['Global']['save_epoch_step'] 123 | save_model_dir = config['Global']['save_model_dir'] 124 | if not os.path.exists(save_model_dir): 125 | os.makedirs(save_model_dir) 126 | main_indicator = eval_class.main_indicator 127 | best_model_dict = {main_indicator: 0} 128 | train_stats = TrainingStats(log_smooth_window, ['lr']) 129 | """ 其他---> """ 130 | 131 | """ trian--->开始 """ 132 | model.train() 133 | for epoch in range(1, epoch_num + 1): 134 | train_dataloader = build_dataloader(config, 'Train', device, logger, seed=epoch) 135 | train_batch_cost = 0.0 136 | train_reader_cost = 0.0 137 | batch_sum = 0 138 | batch_start = time.time() 139 | max_iter = len(train_dataloader) - 1 if platform.system() == "Windows" else len(train_dataloader) 140 | for idx, batch in enumerate(train_dataloader): 141 | batch = [i.to(device) for i in batch] 142 | train_reader_cost += time.time() - batch_start 143 | if idx >= max_iter: 144 | break 145 | lr = optimizer.defaults['lr'] 146 | images = batch[0] 147 | 148 | preds = model(images, data=batch[1:]) 149 | 150 | loss = loss_class(preds, batch) 151 | avg_loss = loss['loss'] 152 | avg_loss.backward() 153 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['Optimizer']['clip_norm']) 154 | optimizer.step() 155 | optimizer.zero_grad() 156 | 157 | train_batch_cost += time.time() - batch_start 158 | batch_sum += len(images) 159 | 160 | if not isinstance(lr_scheduler, float): 161 | lr_scheduler.step() 162 | 163 | stats = {k: v.cpu().detach().numpy().mean() for k, v in loss.items()} 164 | stats['lr'] = lr 165 | train_stats.update(stats) 166 | 167 | if cal_metric_during_train: # only rec and cls need 168 | batch = [item.cpu().numpy() for item in batch] 169 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(), 170 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()} 171 | eval_class(preds, batch) 172 | metric = eval_class.get_metric() 173 | train_stats.update(metric) 174 | 175 | if (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(train_dataloader) - 1): 176 | logs = train_stats.log() 177 | strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( 178 | epoch, epoch_num, global_step, logs, train_reader_cost / 179 | print_batch_step, train_batch_cost / print_batch_step, 180 | batch_sum, batch_sum / train_batch_cost) 181 | logger.info(strs) 182 | train_batch_cost = 0.0 183 | train_reader_cost = 0.0 184 | batch_sum = 0 185 | if global_step % 400 == 0: 186 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)}, 187 | os.path.join(save_model_dir, 'last.pt')) 188 | # eval 189 | if global_step % eval_batch_step == 0 and global_step > 0: 190 | cur_metric = eval( 191 | model, 192 | valid_dataloader, 193 | eval_class, 194 | device 195 | ) 196 | cur_metric_str = \ 197 | 'cur metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])) 198 | logger.info(cur_metric_str) 199 | 200 | if cur_metric[main_indicator] >= best_model_dict[main_indicator]: 201 | best_model_dict.update(cur_metric) 202 | best_model_dict['best_epoch'] = epoch 203 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)}, 204 | os.path.join(save_model_dir, 'best.pt')) 205 | best_str = \ 206 | 'best metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) 207 | logger.info(best_str) 208 | 209 | global_step += 1 210 | optimizer.zero_grad() 211 | batch_start = time.time() 212 | if epoch % save_epoch_step == 0: 213 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)}, 214 | os.path.join(save_model_dir, str(epoch) + '.pt')) 215 | # if dist.get_rank() == 0: 216 | # save_model( 217 | # model, 218 | # optimizer, 219 | # save_model_dir, 220 | # logger, 221 | # is_best=False, 222 | # prefix='latest', 223 | # best_model_dict=best_model_dict, 224 | # epoch=epoch, 225 | # global_step=global_step) 226 | # if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: 227 | # save_model( 228 | # model, 229 | # optimizer, 230 | # save_model_dir, 231 | # logger, 232 | # is_best=False, 233 | # prefix='iter_epoch_{}'.format(epoch), 234 | # best_model_dict=best_model_dict, 235 | # epoch=epoch, 236 | # global_step=global_step) 237 | best_str = 'best metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) 238 | logger.info(best_str) 239 | """ trian--->结束 """ 240 | 241 | 242 | if __name__ == '__main__': 243 | train() 244 | -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/save_load.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/save_load.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/stats.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/torch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utility.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/utility.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | import signal 10 | import random 11 | 12 | __dir__ = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) 14 | 15 | import copy 16 | from torch.utils.data import DataLoader 17 | from utils.data.imaug import transform, create_operators 18 | from utils.data.pubtab_dataset import PubTabDataSet 19 | 20 | __all__ = ['build_dataloader', 'transform', 'create_operators'] 21 | 22 | 23 | def term_mp(sig_num, frame): 24 | """ kill all child processes 25 | """ 26 | pid = os.getpid() 27 | pgid = os.getpgid(os.getpid()) 28 | print("main proc {} exit, kill process group " "{}".format(pid, pgid)) 29 | os.killpg(pgid, signal.SIGKILL) 30 | 31 | 32 | def build_dataloader(config, mode, device, logger, seed=None): 33 | config = copy.deepcopy(config) 34 | 35 | support_dict = ['PubTabDataSet'] 36 | module_name = config[mode]['dataset']['name'] 37 | assert module_name in support_dict, Exception('DataSet only support {}'.format(support_dict)) 38 | assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test." 39 | 40 | dataset = eval(module_name)(config, mode, logger, seed) 41 | loader_config = config[mode]['loader'] 42 | batch_size = loader_config['batch_size_per_card'] 43 | drop_last = loader_config['drop_last'] 44 | shuffle = loader_config['shuffle'] 45 | num_workers = loader_config['num_workers'] 46 | 47 | data_loader = DataLoader(dataset=dataset, 48 | batch_size=batch_size, 49 | num_workers=num_workers, 50 | pin_memory=True, 51 | shuffle=shuffle, 52 | drop_last=drop_last) 53 | 54 | # support exit using ctrl+c 55 | signal.signal(signal.SIGINT, term_mp) 56 | signal.signal(signal.SIGTERM, term_mp) 57 | 58 | return data_loader 59 | -------------------------------------------------------------------------------- /utils/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/pubtab_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/__pycache__/pubtab_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__init__.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | from __future__ import unicode_literals 18 | 19 | from .iaa_augment import IaaAugment 20 | from .make_border_map import MakeBorderMap 21 | from .make_shrink_map import MakeShrinkMap 22 | from .random_crop_data import EastRandomCropData, PSERandomCrop 23 | 24 | from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg 25 | from .randaugment import RandAugment 26 | from .copy_paste import CopyPaste 27 | from .operators import * 28 | from .label_ops import * 29 | 30 | from .east_process import * 31 | from .sast_process import * 32 | from .pg_process import * 33 | from .gen_table_mask import * 34 | 35 | 36 | def transform(data, ops=None): 37 | """ transform """ 38 | if ops is None: 39 | ops = [] 40 | for op in ops: 41 | data = op(data) 42 | if data is None: 43 | return None 44 | return data 45 | 46 | 47 | def create_operators(op_param_list, global_config=None): 48 | """ 49 | create operators based on the config 50 | 51 | Args: 52 | params(list): a dict list, used to create some operators 53 | """ 54 | assert isinstance(op_param_list, list), ('operator config should be a list') 55 | ops = [] 56 | for operator in op_param_list: 57 | assert isinstance(operator, 58 | dict) and len(operator) == 1, "yaml format error" 59 | op_name = list(operator)[0] 60 | param = {} if operator[op_name] is None else operator[op_name] 61 | if global_config is not None: 62 | param.update(global_config) 63 | op = eval(op_name)(**param) 64 | ops.append(op) 65 | return ops 66 | -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/copy_paste.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/copy_paste.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/east_process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/east_process.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/gen_table_mask.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/gen_table_mask.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/iaa_augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/iaa_augment.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/label_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/label_ops.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/make_border_map.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/make_border_map.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/make_shrink_map.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/make_shrink_map.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/operators.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/operators.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/pg_process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/pg_process.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/randaugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/randaugment.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/random_crop_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/random_crop_data.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/rec_img_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/rec_img_aug.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/__pycache__/sast_process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/sast_process.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/copy_paste.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cv2 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | from shapely.geometry import Polygon 7 | 8 | from utils.data.imaug.iaa_augment import IaaAugment 9 | from utils.data.imaug.random_crop_data import is_poly_outside_rect 10 | from utils.utility import get_rotate_crop_image 11 | 12 | 13 | class CopyPaste(object): 14 | def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): 15 | self.ext_data_num = 1 16 | self.objects_paste_ratio = objects_paste_ratio 17 | self.limit_paste = limit_paste 18 | augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}] 19 | self.aug = IaaAugment(augmenter_args) 20 | 21 | def __call__(self, data): 22 | src_img = data['image'] 23 | src_polys = data['polys'].tolist() 24 | src_ignores = data['ignore_tags'].tolist() 25 | ext_data = data['ext_data'][0] 26 | ext_image = ext_data['image'] 27 | ext_polys = ext_data['polys'] 28 | ext_ignores = ext_data['ignore_tags'] 29 | 30 | indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] 31 | select_num = max( 32 | 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) 33 | 34 | random.shuffle(indexs) 35 | select_idxs = indexs[:select_num] 36 | select_polys = ext_polys[select_idxs] 37 | select_ignores = ext_ignores[select_idxs] 38 | 39 | src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) 40 | ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) 41 | src_img = Image.fromarray(src_img).convert('RGBA') 42 | for poly, tag in zip(select_polys, select_ignores): 43 | box_img = get_rotate_crop_image(ext_image, poly) 44 | 45 | src_img, box = self.paste_img(src_img, box_img, src_polys) 46 | if box is not None: 47 | src_polys.append(box) 48 | src_ignores.append(tag) 49 | src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) 50 | h, w = src_img.shape[:2] 51 | src_polys = np.array(src_polys) 52 | src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) 53 | src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) 54 | data['image'] = src_img 55 | data['polys'] = src_polys 56 | data['ignore_tags'] = np.array(src_ignores) 57 | return data 58 | 59 | def paste_img(self, src_img, box_img, src_polys): 60 | box_img_pil = Image.fromarray(box_img).convert('RGBA') 61 | src_w, src_h = src_img.size 62 | box_w, box_h = box_img_pil.size 63 | 64 | angle = np.random.randint(0, 360) 65 | box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) 66 | box = rotate_bbox(box_img, box, angle)[0] 67 | box_img_pil = box_img_pil.rotate(angle, expand=1) 68 | box_w, box_h = box_img_pil.width, box_img_pil.height 69 | if src_w - box_w < 0 or src_h - box_h < 0: 70 | return src_img, None 71 | 72 | paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, 73 | src_h - box_h) 74 | if paste_x is None: 75 | return src_img, None 76 | box[:, 0] += paste_x 77 | box[:, 1] += paste_y 78 | r, g, b, A = box_img_pil.split() 79 | src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) 80 | 81 | return src_img, box 82 | 83 | def select_coord(self, src_polys, box, endx, endy): 84 | if self.limit_paste: 85 | xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min( 86 | ), box[:, 0].max(), box[:, 1].max() 87 | for _ in range(50): 88 | paste_x = random.randint(0, endx) 89 | paste_y = random.randint(0, endy) 90 | xmin1 = xmin + paste_x 91 | xmax1 = xmax + paste_x 92 | ymin1 = ymin + paste_y 93 | ymax1 = ymax + paste_y 94 | 95 | num_poly_in_rect = 0 96 | for poly in src_polys: 97 | if not is_poly_outside_rect(poly, xmin1, ymin1, 98 | xmax1 - xmin1, ymax1 - ymin1): 99 | num_poly_in_rect += 1 100 | break 101 | if num_poly_in_rect == 0: 102 | return paste_x, paste_y 103 | return None, None 104 | else: 105 | paste_x = random.randint(0, endx) 106 | paste_y = random.randint(0, endy) 107 | return paste_x, paste_y 108 | 109 | 110 | def get_union(pD, pG): 111 | return Polygon(pD).union(Polygon(pG)).area 112 | 113 | 114 | def get_intersection_over_union(pD, pG): 115 | return get_intersection(pD, pG) / get_union(pD, pG) 116 | 117 | 118 | def get_intersection(pD, pG): 119 | return Polygon(pD).intersection(Polygon(pG)).area 120 | 121 | 122 | def rotate_bbox(img, text_polys, angle, scale=1): 123 | """ 124 | from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py 125 | Args: 126 | img: np.ndarray 127 | text_polys: np.ndarray N*4*2 128 | angle: int 129 | scale: int 130 | 131 | Returns: 132 | 133 | """ 134 | w = img.shape[1] 135 | h = img.shape[0] 136 | 137 | rangle = np.deg2rad(angle) 138 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) 139 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) 140 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) 141 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) 142 | rot_mat[0, 2] += rot_move[0] 143 | rot_mat[1, 2] += rot_move[1] 144 | 145 | # ---------------------- rotate box ---------------------- 146 | rot_text_polys = list() 147 | for bbox in text_polys: 148 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) 149 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) 150 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) 151 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) 152 | rot_text_polys.append([point1, point2, point3, point4]) 153 | return np.array(rot_text_polys, dtype=np.float32) 154 | -------------------------------------------------------------------------------- /utils/data/imaug/gen_table_mask.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | import sys 23 | import six 24 | import cv2 25 | import numpy as np 26 | 27 | 28 | class GenTableMask(object): 29 | """ gen table mask """ 30 | 31 | def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs): 32 | self.shrink_h_max = 5 33 | self.shrink_w_max = 5 34 | self.mask_type = mask_type 35 | 36 | def projection(self, erosion, h, w, spilt_threshold=0): 37 | # 水平投影 38 | projection_map = np.ones_like(erosion) 39 | project_val_array = [0 for _ in range(0, h)] 40 | 41 | for j in range(0, h): 42 | for i in range(0, w): 43 | if erosion[j, i] == 255: 44 | project_val_array[j] += 1 45 | # 根据数组,获取切割点 46 | start_idx = 0 # 记录进入字符区的索引 47 | end_idx = 0 # 记录进入空白区域的索引 48 | in_text = False # 是否遍历到了字符区内 49 | box_list = [] 50 | for i in range(len(project_val_array)): 51 | if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 52 | in_text = True 53 | start_idx = i 54 | elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 55 | end_idx = i 56 | in_text = False 57 | if end_idx - start_idx <= 2: 58 | continue 59 | box_list.append((start_idx, end_idx + 1)) 60 | 61 | if in_text: 62 | box_list.append((start_idx, h - 1)) 63 | # 绘制投影直方图 64 | for j in range(0, h): 65 | for i in range(0, project_val_array[j]): 66 | projection_map[j, i] = 0 67 | return box_list, projection_map 68 | 69 | def projection_cx(self, box_img): 70 | box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) 71 | h, w = box_gray_img.shape 72 | # 灰度图片进行二值化处理 73 | ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) 74 | # 纵向腐蚀 75 | if h < w: 76 | kernel = np.ones((2, 1), np.uint8) 77 | erode = cv2.erode(thresh1, kernel, iterations=1) 78 | else: 79 | erode = thresh1 80 | # 水平膨胀 81 | kernel = np.ones((1, 5), np.uint8) 82 | erosion = cv2.dilate(erode, kernel, iterations=1) 83 | # 水平投影 84 | projection_map = np.ones_like(erosion) 85 | project_val_array = [0 for _ in range(0, h)] 86 | 87 | for j in range(0, h): 88 | for i in range(0, w): 89 | if erosion[j, i] == 255: 90 | project_val_array[j] += 1 91 | # 根据数组,获取切割点 92 | start_idx = 0 # 记录进入字符区的索引 93 | end_idx = 0 # 记录进入空白区域的索引 94 | in_text = False # 是否遍历到了字符区内 95 | box_list = [] 96 | spilt_threshold = 0 97 | for i in range(len(project_val_array)): 98 | if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 99 | in_text = True 100 | start_idx = i 101 | elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 102 | end_idx = i 103 | in_text = False 104 | if end_idx - start_idx <= 2: 105 | continue 106 | box_list.append((start_idx, end_idx + 1)) 107 | 108 | if in_text: 109 | box_list.append((start_idx, h - 1)) 110 | # 绘制投影直方图 111 | for j in range(0, h): 112 | for i in range(0, project_val_array[j]): 113 | projection_map[j, i] = 0 114 | split_bbox_list = [] 115 | if len(box_list) > 1: 116 | for i, (h_start, h_end) in enumerate(box_list): 117 | if i == 0: 118 | h_start = 0 119 | if i == len(box_list): 120 | h_end = h 121 | word_img = erosion[h_start:h_end + 1, :] 122 | word_h, word_w = word_img.shape 123 | w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) 124 | w_start, w_end = w_split_list[0][0], w_split_list[-1][1] 125 | if h_start > 0: 126 | h_start -= 1 127 | h_end += 1 128 | word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :] 129 | split_bbox_list.append([w_start, h_start, w_end, h_end]) 130 | else: 131 | split_bbox_list.append([0, 0, w, h]) 132 | return split_bbox_list 133 | 134 | def shrink_bbox(self, bbox): 135 | left, top, right, bottom = bbox 136 | sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max) 137 | sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max) 138 | left_new = left + sh_w 139 | right_new = right - sh_w 140 | top_new = top + sh_h 141 | bottom_new = bottom - sh_h 142 | if left_new >= right_new: 143 | left_new = left 144 | right_new = right 145 | if top_new >= bottom_new: 146 | top_new = top 147 | bottom_new = bottom 148 | return [left_new, top_new, right_new, bottom_new] 149 | 150 | def __call__(self, data): 151 | img = data['image'] 152 | cells = data['cells'] 153 | height, width = img.shape[0:2] 154 | if self.mask_type == 1: 155 | mask_img = np.zeros((height, width), dtype=np.float32) 156 | else: 157 | mask_img = np.zeros((height, width, 3), dtype=np.float32) 158 | cell_num = len(cells) 159 | for cno in range(cell_num): 160 | if "bbox" in cells[cno]: 161 | bbox = cells[cno]['bbox'] 162 | left, top, right, bottom = bbox 163 | box_img = img[top:bottom, left:right, :].copy() 164 | split_bbox_list = self.projection_cx(box_img) 165 | for sno in range(len(split_bbox_list)): 166 | split_bbox_list[sno][0] += left 167 | split_bbox_list[sno][1] += top 168 | split_bbox_list[sno][2] += left 169 | split_bbox_list[sno][3] += top 170 | 171 | for sno in range(len(split_bbox_list)): 172 | left, top, right, bottom = split_bbox_list[sno] 173 | left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) 174 | if self.mask_type == 1: 175 | mask_img[top:bottom, left:right] = 1.0 176 | data['mask_img'] = mask_img 177 | else: 178 | mask_img[top:bottom, left:right, :] = (255, 255, 255) 179 | data['image'] = mask_img 180 | return data 181 | 182 | class ResizeTableImage(object): 183 | def __init__(self, max_len, **kwargs): 184 | super(ResizeTableImage, self).__init__() 185 | self.max_len = max_len 186 | 187 | def get_img_bbox(self, cells): 188 | bbox_list = [] 189 | if len(cells) == 0: 190 | return bbox_list 191 | cell_num = len(cells) 192 | for cno in range(cell_num): 193 | if "bbox" in cells[cno]: 194 | bbox = cells[cno]['bbox'] 195 | bbox_list.append(bbox) 196 | return bbox_list 197 | 198 | def resize_img_table(self, img, bbox_list, max_len): 199 | height, width = img.shape[0:2] 200 | ratio = max_len / (max(height, width) * 1.0) 201 | resize_h = int(height * ratio) 202 | resize_w = int(width * ratio) 203 | img_new = cv2.resize(img, (resize_w, resize_h)) 204 | bbox_list_new = [] 205 | for bno in range(len(bbox_list)): 206 | left, top, right, bottom = bbox_list[bno].copy() 207 | left = int(left * ratio) 208 | top = int(top * ratio) 209 | right = int(right * ratio) 210 | bottom = int(bottom * ratio) 211 | bbox_list_new.append([left, top, right, bottom]) 212 | return img_new, bbox_list_new 213 | 214 | def __call__(self, data): 215 | img = data['image'] 216 | if 'cells' not in data: 217 | cells = [] 218 | else: 219 | cells = data['cells'] 220 | bbox_list = self.get_img_bbox(cells) 221 | img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len) 222 | data['image'] = img_new 223 | cell_num = len(cells) 224 | bno = 0 225 | for cno in range(cell_num): 226 | if "bbox" in data['cells'][cno]: 227 | data['cells'][cno]['bbox'] = bbox_list_new[bno] 228 | bno += 1 229 | data['max_len'] = self.max_len 230 | return data 231 | 232 | class PaddingTableImage(object): 233 | def __init__(self, **kwargs): 234 | super(PaddingTableImage, self).__init__() 235 | 236 | def __call__(self, data): 237 | img = data['image'] 238 | max_len = data['max_len'] 239 | padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) 240 | height, width = img.shape[0:2] 241 | padding_img[0:height, 0:width, :] = img.copy() 242 | data['image'] = padding_img 243 | return data 244 | -------------------------------------------------------------------------------- /utils/data/imaug/iaa_augment.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | from __future__ import unicode_literals 23 | 24 | import numpy as np 25 | import imgaug 26 | import imgaug.augmenters as iaa 27 | 28 | 29 | class AugmenterBuilder(object): 30 | def __init__(self): 31 | pass 32 | 33 | def build(self, args, root=True): 34 | if args is None or len(args) == 0: 35 | return None 36 | elif isinstance(args, list): 37 | if root: 38 | sequence = [self.build(value, root=False) for value in args] 39 | return iaa.Sequential(sequence) 40 | else: 41 | return getattr(iaa, args[0])( 42 | *[self.to_tuple_if_list(a) for a in args[1:]]) 43 | elif isinstance(args, dict): 44 | cls = getattr(iaa, args['type']) 45 | return cls(**{ 46 | k: self.to_tuple_if_list(v) 47 | for k, v in args['args'].items() 48 | }) 49 | else: 50 | raise RuntimeError('unknown augmenter arg: ' + str(args)) 51 | 52 | def to_tuple_if_list(self, obj): 53 | if isinstance(obj, list): 54 | return tuple(obj) 55 | return obj 56 | 57 | 58 | class IaaAugment(): 59 | def __init__(self, augmenter_args=None, **kwargs): 60 | if augmenter_args is None: 61 | augmenter_args = [{ 62 | 'type': 'Fliplr', 63 | 'args': { 64 | 'p': 0.5 65 | } 66 | }, { 67 | 'type': 'Affine', 68 | 'args': { 69 | 'rotate': [-10, 10] 70 | } 71 | }, { 72 | 'type': 'Resize', 73 | 'args': { 74 | 'size': [0.5, 3] 75 | } 76 | }] 77 | self.augmenter = AugmenterBuilder().build(augmenter_args) 78 | 79 | def __call__(self, data): 80 | image = data['image'] 81 | shape = image.shape 82 | 83 | if self.augmenter: 84 | aug = self.augmenter.to_deterministic() 85 | data['image'] = aug.augment_image(image) 86 | data = self.may_augment_annotation(aug, data, shape) 87 | return data 88 | 89 | def may_augment_annotation(self, aug, data, shape): 90 | if aug is None: 91 | return data 92 | 93 | line_polys = [] 94 | for poly in data['polys']: 95 | new_poly = self.may_augment_poly(aug, shape, poly) 96 | line_polys.append(new_poly) 97 | data['polys'] = np.array(line_polys) 98 | return data 99 | 100 | def may_augment_poly(self, aug, img_shape, poly): 101 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 102 | keypoints = aug.augment_keypoints( 103 | [imgaug.KeypointsOnImage( 104 | keypoints, shape=img_shape)])[0].keypoints 105 | poly = [(p.x, p.y) for p in keypoints] 106 | return poly 107 | -------------------------------------------------------------------------------- /utils/data/imaug/make_border_map.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import numpy as np 24 | import cv2 25 | 26 | np.seterr(divide='ignore', invalid='ignore') 27 | import pyclipper 28 | from shapely.geometry import Polygon 29 | import sys 30 | import warnings 31 | 32 | warnings.simplefilter("ignore") 33 | 34 | __all__ = ['MakeBorderMap'] 35 | 36 | 37 | class MakeBorderMap(object): 38 | def __init__(self, 39 | shrink_ratio=0.4, 40 | thresh_min=0.3, 41 | thresh_max=0.7, 42 | **kwargs): 43 | self.shrink_ratio = shrink_ratio 44 | self.thresh_min = thresh_min 45 | self.thresh_max = thresh_max 46 | 47 | def __call__(self, data): 48 | 49 | img = data['image'] 50 | text_polys = data['polys'] 51 | ignore_tags = data['ignore_tags'] 52 | 53 | canvas = np.zeros(img.shape[:2], dtype=np.float32) 54 | mask = np.zeros(img.shape[:2], dtype=np.float32) 55 | 56 | for i in range(len(text_polys)): 57 | if ignore_tags[i]: 58 | continue 59 | self.draw_border_map(text_polys[i], canvas, mask=mask) 60 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 61 | 62 | data['threshold_map'] = canvas 63 | data['threshold_mask'] = mask 64 | return data 65 | 66 | def draw_border_map(self, polygon, canvas, mask): 67 | polygon = np.array(polygon) 68 | assert polygon.ndim == 2 69 | assert polygon.shape[1] == 2 70 | 71 | polygon_shape = Polygon(polygon) 72 | if polygon_shape.area <= 0: 73 | return 74 | distance = polygon_shape.area * ( 75 | 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 76 | subject = [tuple(l) for l in polygon] 77 | padding = pyclipper.PyclipperOffset() 78 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 79 | 80 | padded_polygon = np.array(padding.Execute(distance)[0]) 81 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 82 | 83 | xmin = padded_polygon[:, 0].min() 84 | xmax = padded_polygon[:, 0].max() 85 | ymin = padded_polygon[:, 1].min() 86 | ymax = padded_polygon[:, 1].max() 87 | width = xmax - xmin + 1 88 | height = ymax - ymin + 1 89 | 90 | polygon[:, 0] = polygon[:, 0] - xmin 91 | polygon[:, 1] = polygon[:, 1] - ymin 92 | 93 | xs = np.broadcast_to( 94 | np.linspace( 95 | 0, width - 1, num=width).reshape(1, width), (height, width)) 96 | ys = np.broadcast_to( 97 | np.linspace( 98 | 0, height - 1, num=height).reshape(height, 1), (height, width)) 99 | 100 | distance_map = np.zeros( 101 | (polygon.shape[0], height, width), dtype=np.float32) 102 | for i in range(polygon.shape[0]): 103 | j = (i + 1) % polygon.shape[0] 104 | absolute_distance = self._distance(xs, ys, polygon[i], polygon[j]) 105 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 106 | distance_map = distance_map.min(axis=0) 107 | 108 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 109 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 110 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 111 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 112 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 113 | 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height, 114 | xmin_valid - xmin:xmax_valid - xmax + width], 115 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) 116 | 117 | def _distance(self, xs, ys, point_1, point_2): 118 | ''' 119 | compute the distance from point to a line 120 | ys: coordinates in the first axis 121 | xs: coordinates in the second axis 122 | point_1, point_2: (x, y), the end of the line 123 | ''' 124 | height, width = xs.shape[:2] 125 | square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[ 126 | 1]) 127 | square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[ 128 | 1]) 129 | square_distance = np.square(point_1[0] - point_2[0]) + np.square( 130 | point_1[1] - point_2[1]) 131 | 132 | cosin = (square_distance - square_distance_1 - square_distance_2) / ( 133 | 2 * np.sqrt(square_distance_1 * square_distance_2)) 134 | square_sin = 1 - np.square(cosin) 135 | square_sin = np.nan_to_num(square_sin) 136 | result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / 137 | square_distance) 138 | 139 | result[cosin < 140 | 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin 141 | < 0] 142 | # self.extend_line(point_1, point_2, result) 143 | return result 144 | 145 | def extend_line(self, point_1, point_2, result, shrink_ratio): 146 | ex_point_1 = (int( 147 | round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))), 148 | int( 149 | round(point_1[1] + (point_1[1] - point_2[1]) * ( 150 | 1 + shrink_ratio)))) 151 | cv2.line( 152 | result, 153 | tuple(ex_point_1), 154 | tuple(point_1), 155 | 4096.0, 156 | 1, 157 | lineType=cv2.LINE_AA, 158 | shift=0) 159 | ex_point_2 = (int( 160 | round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))), 161 | int( 162 | round(point_2[1] + (point_2[1] - point_1[1]) * ( 163 | 1 + shrink_ratio)))) 164 | cv2.line( 165 | result, 166 | tuple(ex_point_2), 167 | tuple(point_2), 168 | 4096.0, 169 | 1, 170 | lineType=cv2.LINE_AA, 171 | shift=0) 172 | return ex_point_1, ex_point_2 173 | -------------------------------------------------------------------------------- /utils/data/imaug/make_shrink_map.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import numpy as np 24 | import cv2 25 | from shapely.geometry import Polygon 26 | import pyclipper 27 | 28 | __all__ = ['MakeShrinkMap'] 29 | 30 | 31 | class MakeShrinkMap(object): 32 | r''' 33 | Making binary mask from detection data with ICDAR format. 34 | Typically following the process of class `MakeICDARData`. 35 | ''' 36 | 37 | def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs): 38 | self.min_text_size = min_text_size 39 | self.shrink_ratio = shrink_ratio 40 | 41 | def __call__(self, data): 42 | image = data['image'] 43 | text_polys = data['polys'] 44 | ignore_tags = data['ignore_tags'] 45 | 46 | h, w = image.shape[:2] 47 | text_polys, ignore_tags = self.validate_polygons(text_polys, 48 | ignore_tags, h, w) 49 | gt = np.zeros((h, w), dtype=np.float32) 50 | mask = np.ones((h, w), dtype=np.float32) 51 | for i in range(len(text_polys)): 52 | polygon = text_polys[i] 53 | height = max(polygon[:, 1]) - min(polygon[:, 1]) 54 | width = max(polygon[:, 0]) - min(polygon[:, 0]) 55 | if ignore_tags[i] or min(height, width) < self.min_text_size: 56 | cv2.fillPoly(mask, 57 | polygon.astype(np.int32)[np.newaxis, :, :], 0) 58 | ignore_tags[i] = True 59 | else: 60 | polygon_shape = Polygon(polygon) 61 | subject = [tuple(l) for l in polygon] 62 | padding = pyclipper.PyclipperOffset() 63 | padding.AddPath(subject, pyclipper.JT_ROUND, 64 | pyclipper.ET_CLOSEDPOLYGON) 65 | shrinked = [] 66 | 67 | # Increase the shrink ratio every time we get multiple polygon returned back 68 | possible_ratios = np.arange(self.shrink_ratio, 1, 69 | self.shrink_ratio) 70 | np.append(possible_ratios, 1) 71 | # print(possible_ratios) 72 | for ratio in possible_ratios: 73 | # print(f"Change shrink ratio to {ratio}") 74 | distance = polygon_shape.area * ( 75 | 1 - np.power(ratio, 2)) / polygon_shape.length 76 | shrinked = padding.Execute(-distance) 77 | if len(shrinked) == 1: 78 | break 79 | 80 | if shrinked == []: 81 | cv2.fillPoly(mask, 82 | polygon.astype(np.int32)[np.newaxis, :, :], 0) 83 | ignore_tags[i] = True 84 | continue 85 | 86 | for each_shirnk in shrinked: 87 | shirnk = np.array(each_shirnk).reshape(-1, 2) 88 | cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1) 89 | 90 | data['shrink_map'] = gt 91 | data['shrink_mask'] = mask 92 | return data 93 | 94 | def validate_polygons(self, polygons, ignore_tags, h, w): 95 | ''' 96 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 97 | ''' 98 | if len(polygons) == 0: 99 | return polygons, ignore_tags 100 | assert len(polygons) == len(ignore_tags) 101 | for polygon in polygons: 102 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) 103 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) 104 | 105 | for i in range(len(polygons)): 106 | area = self.polygon_area(polygons[i]) 107 | if abs(area) < 1: 108 | ignore_tags[i] = True 109 | if area > 0: 110 | polygons[i] = polygons[i][::-1, :] 111 | return polygons, ignore_tags 112 | 113 | def polygon_area(self, polygon): 114 | """ 115 | compute polygon area 116 | """ 117 | area = 0 118 | q = polygon[-1] 119 | for p in polygon: 120 | area += p[0] * q[1] - p[1] * q[0] 121 | q = p 122 | return area / 2.0 123 | -------------------------------------------------------------------------------- /utils/data/imaug/randaugment.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | from PIL import Image, ImageEnhance, ImageOps 21 | import numpy as np 22 | import random 23 | import six 24 | 25 | 26 | class RawRandAugment(object): 27 | def __init__(self, 28 | num_layers=2, 29 | magnitude=5, 30 | fillcolor=(128, 128, 128), 31 | **kwargs): 32 | self.num_layers = num_layers 33 | self.magnitude = magnitude 34 | self.max_level = 10 35 | 36 | abso_level = self.magnitude / self.max_level 37 | self.level_map = { 38 | "shearX": 0.3 * abso_level, 39 | "shearY": 0.3 * abso_level, 40 | "translateX": 150.0 / 331 * abso_level, 41 | "translateY": 150.0 / 331 * abso_level, 42 | "rotate": 30 * abso_level, 43 | "color": 0.9 * abso_level, 44 | "posterize": int(4.0 * abso_level), 45 | "solarize": 256.0 * abso_level, 46 | "contrast": 0.9 * abso_level, 47 | "sharpness": 0.9 * abso_level, 48 | "brightness": 0.9 * abso_level, 49 | "autocontrast": 0, 50 | "equalize": 0, 51 | "invert": 0 52 | } 53 | 54 | # from https://stackoverflow.com/questions/5252170/ 55 | # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 56 | def rotate_with_fill(img, magnitude): 57 | rot = img.convert("RGBA").rotate(magnitude) 58 | return Image.composite(rot, 59 | Image.new("RGBA", rot.size, (128, ) * 4), 60 | rot).convert(img.mode) 61 | 62 | rnd_ch_op = random.choice 63 | 64 | self.func = { 65 | "shearX": lambda img, magnitude: img.transform( 66 | img.size, 67 | Image.AFFINE, 68 | (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), 69 | Image.BICUBIC, 70 | fillcolor=fillcolor), 71 | "shearY": lambda img, magnitude: img.transform( 72 | img.size, 73 | Image.AFFINE, 74 | (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), 75 | Image.BICUBIC, 76 | fillcolor=fillcolor), 77 | "translateX": lambda img, magnitude: img.transform( 78 | img.size, 79 | Image.AFFINE, 80 | (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), 81 | fillcolor=fillcolor), 82 | "translateY": lambda img, magnitude: img.transform( 83 | img.size, 84 | Image.AFFINE, 85 | (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), 86 | fillcolor=fillcolor), 87 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 88 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( 89 | 1 + magnitude * rnd_ch_op([-1, 1])), 90 | "posterize": lambda img, magnitude: 91 | ImageOps.posterize(img, magnitude), 92 | "solarize": lambda img, magnitude: 93 | ImageOps.solarize(img, magnitude), 94 | "contrast": lambda img, magnitude: 95 | ImageEnhance.Contrast(img).enhance( 96 | 1 + magnitude * rnd_ch_op([-1, 1])), 97 | "sharpness": lambda img, magnitude: 98 | ImageEnhance.Sharpness(img).enhance( 99 | 1 + magnitude * rnd_ch_op([-1, 1])), 100 | "brightness": lambda img, magnitude: 101 | ImageEnhance.Brightness(img).enhance( 102 | 1 + magnitude * rnd_ch_op([-1, 1])), 103 | "autocontrast": lambda img, magnitude: 104 | ImageOps.autocontrast(img), 105 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 106 | "invert": lambda img, magnitude: ImageOps.invert(img) 107 | } 108 | 109 | def __call__(self, img): 110 | avaiable_op_names = list(self.level_map.keys()) 111 | for layer_num in range(self.num_layers): 112 | op_name = np.random.choice(avaiable_op_names) 113 | img = self.func[op_name](img, self.level_map[op_name]) 114 | return img 115 | 116 | 117 | class RandAugment(RawRandAugment): 118 | """ RandAugment wrapper to auto fit different img types """ 119 | 120 | def __init__(self, prob=0.5, *args, **kwargs): 121 | self.prob = prob 122 | if six.PY2: 123 | super(RandAugment, self).__init__(*args, **kwargs) 124 | else: 125 | super().__init__(*args, **kwargs) 126 | 127 | def __call__(self, data): 128 | if np.random.rand() > self.prob: 129 | return data 130 | img = data['image'] 131 | if not isinstance(img, Image.Image): 132 | img = np.ascontiguousarray(img) 133 | img = Image.fromarray(img) 134 | 135 | if six.PY2: 136 | img = super(RandAugment, self).__call__(img) 137 | else: 138 | img = super().__call__(img) 139 | 140 | if isinstance(img, Image.Image): 141 | img = np.asarray(img) 142 | data['image'] = img 143 | return data 144 | -------------------------------------------------------------------------------- /utils/data/imaug/random_crop_data.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import numpy as np 24 | import cv2 25 | import random 26 | 27 | 28 | def is_poly_in_rect(poly, x, y, w, h): 29 | poly = np.array(poly) 30 | if poly[:, 0].min() < x or poly[:, 0].max() > x + w: 31 | return False 32 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h: 33 | return False 34 | return True 35 | 36 | 37 | def is_poly_outside_rect(poly, x, y, w, h): 38 | poly = np.array(poly) 39 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w: 40 | return True 41 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h: 42 | return True 43 | return False 44 | 45 | 46 | def split_regions(axis): 47 | regions = [] 48 | min_axis = 0 49 | for i in range(1, axis.shape[0]): 50 | if axis[i] != axis[i - 1] + 1: 51 | region = axis[min_axis:i] 52 | min_axis = i 53 | regions.append(region) 54 | return regions 55 | 56 | 57 | def random_select(axis, max_size): 58 | xx = np.random.choice(axis, size=2) 59 | xmin = np.min(xx) 60 | xmax = np.max(xx) 61 | xmin = np.clip(xmin, 0, max_size - 1) 62 | xmax = np.clip(xmax, 0, max_size - 1) 63 | return xmin, xmax 64 | 65 | 66 | def region_wise_random_select(regions, max_size): 67 | selected_index = list(np.random.choice(len(regions), 2)) 68 | selected_values = [] 69 | for index in selected_index: 70 | axis = regions[index] 71 | xx = int(np.random.choice(axis, size=1)) 72 | selected_values.append(xx) 73 | xmin = min(selected_values) 74 | xmax = max(selected_values) 75 | return xmin, xmax 76 | 77 | 78 | def crop_area(im, text_polys, min_crop_side_ratio, max_tries): 79 | h, w, _ = im.shape 80 | h_array = np.zeros(h, dtype=np.int32) 81 | w_array = np.zeros(w, dtype=np.int32) 82 | for points in text_polys: 83 | points = np.round(points, decimals=0).astype(np.int32) 84 | minx = np.min(points[:, 0]) 85 | maxx = np.max(points[:, 0]) 86 | w_array[minx:maxx] = 1 87 | miny = np.min(points[:, 1]) 88 | maxy = np.max(points[:, 1]) 89 | h_array[miny:maxy] = 1 90 | # ensure the cropped area not across a text 91 | h_axis = np.where(h_array == 0)[0] 92 | w_axis = np.where(w_array == 0)[0] 93 | 94 | if len(h_axis) == 0 or len(w_axis) == 0: 95 | return 0, 0, w, h 96 | 97 | h_regions = split_regions(h_axis) 98 | w_regions = split_regions(w_axis) 99 | 100 | for i in range(max_tries): 101 | if len(w_regions) > 1: 102 | xmin, xmax = region_wise_random_select(w_regions, w) 103 | else: 104 | xmin, xmax = random_select(w_axis, w) 105 | if len(h_regions) > 1: 106 | ymin, ymax = region_wise_random_select(h_regions, h) 107 | else: 108 | ymin, ymax = random_select(h_axis, h) 109 | 110 | if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h: 111 | # area too small 112 | continue 113 | num_poly_in_rect = 0 114 | for poly in text_polys: 115 | if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, 116 | ymax - ymin): 117 | num_poly_in_rect += 1 118 | break 119 | 120 | if num_poly_in_rect > 0: 121 | return xmin, ymin, xmax - xmin, ymax - ymin 122 | 123 | return 0, 0, w, h 124 | 125 | 126 | class EastRandomCropData(object): 127 | def __init__(self, 128 | size=(640, 640), 129 | max_tries=10, 130 | min_crop_side_ratio=0.1, 131 | keep_ratio=True, 132 | **kwargs): 133 | self.size = size 134 | self.max_tries = max_tries 135 | self.min_crop_side_ratio = min_crop_side_ratio 136 | self.keep_ratio = keep_ratio 137 | 138 | def __call__(self, data): 139 | img = data['image'] 140 | text_polys = data['polys'] 141 | ignore_tags = data['ignore_tags'] 142 | texts = data['texts'] 143 | all_care_polys = [ 144 | text_polys[i] for i, tag in enumerate(ignore_tags) if not tag 145 | ] 146 | # 计算crop区域 147 | crop_x, crop_y, crop_w, crop_h = crop_area( 148 | img, all_care_polys, self.min_crop_side_ratio, self.max_tries) 149 | # crop 图片 保持比例填充 150 | scale_w = self.size[0] / crop_w 151 | scale_h = self.size[1] / crop_h 152 | scale = min(scale_w, scale_h) 153 | h = int(crop_h * scale) 154 | w = int(crop_w * scale) 155 | if self.keep_ratio: 156 | padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), 157 | img.dtype) 158 | padimg[:h, :w] = cv2.resize( 159 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) 160 | img = padimg 161 | else: 162 | img = cv2.resize( 163 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], 164 | tuple(self.size)) 165 | # crop 文本框 166 | text_polys_crop = [] 167 | ignore_tags_crop = [] 168 | texts_crop = [] 169 | for poly, text, tag in zip(text_polys, texts, ignore_tags): 170 | poly = ((poly - (crop_x, crop_y)) * scale).tolist() 171 | if not is_poly_outside_rect(poly, 0, 0, w, h): 172 | text_polys_crop.append(poly) 173 | ignore_tags_crop.append(tag) 174 | texts_crop.append(text) 175 | data['image'] = img 176 | data['polys'] = np.array(text_polys_crop) 177 | data['ignore_tags'] = ignore_tags_crop 178 | data['texts'] = texts_crop 179 | return data 180 | 181 | 182 | class PSERandomCrop(object): 183 | def __init__(self, size, **kwargs): 184 | self.size = size 185 | 186 | def __call__(self, data): 187 | imgs = data['imgs'] 188 | 189 | h, w = imgs[0].shape[0:2] 190 | th, tw = self.size 191 | if w == tw and h == th: 192 | return imgs 193 | 194 | # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制 195 | if np.max(imgs[2]) > 0 and random.random() > 3 / 8: 196 | # 文本实例的左上角点 197 | tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size 198 | tl[tl < 0] = 0 199 | # 文本实例的右下角点 200 | br = np.max(np.where(imgs[2] > 0), axis=1) - self.size 201 | br[br < 0] = 0 202 | # 保证选到右下角点时,有足够的距离进行crop 203 | br[0] = min(br[0], h - th) 204 | br[1] = min(br[1], w - tw) 205 | 206 | for _ in range(50000): 207 | i = random.randint(tl[0], br[0]) 208 | j = random.randint(tl[1], br[1]) 209 | # 保证shrink_label_map有文本 210 | if imgs[1][i:i + th, j:j + tw].sum() <= 0: 211 | continue 212 | else: 213 | break 214 | else: 215 | i = random.randint(0, h - th) 216 | j = random.randint(0, w - tw) 217 | 218 | # return i, j, th, tw 219 | for idx in range(len(imgs)): 220 | if len(imgs[idx].shape) == 3: 221 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 222 | else: 223 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 224 | data['imgs'] = imgs 225 | return data 226 | -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/__init__.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .augment import tia_perspective, tia_distort, tia_stretch 16 | 17 | __all__ = ['tia_distort', 'tia_stretch', 'tia_perspective'] 18 | -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/augment.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py 17 | """ 18 | import numpy as np 19 | from .warp_mls import WarpMLS 20 | 21 | 22 | def tia_distort(src, segment=4): 23 | img_h, img_w = src.shape[:2] 24 | 25 | cut = img_w // segment 26 | thresh = cut // 3 27 | 28 | src_pts = list() 29 | dst_pts = list() 30 | 31 | src_pts.append([0, 0]) 32 | src_pts.append([img_w, 0]) 33 | src_pts.append([img_w, img_h]) 34 | src_pts.append([0, img_h]) 35 | 36 | dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) 37 | dst_pts.append( 38 | [img_w - np.random.randint(thresh), np.random.randint(thresh)]) 39 | dst_pts.append( 40 | [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]) 41 | dst_pts.append( 42 | [np.random.randint(thresh), img_h - np.random.randint(thresh)]) 43 | 44 | half_thresh = thresh * 0.5 45 | 46 | for cut_idx in np.arange(1, segment, 1): 47 | src_pts.append([cut * cut_idx, 0]) 48 | src_pts.append([cut * cut_idx, img_h]) 49 | dst_pts.append([ 50 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 51 | np.random.randint(thresh) - half_thresh 52 | ]) 53 | dst_pts.append([ 54 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 55 | img_h + np.random.randint(thresh) - half_thresh 56 | ]) 57 | 58 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 59 | dst = trans.generate() 60 | 61 | return dst 62 | 63 | 64 | def tia_stretch(src, segment=4): 65 | img_h, img_w = src.shape[:2] 66 | 67 | cut = img_w // segment 68 | thresh = cut * 4 // 5 69 | 70 | src_pts = list() 71 | dst_pts = list() 72 | 73 | src_pts.append([0, 0]) 74 | src_pts.append([img_w, 0]) 75 | src_pts.append([img_w, img_h]) 76 | src_pts.append([0, img_h]) 77 | 78 | dst_pts.append([0, 0]) 79 | dst_pts.append([img_w, 0]) 80 | dst_pts.append([img_w, img_h]) 81 | dst_pts.append([0, img_h]) 82 | 83 | half_thresh = thresh * 0.5 84 | 85 | for cut_idx in np.arange(1, segment, 1): 86 | move = np.random.randint(thresh) - half_thresh 87 | src_pts.append([cut * cut_idx, 0]) 88 | src_pts.append([cut * cut_idx, img_h]) 89 | dst_pts.append([cut * cut_idx + move, 0]) 90 | dst_pts.append([cut * cut_idx + move, img_h]) 91 | 92 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 93 | dst = trans.generate() 94 | 95 | return dst 96 | 97 | 98 | def tia_perspective(src): 99 | img_h, img_w = src.shape[:2] 100 | 101 | thresh = img_h // 2 102 | 103 | src_pts = list() 104 | dst_pts = list() 105 | 106 | src_pts.append([0, 0]) 107 | src_pts.append([img_w, 0]) 108 | src_pts.append([img_w, img_h]) 109 | src_pts.append([0, img_h]) 110 | 111 | dst_pts.append([0, np.random.randint(thresh)]) 112 | dst_pts.append([img_w, np.random.randint(thresh)]) 113 | dst_pts.append([img_w, img_h - np.random.randint(thresh)]) 114 | dst_pts.append([0, img_h - np.random.randint(thresh)]) 115 | 116 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 117 | dst = trans.generate() 118 | 119 | return dst -------------------------------------------------------------------------------- /utils/data/imaug/text_image_aug/warp_mls.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This code is refer from: 16 | https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py 17 | """ 18 | import numpy as np 19 | 20 | 21 | class WarpMLS: 22 | def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.): 23 | self.src = src 24 | self.src_pts = src_pts 25 | self.dst_pts = dst_pts 26 | self.pt_count = len(self.dst_pts) 27 | self.dst_w = dst_w 28 | self.dst_h = dst_h 29 | self.trans_ratio = trans_ratio 30 | self.grid_size = 100 31 | self.rdx = np.zeros((self.dst_h, self.dst_w)) 32 | self.rdy = np.zeros((self.dst_h, self.dst_w)) 33 | 34 | @staticmethod 35 | def __bilinear_interp(x, y, v11, v12, v21, v22): 36 | return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * 37 | (1 - y) + v22 * y) * x 38 | 39 | def generate(self): 40 | self.calc_delta() 41 | return self.gen_img() 42 | 43 | def calc_delta(self): 44 | w = np.zeros(self.pt_count, dtype=np.float32) 45 | 46 | if self.pt_count < 2: 47 | return 48 | 49 | i = 0 50 | while 1: 51 | if self.dst_w <= i < self.dst_w + self.grid_size - 1: 52 | i = self.dst_w - 1 53 | elif i >= self.dst_w: 54 | break 55 | 56 | j = 0 57 | while 1: 58 | if self.dst_h <= j < self.dst_h + self.grid_size - 1: 59 | j = self.dst_h - 1 60 | elif j >= self.dst_h: 61 | break 62 | 63 | sw = 0 64 | swp = np.zeros(2, dtype=np.float32) 65 | swq = np.zeros(2, dtype=np.float32) 66 | new_pt = np.zeros(2, dtype=np.float32) 67 | cur_pt = np.array([i, j], dtype=np.float32) 68 | 69 | k = 0 70 | for k in range(self.pt_count): 71 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 72 | break 73 | 74 | w[k] = 1. / ( 75 | (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + 76 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) 77 | 78 | sw += w[k] 79 | swp = swp + w[k] * np.array(self.dst_pts[k]) 80 | swq = swq + w[k] * np.array(self.src_pts[k]) 81 | 82 | if k == self.pt_count - 1: 83 | pstar = 1 / sw * swp 84 | qstar = 1 / sw * swq 85 | 86 | miu_s = 0 87 | for k in range(self.pt_count): 88 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 89 | continue 90 | pt_i = self.dst_pts[k] - pstar 91 | miu_s += w[k] * np.sum(pt_i * pt_i) 92 | 93 | cur_pt -= pstar 94 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) 95 | 96 | for k in range(self.pt_count): 97 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 98 | continue 99 | 100 | pt_i = self.dst_pts[k] - pstar 101 | pt_j = np.array([-pt_i[1], pt_i[0]]) 102 | 103 | tmp_pt = np.zeros(2, dtype=np.float32) 104 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ 105 | np.sum(pt_j * cur_pt) * self.src_pts[k][1] 106 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ 107 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] 108 | tmp_pt *= (w[k] / miu_s) 109 | new_pt += tmp_pt 110 | 111 | new_pt += qstar 112 | else: 113 | new_pt = self.src_pts[k] 114 | 115 | self.rdx[j, i] = new_pt[0] - i 116 | self.rdy[j, i] = new_pt[1] - j 117 | 118 | j += self.grid_size 119 | i += self.grid_size 120 | 121 | def gen_img(self): 122 | src_h, src_w = self.src.shape[:2] 123 | dst = np.zeros_like(self.src, dtype=np.float32) 124 | 125 | for i in np.arange(0, self.dst_h, self.grid_size): 126 | for j in np.arange(0, self.dst_w, self.grid_size): 127 | ni = i + self.grid_size 128 | nj = j + self.grid_size 129 | w = h = self.grid_size 130 | if ni >= self.dst_h: 131 | ni = self.dst_h - 1 132 | h = ni - i + 1 133 | if nj >= self.dst_w: 134 | nj = self.dst_w - 1 135 | w = nj - j + 1 136 | 137 | di = np.reshape(np.arange(h), (-1, 1)) 138 | dj = np.reshape(np.arange(w), (1, -1)) 139 | delta_x = self.__bilinear_interp( 140 | di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], 141 | self.rdx[ni, j], self.rdx[ni, nj]) 142 | delta_y = self.__bilinear_interp( 143 | di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], 144 | self.rdy[ni, j], self.rdy[ni, nj]) 145 | nx = j + dj + delta_x * self.trans_ratio 146 | ny = i + di + delta_y * self.trans_ratio 147 | nx = np.clip(nx, 0, src_w - 1) 148 | ny = np.clip(ny, 0, src_h - 1) 149 | nxi = np.array(np.floor(nx), dtype=np.int32) 150 | nyi = np.array(np.floor(ny), dtype=np.int32) 151 | nxi1 = np.array(np.ceil(nx), dtype=np.int32) 152 | nyi1 = np.array(np.ceil(ny), dtype=np.int32) 153 | 154 | if len(self.src.shape) == 3: 155 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) 156 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) 157 | else: 158 | x = ny - nyi 159 | y = nx - nxi 160 | dst[i:i + h, j:j + w] = self.__bilinear_interp( 161 | x, y, self.src[nyi, nxi], self.src[nyi, nxi1], 162 | self.src[nyi1, nxi], self.src[nyi1, nxi1]) 163 | 164 | dst = np.clip(dst, 0, 255) 165 | dst = np.array(dst, dtype=np.uint8) 166 | 167 | return dst -------------------------------------------------------------------------------- /utils/data/pubtab_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | import json 6 | from .imaug import transform, create_operators 7 | 8 | 9 | class PubTabDataSet(Dataset): 10 | def __init__(self, config, mode, logger, seed=None): 11 | super(PubTabDataSet, self).__init__() 12 | self.logger = logger 13 | 14 | global_config = config['Global'] 15 | dataset_config = config[mode]['dataset'] 16 | loader_config = config[mode]['loader'] 17 | 18 | label_file_path = dataset_config.pop('label_file_path') 19 | 20 | self.data_dir = dataset_config['data_dir'] 21 | self.do_shuffle = loader_config['shuffle'] 22 | self.do_hard_select = False 23 | if 'hard_select' in loader_config: 24 | self.do_hard_select = loader_config['hard_select'] 25 | self.hard_prob = loader_config['hard_prob'] 26 | if self.do_hard_select: 27 | self.img_select_prob = self.load_hard_select_prob() 28 | self.table_select_type = None 29 | if 'table_select_type' in loader_config: 30 | self.table_select_type = loader_config['table_select_type'] 31 | self.table_select_prob = loader_config['table_select_prob'] 32 | 33 | self.seed = seed 34 | logger.info("Initialize indexs of datasets:%s" % label_file_path) 35 | with open(label_file_path, "rb") as f: 36 | self.data_lines = f.readlines() 37 | self.data_idx_order_list = list(range(len(self.data_lines))) 38 | if mode.lower() == "train": 39 | self.shuffle_data_random() 40 | self.ops = create_operators(dataset_config['transforms'], global_config) 41 | 42 | def shuffle_data_random(self): 43 | if self.do_shuffle: 44 | random.seed(self.seed) 45 | random.shuffle(self.data_lines) 46 | return 47 | 48 | def __getitem__(self, idx): 49 | try: 50 | data_line = self.data_lines[idx] 51 | data_line = data_line.decode('utf-8').strip("\n") 52 | info = json.loads(data_line) 53 | file_name = info['filename'] 54 | select_flag = True 55 | if self.do_hard_select: 56 | prob = self.img_select_prob[file_name] 57 | if prob < random.uniform(0, 1): 58 | select_flag = False 59 | 60 | if self.table_select_type: 61 | structure = info['html']['structure']['tokens'].copy() 62 | structure_str = ''.join(structure) 63 | table_type = "simple" 64 | if 'colspan' in structure_str or 'rowspan' in structure_str: 65 | table_type = "complex" 66 | if table_type == "complex": 67 | if self.table_select_prob < random.uniform(0, 1): 68 | select_flag = False 69 | 70 | if select_flag: 71 | cells = info['html']['cells'].copy() 72 | structure = info['html']['structure'].copy() 73 | img_path = os.path.join(self.data_dir, file_name) 74 | data = {'img_path': img_path, 'cells': cells, 'structure': structure} 75 | if not os.path.exists(img_path): 76 | raise Exception("{} does not exist!".format(img_path)) 77 | with open(data['img_path'], 'rb') as f: 78 | img = f.read() 79 | data['image'] = img 80 | outs = transform(data, self.ops) 81 | else: 82 | outs = None 83 | except Exception as e: 84 | self.logger.error( 85 | "When parsing line {}, error happened with msg: {}".format( 86 | data_line, e)) 87 | outs = None 88 | if outs is None: 89 | return self.__getitem__(np.random.randint(self.__len__())) 90 | return outs 91 | 92 | def __len__(self): 93 | return len(self.data_idx_order_list) -------------------------------------------------------------------------------- /utils/dict/ar_dict.txt: -------------------------------------------------------------------------------- 1 | a 2 | r 3 | b 4 | i 5 | c 6 | _ 7 | m 8 | g 9 | / 10 | 1 11 | 0 12 | I 13 | L 14 | S 15 | V 16 | R 17 | C 18 | 2 19 | v 20 | l 21 | 6 22 | 3 23 | 9 24 | . 25 | j 26 | p 27 | ا 28 | ل 29 | م 30 | ر 31 | ج 32 | و 33 | ح 34 | ي 35 | ة 36 | 5 37 | 8 38 | 7 39 | أ 40 | ب 41 | ض 42 | 4 43 | ك 44 | س 45 | ه 46 | ث 47 | ن 48 | ط 49 | ع 50 | ت 51 | غ 52 | خ 53 | ف 54 | ئ 55 | ز 56 | إ 57 | د 58 | ص 59 | ظ 60 | ذ 61 | ش 62 | ى 63 | ق 64 | ؤ 65 | آ 66 | ء 67 | s 68 | e 69 | n 70 | w 71 | t 72 | u 73 | z 74 | d 75 | A 76 | N 77 | G 78 | h 79 | o 80 | E 81 | T 82 | H 83 | O 84 | B 85 | y 86 | F 87 | U 88 | J 89 | X 90 | W 91 | P 92 | Z 93 | M 94 | k 95 | q 96 | Y 97 | Q 98 | D 99 | f 100 | K 101 | x 102 | ' 103 | % 104 | - 105 | # 106 | @ 107 | ! 108 | & 109 | $ 110 | , 111 | : 112 | é 113 | ? 114 | + 115 | É 116 | ( 117 | 118 | -------------------------------------------------------------------------------- /utils/dict/arabic_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | ء 83 | آ 84 | أ 85 | ؤ 86 | إ 87 | ئ 88 | ا 89 | ب 90 | ة 91 | ت 92 | ث 93 | ج 94 | ح 95 | خ 96 | د 97 | ذ 98 | ر 99 | ز 100 | س 101 | ش 102 | ص 103 | ض 104 | ط 105 | ظ 106 | ع 107 | غ 108 | ف 109 | ق 110 | ك 111 | ل 112 | م 113 | ن 114 | ه 115 | و 116 | ى 117 | ي 118 | ً 119 | ٌ 120 | ٍ 121 | َ 122 | ُ 123 | ِ 124 | ّ 125 | ْ 126 | ٓ 127 | ٔ 128 | ٰ 129 | ٱ 130 | ٹ 131 | پ 132 | چ 133 | ڈ 134 | ڑ 135 | ژ 136 | ک 137 | ڭ 138 | گ 139 | ں 140 | ھ 141 | ۀ 142 | ہ 143 | ۂ 144 | ۃ 145 | ۆ 146 | ۇ 147 | ۈ 148 | ۋ 149 | ی 150 | ې 151 | ے 152 | ۓ 153 | ە 154 | ١ 155 | ٢ 156 | ٣ 157 | ٤ 158 | ٥ 159 | ٦ 160 | ٧ 161 | ٨ 162 | ٩ 163 | -------------------------------------------------------------------------------- /utils/dict/be_dict.txt: -------------------------------------------------------------------------------- 1 | b 2 | e 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 2 9 | 0 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 1 17 | v 18 | a 19 | l 20 | 6 21 | 9 22 | 4 23 | 3 24 | . 25 | j 26 | p 27 | п 28 | а 29 | з 30 | б 31 | у 32 | г 33 | н 34 | ц 35 | ь 36 | 8 37 | м 38 | л 39 | і 40 | о 41 | ў 42 | ы 43 | 7 44 | 5 45 | М 46 | х 47 | с 48 | р 49 | ф 50 | я 51 | е 52 | д 53 | ж 54 | ю 55 | ч 56 | й 57 | к 58 | Д 59 | в 60 | Б 61 | т 62 | І 63 | ш 64 | ё 65 | э 66 | К 67 | Л 68 | Н 69 | А 70 | Ж 71 | Г 72 | В 73 | П 74 | З 75 | Е 76 | О 77 | Р 78 | С 79 | У 80 | Ё 81 | Й 82 | Т 83 | Ч 84 | Э 85 | Ц 86 | Ю 87 | Ш 88 | Ф 89 | Х 90 | Я 91 | Ь 92 | Ы 93 | Ў 94 | s 95 | c 96 | n 97 | w 98 | M 99 | o 100 | t 101 | T 102 | E 103 | A 104 | B 105 | u 106 | h 107 | y 108 | k 109 | r 110 | H 111 | d 112 | Y 113 | O 114 | U 115 | F 116 | f 117 | x 118 | D 119 | G 120 | N 121 | K 122 | P 123 | z 124 | J 125 | X 126 | W 127 | Z 128 | Q 129 | % 130 | - 131 | q 132 | @ 133 | ' 134 | ! 135 | # 136 | & 137 | , 138 | : 139 | $ 140 | ( 141 | ? 142 | é 143 | + 144 | É 145 | 146 | -------------------------------------------------------------------------------- /utils/dict/bg_dict.txt: -------------------------------------------------------------------------------- 1 | ! 2 | # 3 | $ 4 | % 5 | & 6 | ' 7 | ( 8 | + 9 | , 10 | - 11 | . 12 | / 13 | 0 14 | 1 15 | 2 16 | 3 17 | 4 18 | 5 19 | 6 20 | 7 21 | 8 22 | 9 23 | : 24 | ? 25 | @ 26 | A 27 | B 28 | C 29 | D 30 | E 31 | F 32 | G 33 | H 34 | I 35 | J 36 | K 37 | L 38 | M 39 | N 40 | O 41 | P 42 | Q 43 | R 44 | S 45 | T 46 | U 47 | V 48 | W 49 | X 50 | Y 51 | Z 52 | _ 53 | a 54 | b 55 | c 56 | d 57 | e 58 | f 59 | g 60 | h 61 | i 62 | j 63 | k 64 | l 65 | m 66 | n 67 | o 68 | p 69 | q 70 | r 71 | s 72 | t 73 | u 74 | v 75 | w 76 | x 77 | y 78 | z 79 | É 80 | é 81 | А 82 | Б 83 | В 84 | Г 85 | Д 86 | Е 87 | Ж 88 | З 89 | И 90 | Й 91 | К 92 | Л 93 | М 94 | Н 95 | О 96 | П 97 | Р 98 | С 99 | Т 100 | У 101 | Ф 102 | Х 103 | Ц 104 | Ч 105 | Ш 106 | Щ 107 | Ъ 108 | Ю 109 | Я 110 | а 111 | б 112 | в 113 | г 114 | д 115 | е 116 | ж 117 | з 118 | и 119 | й 120 | к 121 | л 122 | м 123 | н 124 | о 125 | п 126 | р 127 | с 128 | т 129 | у 130 | ф 131 | х 132 | ц 133 | ч 134 | ш 135 | щ 136 | ъ 137 | ь 138 | ю 139 | я 140 | 141 | -------------------------------------------------------------------------------- /utils/dict/cyrillic_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | Ё 83 | Є 84 | І 85 | Ј 86 | Љ 87 | Ў 88 | А 89 | Б 90 | В 91 | Г 92 | Д 93 | Е 94 | Ж 95 | З 96 | И 97 | Й 98 | К 99 | Л 100 | М 101 | Н 102 | О 103 | П 104 | Р 105 | С 106 | Т 107 | У 108 | Ф 109 | Х 110 | Ц 111 | Ч 112 | Ш 113 | Щ 114 | Ъ 115 | Ы 116 | Ь 117 | Э 118 | Ю 119 | Я 120 | а 121 | б 122 | в 123 | г 124 | д 125 | е 126 | ж 127 | з 128 | и 129 | й 130 | к 131 | л 132 | м 133 | н 134 | о 135 | п 136 | р 137 | с 138 | т 139 | у 140 | ф 141 | х 142 | ц 143 | ч 144 | ш 145 | щ 146 | ъ 147 | ы 148 | ь 149 | э 150 | ю 151 | я 152 | ё 153 | ђ 154 | є 155 | і 156 | ј 157 | љ 158 | њ 159 | ћ 160 | ў 161 | џ 162 | Ґ 163 | ґ 164 | -------------------------------------------------------------------------------- /utils/dict/devanagari_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | ँ 83 | ं 84 | ः 85 | अ 86 | आ 87 | इ 88 | ई 89 | उ 90 | ऊ 91 | ऋ 92 | ए 93 | ऐ 94 | ऑ 95 | ओ 96 | औ 97 | क 98 | ख 99 | ग 100 | घ 101 | ङ 102 | च 103 | छ 104 | ज 105 | झ 106 | ञ 107 | ट 108 | ठ 109 | ड 110 | ढ 111 | ण 112 | त 113 | थ 114 | द 115 | ध 116 | न 117 | ऩ 118 | प 119 | फ 120 | ब 121 | भ 122 | म 123 | य 124 | र 125 | ऱ 126 | ल 127 | ळ 128 | व 129 | श 130 | ष 131 | स 132 | ह 133 | ़ 134 | ा 135 | ि 136 | ी 137 | ु 138 | ू 139 | ृ 140 | ॅ 141 | े 142 | ै 143 | ॉ 144 | ो 145 | ौ 146 | ् 147 | ॒ 148 | क़ 149 | ख़ 150 | ग़ 151 | ज़ 152 | ड़ 153 | ढ़ 154 | फ़ 155 | ॠ 156 | । 157 | ० 158 | १ 159 | २ 160 | ३ 161 | ४ 162 | ५ 163 | ६ 164 | ७ 165 | ८ 166 | ९ 167 | ॰ 168 | -------------------------------------------------------------------------------- /utils/dict/en_dict.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | a 12 | b 13 | c 14 | d 15 | e 16 | f 17 | g 18 | h 19 | i 20 | j 21 | k 22 | l 23 | m 24 | n 25 | o 26 | p 27 | q 28 | r 29 | s 30 | t 31 | u 32 | v 33 | w 34 | x 35 | y 36 | z 37 | A 38 | B 39 | C 40 | D 41 | E 42 | F 43 | G 44 | H 45 | I 46 | J 47 | K 48 | L 49 | M 50 | N 51 | O 52 | P 53 | Q 54 | R 55 | S 56 | T 57 | U 58 | V 59 | W 60 | X 61 | Y 62 | Z 63 | 64 | -------------------------------------------------------------------------------- /utils/dict/fa_dict.txt: -------------------------------------------------------------------------------- 1 | f 2 | a 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 1 9 | 3 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 2 17 | 0 18 | v 19 | l 20 | 6 21 | 8 22 | 5 23 | . 24 | j 25 | p 26 | و 27 | د 28 | ر 29 | ك 30 | ن 31 | ش 32 | ه 33 | ا 34 | 4 35 | 9 36 | ی 37 | ج 38 | ِ 39 | 7 40 | غ 41 | ل 42 | س 43 | ز 44 | ّ 45 | ت 46 | ک 47 | گ 48 | ي 49 | م 50 | ب 51 | ف 52 | چ 53 | خ 54 | ق 55 | ژ 56 | آ 57 | ص 58 | پ 59 | َ 60 | ع 61 | ئ 62 | ح 63 | ٔ 64 | ض 65 | ُ 66 | ذ 67 | أ 68 | ى 69 | ط 70 | ظ 71 | ث 72 | ة 73 | ً 74 | ء 75 | ؤ 76 | ْ 77 | ۀ 78 | إ 79 | ٍ 80 | ٌ 81 | ٰ 82 | ٓ 83 | ٱ 84 | s 85 | c 86 | e 87 | n 88 | w 89 | N 90 | E 91 | W 92 | Y 93 | D 94 | O 95 | H 96 | A 97 | d 98 | z 99 | r 100 | T 101 | G 102 | o 103 | t 104 | x 105 | h 106 | b 107 | B 108 | M 109 | Z 110 | u 111 | P 112 | F 113 | y 114 | q 115 | U 116 | K 117 | k 118 | J 119 | Q 120 | ' 121 | X 122 | # 123 | ? 124 | % 125 | $ 126 | , 127 | : 128 | & 129 | ! 130 | - 131 | ( 132 | É 133 | @ 134 | é 135 | + 136 | 137 | -------------------------------------------------------------------------------- /utils/dict/french_dict.txt: -------------------------------------------------------------------------------- 1 | f 2 | e 3 | n 4 | c 5 | h 6 | _ 7 | i 8 | m 9 | g 10 | / 11 | r 12 | v 13 | a 14 | l 15 | t 16 | w 17 | o 18 | d 19 | 6 20 | 1 21 | . 22 | p 23 | B 24 | u 25 | 2 26 | à 27 | 3 28 | R 29 | y 30 | 4 31 | U 32 | E 33 | A 34 | 5 35 | P 36 | O 37 | S 38 | T 39 | D 40 | 7 41 | Z 42 | 8 43 | I 44 | N 45 | L 46 | G 47 | M 48 | H 49 | 0 50 | J 51 | K 52 | - 53 | 9 54 | F 55 | C 56 | V 57 | é 58 | X 59 | ' 60 | s 61 | Q 62 | : 63 | è 64 | x 65 | b 66 | Y 67 | Œ 68 | É 69 | z 70 | W 71 | Ç 72 | È 73 | k 74 | Ô 75 | ô 76 | € 77 | À 78 | Ê 79 | q 80 | ù 81 | ° 82 | ê 83 | î 84 | * 85 |  86 | j 87 | " 88 | , 89 | â 90 | % 91 | û 92 | ç 93 | ü 94 | ? 95 | ! 96 | ; 97 | ö 98 | ( 99 | ) 100 | ï 101 | º 102 | ó 103 | ø 104 | å 105 | + 106 | ™ 107 | á 108 | Ë 109 | < 110 | ² 111 | Á 112 | Î 113 | & 114 | @ 115 | œ 116 | ε 117 | Ü 118 | ë 119 | [ 120 | ] 121 | í 122 | ò 123 | Ö 124 | ä 125 | ß 126 | « 127 | » 128 | ú 129 | ñ 130 | æ 131 | µ 132 | ³ 133 | Å 134 | $ 135 | # 136 | 137 | -------------------------------------------------------------------------------- /utils/dict/german_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | " 4 | # 5 | $ 6 | % 7 | & 8 | ' 9 | ( 10 | ) 11 | * 12 | + 13 | , 14 | - 15 | . 16 | / 17 | 0 18 | 1 19 | 2 20 | 3 21 | 4 22 | 5 23 | 6 24 | 7 25 | 8 26 | 9 27 | : 28 | ; 29 | = 30 | > 31 | ? 32 | @ 33 | A 34 | B 35 | C 36 | D 37 | E 38 | F 39 | G 40 | H 41 | I 42 | J 43 | K 44 | L 45 | M 46 | N 47 | O 48 | P 49 | Q 50 | R 51 | S 52 | T 53 | U 54 | V 55 | W 56 | X 57 | Y 58 | Z 59 | [ 60 | ] 61 | _ 62 | a 63 | b 64 | c 65 | d 66 | e 67 | f 68 | g 69 | h 70 | i 71 | j 72 | k 73 | l 74 | m 75 | n 76 | o 77 | p 78 | q 79 | r 80 | s 81 | t 82 | u 83 | v 84 | w 85 | x 86 | y 87 | z 88 | £ 89 | § 90 | ­ 91 | ° 92 | ´ 93 | µ 94 | · 95 | º 96 | ¿ 97 | Á 98 | Ä 99 | Å 100 | É 101 | Ï 102 | Ô 103 | Ö 104 | Ü 105 | ß 106 | à 107 | á 108 | â 109 | ã 110 | ä 111 | å 112 | æ 113 | ç 114 | è 115 | é 116 | ê 117 | ë 118 | í 119 | ï 120 | ñ 121 | ò 122 | ó 123 | ô 124 | ö 125 | ø 126 | ù 127 | ú 128 | û 129 | ü 130 | ō 131 | Š 132 | Ÿ 133 | ʒ 134 | β 135 | δ 136 | з 137 | Ṡ 138 | ‘ 139 | € 140 | © 141 | ª 142 | « 143 | ¬ 144 | -------------------------------------------------------------------------------- /utils/dict/hi_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | ँ 83 | ं 84 | ः 85 | अ 86 | आ 87 | इ 88 | ई 89 | उ 90 | ऊ 91 | ऋ 92 | ए 93 | ऐ 94 | ऑ 95 | ओ 96 | औ 97 | क 98 | ख 99 | ग 100 | घ 101 | ङ 102 | च 103 | छ 104 | ज 105 | झ 106 | ञ 107 | ट 108 | ठ 109 | ड 110 | ढ 111 | ण 112 | त 113 | थ 114 | द 115 | ध 116 | न 117 | प 118 | फ 119 | ब 120 | भ 121 | म 122 | य 123 | र 124 | ल 125 | ळ 126 | व 127 | श 128 | ष 129 | स 130 | ह 131 | ़ 132 | ा 133 | ि 134 | ी 135 | ु 136 | ू 137 | ृ 138 | ॅ 139 | े 140 | ै 141 | ॉ 142 | ो 143 | ौ 144 | ् 145 | क़ 146 | ख़ 147 | ग़ 148 | ज़ 149 | ड़ 150 | ढ़ 151 | फ़ 152 | ० 153 | १ 154 | २ 155 | ३ 156 | ४ 157 | ५ 158 | ६ 159 | ७ 160 | ८ 161 | ९ 162 | ॰ 163 | -------------------------------------------------------------------------------- /utils/dict/it_dict.txt: -------------------------------------------------------------------------------- 1 | i 2 | t 3 | _ 4 | m 5 | g 6 | / 7 | 5 8 | I 9 | L 10 | S 11 | V 12 | R 13 | C 14 | 2 15 | 0 16 | 1 17 | v 18 | a 19 | l 20 | 7 21 | 8 22 | 9 23 | 6 24 | . 25 | j 26 | p 27 | 28 | e 29 | r 30 | o 31 | d 32 | s 33 | n 34 | 3 35 | 4 36 | P 37 | u 38 | c 39 | A 40 | - 41 | , 42 | " 43 | z 44 | h 45 | f 46 | b 47 | q 48 | ì 49 | ' 50 | à 51 | O 52 | è 53 | G 54 | ù 55 | é 56 | ò 57 | ; 58 | F 59 | E 60 | B 61 | N 62 | H 63 | k 64 | : 65 | U 66 | T 67 | X 68 | D 69 | K 70 | ? 71 | [ 72 | M 73 | ­ 74 | x 75 | y 76 | ( 77 | ) 78 | W 79 | ö 80 | º 81 | w 82 | ] 83 | Q 84 | J 85 | + 86 | ü 87 | ! 88 | È 89 | á 90 | % 91 | = 92 | » 93 | ñ 94 | Ö 95 | Y 96 | ä 97 | í 98 | Z 99 | « 100 | @ 101 | ó 102 | ø 103 | ï 104 | ú 105 | ê 106 | ç 107 | Á 108 | É 109 | Å 110 | ß 111 | { 112 | } 113 | & 114 | ` 115 | û 116 | î 117 | # 118 | $ 119 | -------------------------------------------------------------------------------- /utils/dict/ka_dict.txt: -------------------------------------------------------------------------------- 1 | k 2 | a 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 1 9 | 2 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 0 17 | v 18 | l 19 | 6 20 | 4 21 | 8 22 | . 23 | j 24 | p 25 | ಗ 26 | ು 27 | ಣ 28 | ಪ 29 | ಡ 30 | ಿ 31 | ಸ 32 | ಲ 33 | ಾ 34 | ದ 35 | ್ 36 | 7 37 | 5 38 | 3 39 | ವ 40 | ಷ 41 | ಬ 42 | ಹ 43 | ೆ 44 | 9 45 | ಅ 46 | ಳ 47 | ನ 48 | ರ 49 | ಉ 50 | ಕ 51 | ಎ 52 | ೇ 53 | ಂ 54 | ೈ 55 | ೊ 56 | ೀ 57 | ಯ 58 | ೋ 59 | ತ 60 | ಶ 61 | ಭ 62 | ಧ 63 | ಚ 64 | ಜ 65 | ೂ 66 | ಮ 67 | ಒ 68 | ೃ 69 | ಥ 70 | ಇ 71 | ಟ 72 | ಖ 73 | ಆ 74 | ಞ 75 | ಫ 76 | - 77 | ಢ 78 | ಊ 79 | ಓ 80 | ಐ 81 | ಃ 82 | ಘ 83 | ಝ 84 | ೌ 85 | ಠ 86 | ಛ 87 | ಔ 88 | ಏ 89 | ಈ 90 | ಋ 91 | ೨ 92 | ೦ 93 | ೧ 94 | ೮ 95 | ೯ 96 | ೪ 97 | , 98 | ೫ 99 | ೭ 100 | ೩ 101 | ೬ 102 | ಙ 103 | s 104 | c 105 | e 106 | n 107 | w 108 | o 109 | u 110 | t 111 | d 112 | E 113 | A 114 | T 115 | B 116 | Z 117 | N 118 | G 119 | O 120 | q 121 | z 122 | r 123 | x 124 | P 125 | K 126 | M 127 | J 128 | U 129 | D 130 | f 131 | F 132 | h 133 | b 134 | W 135 | Y 136 | y 137 | H 138 | X 139 | Q 140 | ' 141 | # 142 | & 143 | ! 144 | @ 145 | $ 146 | : 147 | % 148 | é 149 | É 150 | ( 151 | ? 152 | + 153 | 154 | -------------------------------------------------------------------------------- /utils/dict/latin_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | " 4 | # 5 | $ 6 | % 7 | & 8 | ' 9 | ( 10 | ) 11 | * 12 | + 13 | , 14 | - 15 | . 16 | / 17 | 0 18 | 1 19 | 2 20 | 3 21 | 4 22 | 5 23 | 6 24 | 7 25 | 8 26 | 9 27 | : 28 | ; 29 | < 30 | = 31 | > 32 | ? 33 | @ 34 | A 35 | B 36 | C 37 | D 38 | E 39 | F 40 | G 41 | H 42 | I 43 | J 44 | K 45 | L 46 | M 47 | N 48 | O 49 | P 50 | Q 51 | R 52 | S 53 | T 54 | U 55 | V 56 | W 57 | X 58 | Y 59 | Z 60 | [ 61 | ] 62 | _ 63 | ` 64 | a 65 | b 66 | c 67 | d 68 | e 69 | f 70 | g 71 | h 72 | i 73 | j 74 | k 75 | l 76 | m 77 | n 78 | o 79 | p 80 | q 81 | r 82 | s 83 | t 84 | u 85 | v 86 | w 87 | x 88 | y 89 | z 90 | { 91 | } 92 | ¡ 93 | £ 94 | § 95 | ª 96 | « 97 | ­ 98 | ° 99 | ² 100 | ³ 101 | ´ 102 | µ 103 | · 104 | º 105 | » 106 | ¿ 107 | À 108 | Á 109 |  110 | Ä 111 | Å 112 | Ç 113 | È 114 | É 115 | Ê 116 | Ë 117 | Ì 118 | Í 119 | Î 120 | Ï 121 | Ò 122 | Ó 123 | Ô 124 | Õ 125 | Ö 126 | Ú 127 | Ü 128 | Ý 129 | ß 130 | à 131 | á 132 | â 133 | ã 134 | ä 135 | å 136 | æ 137 | ç 138 | è 139 | é 140 | ê 141 | ë 142 | ì 143 | í 144 | î 145 | ï 146 | ñ 147 | ò 148 | ó 149 | ô 150 | õ 151 | ö 152 | ø 153 | ù 154 | ú 155 | û 156 | ü 157 | ý 158 | ą 159 | Ć 160 | ć 161 | Č 162 | č 163 | Đ 164 | đ 165 | ę 166 | ı 167 | Ł 168 | ł 169 | ō 170 | Œ 171 | œ 172 | Š 173 | š 174 | Ÿ 175 | Ž 176 | ž 177 | ʒ 178 | β 179 | δ 180 | ε 181 | з 182 | Ṡ 183 | ‘ 184 | € 185 | ™ 186 | -------------------------------------------------------------------------------- /utils/dict/mr_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | ँ 83 | ं 84 | ः 85 | अ 86 | आ 87 | इ 88 | ई 89 | उ 90 | ऊ 91 | ए 92 | ऐ 93 | ऑ 94 | ओ 95 | औ 96 | क 97 | ख 98 | ग 99 | घ 100 | च 101 | छ 102 | ज 103 | झ 104 | ञ 105 | ट 106 | ठ 107 | ड 108 | ढ 109 | ण 110 | त 111 | थ 112 | द 113 | ध 114 | न 115 | प 116 | फ 117 | ब 118 | भ 119 | म 120 | य 121 | र 122 | ऱ 123 | ल 124 | ळ 125 | व 126 | श 127 | ष 128 | स 129 | ह 130 | ़ 131 | ा 132 | ि 133 | ी 134 | ु 135 | ू 136 | ृ 137 | ॅ 138 | े 139 | ै 140 | ॉ 141 | ो 142 | ौ 143 | ् 144 | ० 145 | १ 146 | २ 147 | ३ 148 | ४ 149 | ५ 150 | ६ 151 | ७ 152 | ८ 153 | ९ 154 | -------------------------------------------------------------------------------- /utils/dict/ne_dict.txt: -------------------------------------------------------------------------------- 1 | 2 | ! 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | + 10 | , 11 | - 12 | . 13 | / 14 | 0 15 | 1 16 | 2 17 | 3 18 | 4 19 | 5 20 | 6 21 | 7 22 | 8 23 | 9 24 | : 25 | ? 26 | @ 27 | A 28 | B 29 | C 30 | D 31 | E 32 | F 33 | G 34 | H 35 | I 36 | J 37 | K 38 | L 39 | M 40 | N 41 | O 42 | P 43 | Q 44 | R 45 | S 46 | T 47 | U 48 | V 49 | W 50 | X 51 | Y 52 | Z 53 | _ 54 | a 55 | b 56 | c 57 | d 58 | e 59 | f 60 | g 61 | h 62 | i 63 | j 64 | k 65 | l 66 | m 67 | n 68 | o 69 | p 70 | q 71 | r 72 | s 73 | t 74 | u 75 | v 76 | w 77 | x 78 | y 79 | z 80 | É 81 | é 82 | ः 83 | अ 84 | आ 85 | इ 86 | ई 87 | उ 88 | ऊ 89 | ऋ 90 | ए 91 | ऐ 92 | ओ 93 | औ 94 | क 95 | ख 96 | ग 97 | घ 98 | ङ 99 | च 100 | छ 101 | ज 102 | झ 103 | ञ 104 | ट 105 | ठ 106 | ड 107 | ढ 108 | ण 109 | त 110 | थ 111 | द 112 | ध 113 | न 114 | ऩ 115 | प 116 | फ 117 | ब 118 | भ 119 | म 120 | य 121 | र 122 | ऱ 123 | ल 124 | व 125 | श 126 | ष 127 | स 128 | ह 129 | ़ 130 | ा 131 | ि 132 | ी 133 | ु 134 | ू 135 | ृ 136 | े 137 | ै 138 | ो 139 | ौ 140 | ् 141 | ॒ 142 | ॠ 143 | । 144 | ० 145 | १ 146 | २ 147 | ३ 148 | ४ 149 | ५ 150 | ६ 151 | ७ 152 | ८ 153 | ९ 154 | -------------------------------------------------------------------------------- /utils/dict/oc_dict.txt: -------------------------------------------------------------------------------- 1 | o 2 | c 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 2 9 | 0 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 1 17 | v 18 | a 19 | l 20 | 4 21 | 3 22 | . 23 | j 24 | p 25 | r 26 | e 27 | è 28 | t 29 | 9 30 | 7 31 | 5 32 | 8 33 | n 34 | ' 35 | b 36 | s 37 | 6 38 | q 39 | u 40 | á 41 | d 42 | ò 43 | à 44 | h 45 | z 46 | f 47 | ï 48 | í 49 | A 50 | ç 51 | x 52 | ó 53 | é 54 | P 55 | O 56 | Ò 57 | ü 58 | k 59 | À 60 | F 61 | - 62 | ú 63 | ­ 64 | æ 65 | Á 66 | D 67 | E 68 | w 69 | K 70 | T 71 | N 72 | y 73 | U 74 | Z 75 | G 76 | B 77 | J 78 | H 79 | M 80 | W 81 | Y 82 | X 83 | Q 84 | % 85 | $ 86 | , 87 | @ 88 | & 89 | ! 90 | : 91 | ( 92 | # 93 | ? 94 | + 95 | É 96 | 97 | -------------------------------------------------------------------------------- /utils/dict/pu_dict.txt: -------------------------------------------------------------------------------- 1 | p 2 | u 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 8 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | 0 17 | 1 18 | v 19 | a 20 | l 21 | 6 22 | 7 23 | 4 24 | 5 25 | . 26 | j 27 | 28 | q 29 | e 30 | s 31 | t 32 | ã 33 | o 34 | x 35 | 9 36 | c 37 | n 38 | r 39 | z 40 | ç 41 | õ 42 | 3 43 | A 44 | U 45 | d 46 | º 47 | ô 48 | ­ 49 | , 50 | E 51 | ; 52 | ó 53 | á 54 | b 55 | D 56 | ? 57 | ú 58 | ê 59 | - 60 | h 61 | P 62 | f 63 | à 64 | N 65 | í 66 | O 67 | M 68 | G 69 | É 70 | é 71 | â 72 | F 73 | : 74 | T 75 | Á 76 | " 77 | Q 78 | ) 79 | W 80 | J 81 | B 82 | H 83 | ( 84 | ö 85 | % 86 | Ö 87 | « 88 | w 89 | K 90 | y 91 | ! 92 | k 93 | ] 94 | ' 95 | Z 96 | + 97 | Ç 98 | Õ 99 | Y 100 | À 101 | X 102 | µ 103 | » 104 | ª 105 | Í 106 | ü 107 | ä 108 | ´ 109 | è 110 | ñ 111 | ß 112 | ï 113 | Ú 114 | ë 115 | Ô 116 | Ï 117 | Ó 118 | [ 119 | Ì 120 | < 121 |  122 | ò 123 | § 124 | ³ 125 | ø 126 | å 127 | # 128 | $ 129 | & 130 | @ 131 | -------------------------------------------------------------------------------- /utils/dict/rs_dict.txt: -------------------------------------------------------------------------------- 1 | r 2 | s 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 1 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | 0 17 | v 18 | a 19 | l 20 | 7 21 | 5 22 | 8 23 | 6 24 | . 25 | j 26 | p 27 | 28 | t 29 | d 30 | 9 31 | 3 32 | e 33 | š 34 | 4 35 | k 36 | u 37 | ć 38 | c 39 | n 40 | đ 41 | o 42 | z 43 | č 44 | b 45 | ž 46 | f 47 | Z 48 | T 49 | h 50 | M 51 | F 52 | O 53 | Š 54 | B 55 | H 56 | A 57 | E 58 | Đ 59 | Ž 60 | D 61 | P 62 | G 63 | Č 64 | K 65 | U 66 | N 67 | J 68 | Ć 69 | w 70 | y 71 | W 72 | x 73 | Y 74 | X 75 | q 76 | Q 77 | # 78 | & 79 | $ 80 | , 81 | - 82 | % 83 | ' 84 | @ 85 | ! 86 | : 87 | ? 88 | ( 89 | É 90 | é 91 | + 92 | -------------------------------------------------------------------------------- /utils/dict/rsc_dict.txt: -------------------------------------------------------------------------------- 1 | r 2 | s 3 | c 4 | _ 5 | i 6 | m 7 | g 8 | / 9 | 5 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 2 17 | 0 18 | 1 19 | v 20 | a 21 | l 22 | 9 23 | 7 24 | 8 25 | . 26 | j 27 | p 28 | м 29 | а 30 | с 31 | и 32 | р 33 | ћ 34 | е 35 | ш 36 | 3 37 | 4 38 | о 39 | г 40 | н 41 | з 42 | в 43 | л 44 | 6 45 | т 46 | ж 47 | у 48 | к 49 | п 50 | њ 51 | д 52 | ч 53 | С 54 | ј 55 | ф 56 | ц 57 | љ 58 | х 59 | О 60 | И 61 | А 62 | б 63 | Ш 64 | К 65 | ђ 66 | џ 67 | М 68 | В 69 | З 70 | Д 71 | Р 72 | У 73 | Н 74 | Т 75 | Б 76 | ? 77 | П 78 | Х 79 | Ј 80 | Ц 81 | Г 82 | Љ 83 | Л 84 | Ф 85 | e 86 | n 87 | w 88 | E 89 | F 90 | A 91 | N 92 | f 93 | o 94 | b 95 | M 96 | G 97 | t 98 | y 99 | W 100 | k 101 | P 102 | u 103 | H 104 | B 105 | T 106 | z 107 | h 108 | O 109 | Y 110 | d 111 | U 112 | K 113 | D 114 | x 115 | X 116 | J 117 | Z 118 | Q 119 | q 120 | ' 121 | - 122 | @ 123 | é 124 | # 125 | ! 126 | , 127 | % 128 | $ 129 | : 130 | & 131 | + 132 | ( 133 | É 134 | 135 | -------------------------------------------------------------------------------- /utils/dict/ru_dict.txt: -------------------------------------------------------------------------------- 1 | к 2 | в 3 | а 4 | з 5 | и 6 | у 7 | р 8 | о 9 | н 10 | я 11 | х 12 | п 13 | л 14 | ы 15 | г 16 | е 17 | т 18 | м 19 | д 20 | ж 21 | ш 22 | ь 23 | с 24 | ё 25 | б 26 | й 27 | ч 28 | ю 29 | ц 30 | щ 31 | М 32 | э 33 | ф 34 | А 35 | ъ 36 | С 37 | Ф 38 | Ю 39 | В 40 | К 41 | Т 42 | Н 43 | О 44 | Э 45 | У 46 | И 47 | Г 48 | Л 49 | Р 50 | Д 51 | Б 52 | Ш 53 | П 54 | З 55 | Х 56 | Е 57 | Ж 58 | Я 59 | Ц 60 | Ч 61 | Й 62 | Щ 63 | 0 64 | 1 65 | 2 66 | 3 67 | 4 68 | 5 69 | 6 70 | 7 71 | 8 72 | 9 73 | a 74 | b 75 | c 76 | d 77 | e 78 | f 79 | g 80 | h 81 | i 82 | j 83 | k 84 | l 85 | m 86 | n 87 | o 88 | p 89 | q 90 | r 91 | s 92 | t 93 | u 94 | v 95 | w 96 | x 97 | y 98 | z 99 | A 100 | B 101 | C 102 | D 103 | E 104 | F 105 | G 106 | H 107 | I 108 | J 109 | K 110 | L 111 | M 112 | N 113 | O 114 | P 115 | Q 116 | R 117 | S 118 | T 119 | U 120 | V 121 | W 122 | X 123 | Y 124 | Z 125 | 126 | -------------------------------------------------------------------------------- /utils/dict/ta_dict.txt: -------------------------------------------------------------------------------- 1 | t 2 | a 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 3 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | 0 17 | 1 18 | v 19 | l 20 | 9 21 | 7 22 | 8 23 | . 24 | j 25 | p 26 | ப 27 | ூ 28 | த 29 | ம 30 | ி 31 | வ 32 | ர 33 | ் 34 | ந 35 | ோ 36 | ன 37 | 6 38 | ஆ 39 | ற 40 | ல 41 | 5 42 | ள 43 | ா 44 | ொ 45 | ழ 46 | ு 47 | 4 48 | ெ 49 | ண 50 | க 51 | ட 52 | ை 53 | ே 54 | ச 55 | ய 56 | ஒ 57 | இ 58 | அ 59 | ங 60 | உ 61 | ீ 62 | ஞ 63 | எ 64 | ஓ 65 | ஃ 66 | ஜ 67 | ஷ 68 | ஸ 69 | ஏ 70 | ஊ 71 | ஹ 72 | ஈ 73 | ஐ 74 | ௌ 75 | ஔ 76 | s 77 | c 78 | e 79 | n 80 | w 81 | F 82 | T 83 | O 84 | P 85 | K 86 | A 87 | N 88 | G 89 | Y 90 | E 91 | M 92 | H 93 | U 94 | B 95 | o 96 | b 97 | D 98 | d 99 | r 100 | W 101 | u 102 | y 103 | f 104 | X 105 | k 106 | q 107 | h 108 | J 109 | z 110 | Z 111 | Q 112 | x 113 | - 114 | ' 115 | $ 116 | , 117 | % 118 | @ 119 | é 120 | ! 121 | # 122 | + 123 | É 124 | & 125 | : 126 | ( 127 | ? 128 | 129 | -------------------------------------------------------------------------------- /utils/dict/table_dict.txt: -------------------------------------------------------------------------------- 1 | ← 2 | 3 | ☆ 4 | ─ 5 | α 6 |  7 | 8 | ⋅ 9 | $ 10 | ω 11 | ψ 12 | χ 13 | ( 14 | υ 15 | ≥ 16 | σ 17 | , 18 | ρ 19 | ε 20 | 0 21 | ■ 22 | 4 23 | 8 24 | ✗ 25 | b 26 | < 27 | ✓ 28 | Ψ 29 | Ω 30 | € 31 | D 32 | 3 33 | Π 34 | H 35 | ║ 36 | 37 | L 38 | Φ 39 | Χ 40 | θ 41 | P 42 | κ 43 | λ 44 | μ 45 | T 46 | ξ 47 | X 48 | β 49 | γ 50 | δ 51 | \ 52 | ζ 53 | η 54 | ` 55 | d 56 | 57 | h 58 | f 59 | l 60 | Θ 61 | p 62 | √ 63 | t 64 | 65 | x 66 | Β 67 | Γ 68 | Δ 69 | | 70 | ǂ 71 | ɛ 72 | j 73 | ̧ 74 | ➢ 75 | ⁡ 76 | ̌ 77 | ′ 78 | « 79 | △ 80 | ▲ 81 | # 82 | 83 | ' 84 | Ι 85 | + 86 | ¶ 87 | / 88 | ▼ 89 | ⇑ 90 | □ 91 | · 92 | 7 93 | ▪ 94 | ; 95 | ? 96 | ➔ 97 | ∩ 98 | C 99 | ÷ 100 | G 101 | ⇒ 102 | K 103 | 104 | O 105 | S 106 | С 107 | W 108 | Α 109 | [ 110 | ○ 111 | _ 112 | ● 113 | ‡ 114 | c 115 | z 116 | g 117 | 118 | o 119 | 120 | 〈 121 | 〉 122 | s 123 | ⩽ 124 | w 125 | φ 126 | ʹ 127 | { 128 | » 129 | ∣ 130 | ̆ 131 | e 132 | ˆ 133 | ∈ 134 | τ 135 | ◆ 136 | ι 137 | ∅ 138 | ∆ 139 | ∙ 140 | ∘ 141 | Ø 142 | ß 143 | ✔ 144 | ∞ 145 | ∑ 146 | − 147 | × 148 | ◊ 149 | ∗ 150 | ∖ 151 | ˃ 152 | ˂ 153 | ∫ 154 | " 155 | i 156 | & 157 | π 158 | ↔ 159 | * 160 | ∥ 161 | æ 162 | ∧ 163 | . 164 | ⁄ 165 | ø 166 | Q 167 | ∼ 168 | 6 169 | ⁎ 170 | : 171 | ★ 172 | > 173 | a 174 | B 175 | ≈ 176 | F 177 | J 178 | ̄ 179 | N 180 | ♯ 181 | R 182 | V 183 | 184 | ― 185 | Z 186 | ♣ 187 | ^ 188 | ¤ 189 | ¥ 190 | § 191 | 192 | ¢ 193 | £ 194 | ≦ 195 | ­ 196 | ≤ 197 | ‖ 198 | Λ 199 | © 200 | n 201 | ↓ 202 | → 203 | ↑ 204 | r 205 | ° 206 | ± 207 | v 208 | 209 | ♂ 210 | k 211 | ♀ 212 | ~ 213 | ᅟ 214 | ̇ 215 | @ 216 | ” 217 | ♦ 218 | ł 219 | ® 220 | ⊕ 221 | „ 222 | ! 223 | 224 | % 225 | ⇓ 226 | ) 227 | - 228 | 1 229 | 5 230 | 9 231 | = 232 | А 233 | A 234 | ‰ 235 | ⋆ 236 | Σ 237 | E 238 | ◦ 239 | I 240 | ※ 241 | M 242 | m 243 | ̨ 244 | ⩾ 245 | † 246 | 247 | • 248 | U 249 | Y 250 | 
 251 | ] 252 | ̸ 253 | 2 254 | ‐ 255 | – 256 | ‒ 257 | ̂ 258 | — 259 | ̀ 260 | ́ 261 | ’ 262 | ‘ 263 | ⋮ 264 | ⋯ 265 | ̊ 266 | “ 267 | ̈ 268 | ≧ 269 | q 270 | u 271 | ı 272 | y 273 | 274 | ​ 275 | ̃ 276 | } 277 | ν 278 | -------------------------------------------------------------------------------- /utils/dict/te_dict.txt: -------------------------------------------------------------------------------- 1 | t 2 | e 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 5 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | 0 17 | 1 18 | v 19 | a 20 | l 21 | 3 22 | 4 23 | 8 24 | 9 25 | . 26 | j 27 | p 28 | త 29 | ె 30 | ర 31 | క 32 | ్ 33 | ి 34 | ం 35 | చ 36 | ే 37 | ద 38 | ు 39 | 7 40 | 6 41 | ఉ 42 | ా 43 | మ 44 | ట 45 | ో 46 | వ 47 | ప 48 | ల 49 | శ 50 | ఆ 51 | య 52 | ై 53 | భ 54 | ' 55 | ీ 56 | గ 57 | ూ 58 | డ 59 | ధ 60 | హ 61 | న 62 | జ 63 | స 64 | [ 65 | ‌ 66 | ష 67 | అ 68 | ణ 69 | ఫ 70 | బ 71 | ఎ 72 | ; 73 | ళ 74 | థ 75 | ొ 76 | ఠ 77 | ృ 78 | ఒ 79 | ఇ 80 | ః 81 | ఊ 82 | ఖ 83 | - 84 | ఐ 85 | ఘ 86 | ౌ 87 | ఏ 88 | ఈ 89 | ఛ 90 | , 91 | ఓ 92 | ఞ 93 | | 94 | ? 95 | : 96 | ఢ 97 | " 98 | ( 99 | ” 100 | ! 101 | + 102 | ) 103 | * 104 | = 105 | & 106 | “ 107 | € 108 | ] 109 | £ 110 | $ 111 | s 112 | c 113 | n 114 | w 115 | k 116 | J 117 | G 118 | u 119 | d 120 | r 121 | E 122 | o 123 | h 124 | y 125 | b 126 | f 127 | B 128 | M 129 | O 130 | T 131 | N 132 | D 133 | P 134 | A 135 | F 136 | x 137 | W 138 | Y 139 | U 140 | H 141 | K 142 | X 143 | z 144 | Z 145 | Q 146 | q 147 | É 148 | % 149 | # 150 | @ 151 | é 152 | -------------------------------------------------------------------------------- /utils/dict/ug_dict.txt: -------------------------------------------------------------------------------- 1 | u 2 | g 3 | _ 4 | i 5 | m 6 | / 7 | 1 8 | I 9 | L 10 | S 11 | V 12 | R 13 | C 14 | 2 15 | 0 16 | v 17 | a 18 | l 19 | 8 20 | 5 21 | 3 22 | 6 23 | 9 24 | . 25 | j 26 | p 27 | 28 | ق 29 | ا 30 | پ 31 | ل 32 | 4 33 | 7 34 | ئ 35 | ى 36 | ش 37 | ت 38 | ي 39 | ك 40 | د 41 | ف 42 | ر 43 | و 44 | ن 45 | ب 46 | ە 47 | خ 48 | ې 49 | چ 50 | ۇ 51 | ز 52 | س 53 | م 54 | ۋ 55 | گ 56 | ڭ 57 | ۆ 58 | ۈ 59 | ج 60 | غ 61 | ھ 62 | ژ 63 | s 64 | c 65 | e 66 | n 67 | w 68 | P 69 | E 70 | D 71 | U 72 | d 73 | r 74 | b 75 | y 76 | B 77 | o 78 | O 79 | Y 80 | N 81 | T 82 | k 83 | t 84 | h 85 | A 86 | H 87 | F 88 | z 89 | W 90 | K 91 | G 92 | M 93 | f 94 | Z 95 | X 96 | Q 97 | J 98 | x 99 | q 100 | - 101 | ! 102 | % 103 | # 104 | ? 105 | : 106 | $ 107 | , 108 | & 109 | ' 110 | É 111 | @ 112 | é 113 | ( 114 | + 115 | -------------------------------------------------------------------------------- /utils/dict/uk_dict.txt: -------------------------------------------------------------------------------- 1 | u 2 | k 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 1 9 | 6 10 | I 11 | L 12 | S 13 | V 14 | R 15 | C 16 | 2 17 | 0 18 | v 19 | a 20 | l 21 | 7 22 | 9 23 | . 24 | j 25 | p 26 | в 27 | і 28 | д 29 | п 30 | о 31 | н 32 | с 33 | т 34 | ю 35 | 4 36 | 5 37 | 3 38 | а 39 | и 40 | м 41 | е 42 | р 43 | ч 44 | у 45 | Б 46 | з 47 | л 48 | к 49 | 8 50 | А 51 | В 52 | г 53 | є 54 | б 55 | ь 56 | х 57 | ґ 58 | ш 59 | ц 60 | ф 61 | я 62 | щ 63 | ж 64 | Г 65 | Х 66 | У 67 | Т 68 | Е 69 | І 70 | Н 71 | П 72 | З 73 | Л 74 | Ю 75 | С 76 | Д 77 | М 78 | К 79 | Р 80 | Ф 81 | О 82 | Ц 83 | И 84 | Я 85 | Ч 86 | Ш 87 | Ж 88 | Є 89 | Ґ 90 | Ь 91 | s 92 | c 93 | e 94 | n 95 | w 96 | A 97 | P 98 | r 99 | E 100 | t 101 | o 102 | h 103 | d 104 | y 105 | M 106 | G 107 | N 108 | F 109 | B 110 | T 111 | D 112 | U 113 | O 114 | W 115 | Z 116 | f 117 | H 118 | Y 119 | b 120 | K 121 | z 122 | x 123 | Q 124 | X 125 | q 126 | J 127 | $ 128 | - 129 | ' 130 | # 131 | & 132 | % 133 | ? 134 | : 135 | ! 136 | , 137 | + 138 | @ 139 | ( 140 | é 141 | É 142 | 143 | -------------------------------------------------------------------------------- /utils/dict/ur_dict.txt: -------------------------------------------------------------------------------- 1 | u 2 | r 3 | _ 4 | i 5 | m 6 | g 7 | / 8 | 3 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | 0 17 | 1 18 | v 19 | a 20 | l 21 | 9 22 | 7 23 | 8 24 | . 25 | j 26 | p 27 | 28 | چ 29 | ٹ 30 | پ 31 | ا 32 | ئ 33 | ی 34 | ے 35 | 4 36 | 6 37 | و 38 | ل 39 | ن 40 | ڈ 41 | ھ 42 | ک 43 | ت 44 | ش 45 | ف 46 | ق 47 | ر 48 | د 49 | 5 50 | ب 51 | ج 52 | خ 53 | ہ 54 | س 55 | ز 56 | غ 57 | ڑ 58 | ں 59 | آ 60 | م 61 | ؤ 62 | ط 63 | ص 64 | ح 65 | ع 66 | گ 67 | ث 68 | ض 69 | ذ 70 | ۓ 71 | ِ 72 | ء 73 | ظ 74 | ً 75 | ي 76 | ُ 77 | ۃ 78 | أ 79 | ٰ 80 | ە 81 | ژ 82 | ۂ 83 | ة 84 | ّ 85 | ك 86 | ه 87 | s 88 | c 89 | e 90 | n 91 | w 92 | o 93 | d 94 | t 95 | D 96 | M 97 | T 98 | U 99 | E 100 | b 101 | P 102 | h 103 | y 104 | W 105 | H 106 | A 107 | x 108 | B 109 | O 110 | N 111 | G 112 | Y 113 | Q 114 | F 115 | k 116 | K 117 | q 118 | J 119 | Z 120 | f 121 | z 122 | X 123 | ' 124 | @ 125 | & 126 | ! 127 | , 128 | : 129 | $ 130 | - 131 | # 132 | ? 133 | % 134 | é 135 | + 136 | ( 137 | É 138 | -------------------------------------------------------------------------------- /utils/dict/xi_dict.txt: -------------------------------------------------------------------------------- 1 | x 2 | i 3 | _ 4 | m 5 | g 6 | / 7 | 1 8 | 0 9 | I 10 | L 11 | S 12 | V 13 | R 14 | C 15 | 2 16 | v 17 | a 18 | l 19 | 3 20 | 6 21 | 4 22 | 5 23 | . 24 | j 25 | p 26 | 27 | Q 28 | u 29 | e 30 | r 31 | o 32 | 8 33 | 7 34 | n 35 | c 36 | 9 37 | t 38 | b 39 | é 40 | q 41 | d 42 | ó 43 | y 44 | F 45 | s 46 | , 47 | O 48 | í 49 | T 50 | f 51 | " 52 | U 53 | M 54 | h 55 | : 56 | P 57 | H 58 | A 59 | E 60 | D 61 | z 62 | N 63 | á 64 | ñ 65 | ú 66 | % 67 | ; 68 | è 69 | + 70 | Y 71 | - 72 | B 73 | G 74 | ( 75 | ) 76 | ¿ 77 | ? 78 | w 79 | ¡ 80 | ! 81 | X 82 | É 83 | K 84 | k 85 | Á 86 | ü 87 | Ú 88 | « 89 | » 90 | J 91 | ' 92 | ö 93 | W 94 | Z 95 | º 96 | Ö 97 | ­ 98 | [ 99 | ] 100 | Ç 101 | ç 102 | à 103 | ä 104 | û 105 | ò 106 | Í 107 | ê 108 | ô 109 | ø 110 | ª 111 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/6/15 14:25 3 | # @Author : zhoujun 4 | 5 | import logging 6 | 7 | import torch.distributed as dist 8 | 9 | logger_initialized = {} 10 | 11 | 12 | def get_logger(name, log_file=None, log_level=logging.INFO): 13 | """Initialize and get a logger by name. 14 | If the logger has not been initialized, this method will initialize the 15 | logger by adding one or two handlers, otherwise the initialized logger will 16 | be directly returned. During initialization, a StreamHandler will always be 17 | added. If `log_file` is specified and the process rank is 0, a FileHandler 18 | will also be added. 19 | Args: 20 | name (str): Logger name. 21 | log_file (str | None): The log filename. If specified, a FileHandler 22 | will be added to the logger. 23 | log_level (int): The logger level. Note that only the process of 24 | rank 0 is affected, and other processes will set the level to 25 | "Error" thus be silent most of the time. 26 | Returns: 27 | logging.Logger: The expected logger. 28 | """ 29 | logger = logging.getLogger(name) 30 | if name in logger_initialized: 31 | return logger 32 | # handle hierarchical names 33 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 34 | # initialization since it is a child of "a". 35 | for logger_name in logger_initialized: 36 | if name.startswith(logger_name): 37 | return logger 38 | 39 | stream_handler = logging.StreamHandler() 40 | handlers = [stream_handler] 41 | 42 | if dist.is_available() and dist.is_initialized(): 43 | rank = dist.get_rank() 44 | else: 45 | rank = 0 46 | 47 | # only rank 0 will add a FileHandler 48 | if rank == 0 and log_file is not None: 49 | file_handler = logging.FileHandler(log_file, 'w') 50 | handlers.append(file_handler) 51 | 52 | formatter = logging.Formatter( 53 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 54 | for handler in handlers: 55 | handler.setFormatter(formatter) 56 | handler.setLevel(log_level) 57 | logger.addHandler(handler) 58 | 59 | if rank == 0: 60 | logger.setLevel(log_level) 61 | else: 62 | logger.setLevel(logging.ERROR) 63 | 64 | logger_initialized[name] = True 65 | 66 | return logger 67 | 68 | 69 | def print_log(msg, logger=None, level=logging.INFO): 70 | """Print a log message. 71 | Args: 72 | msg (str): The message to be logged. 73 | logger (logging.Logger | str | None): The logger to be used. 74 | Some special loggers are: 75 | - "silent": no message will be printed. 76 | - other str: the logger obtained with `get_root_logger(logger)`. 77 | - None: The `print()` method will be used to print log messages. 78 | level (int): Logging level. Only available when `logger` is a Logger 79 | object or "root". 80 | """ 81 | if logger is None: 82 | print(msg) 83 | elif isinstance(logger, logging.Logger): 84 | logger.log(level, msg) 85 | elif logger == 'silent': 86 | pass 87 | elif isinstance(logger, str): 88 | _logger = get_logger(logger) 89 | _logger.log(level, msg) 90 | else: 91 | raise TypeError( 92 | 'logger should be either a logging.Logger object, str, ' 93 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /utils/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from .table_att_loss import TableAttentionLoss 3 | 4 | 5 | def build_loss(config): 6 | support_dict = ['TableAttentionLoss'] 7 | 8 | config = copy.deepcopy(config) 9 | module_name = config.pop('name') 10 | assert module_name in support_dict, Exception('loss only support {}'.format( 11 | support_dict)) 12 | module_class = eval(module_name)(**config) 13 | return module_class 14 | -------------------------------------------------------------------------------- /utils/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/losses/__pycache__/table_att_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/losses/__pycache__/table_att_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/losses/table_att_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class TableAttentionLoss(nn.Module): 11 | def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): 12 | super(TableAttentionLoss, self).__init__() 13 | self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') 14 | self.structure_weight = structure_weight 15 | self.loc_weight = loc_weight 16 | self.use_giou = use_giou 17 | self.giou_weight = giou_weight 18 | 19 | def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): 20 | ''' 21 | :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 22 | :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] 23 | :return: loss 24 | ''' 25 | # ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0]) 26 | # iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1]) 27 | # ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2]) 28 | # iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3]) 29 | ix1 = (preds[:, 0] >= bbox[:, 0]) * preds[:, 0] + (preds[:, 0] < bbox[:, 0]) * bbox[:, 0] 30 | iy1 = (preds[:, 1] >= bbox[:, 1]) * preds[:, 1] + (preds[:, 1] < bbox[:, 1]) * bbox[:, 1] 31 | ix2 = (preds[:, 2] >= bbox[:, 2]) * bbox[:, 2] + (preds[:, 2] < bbox[:, 2]) * preds[:, 2] 32 | iy2 = (preds[:, 3] >= bbox[:, 3]) * bbox[:, 3] + (preds[:, 3] < bbox[:, 3]) * preds[:, 3] 33 | 34 | iw = torch.clamp(ix2 - ix1 + 1e-3, 0., 1e10) 35 | ih = torch.clamp(iy2 - iy1 + 1e-3, 0., 1e10) 36 | 37 | # overlap 38 | inters = iw * ih 39 | 40 | # union 41 | uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 42 | ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( 43 | bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps 44 | 45 | # ious 46 | ious = inters / uni 47 | 48 | # ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0]) 49 | # ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1]) 50 | # ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2]) 51 | # ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3]) 52 | ex1 = (preds[:, 0] >= bbox[:, 0]) * bbox[:, 0] + (preds[:, 0] < bbox[:, 0]) * preds[:, 0] 53 | ey1 = (preds[:, 1] >= bbox[:, 1]) * bbox[:, 1] + (preds[:, 1] < bbox[:, 1]) * preds[:, 1] 54 | ex2 = (preds[:, 2] >= bbox[:, 2]) * preds[:, 2] + (preds[:, 2] < bbox[:, 2]) * bbox[:, 2] 55 | ey2 = (preds[:, 3] >= bbox[:, 3]) * preds[:, 3] + (preds[:, 3] < bbox[:, 3]) * bbox[:, 3] 56 | ew = torch.clamp(ex2 - ex1 + 1e-3, 0., 1e10) 57 | eh = torch.clamp(ey2 - ey1 + 1e-3, 0., 1e10) 58 | 59 | # enclose erea 60 | enclose = ew * eh + eps 61 | giou = ious - (enclose - uni) / enclose 62 | 63 | loss = 1 - giou 64 | 65 | if reduction == 'mean': 66 | loss = torch.mean(loss) 67 | elif reduction == 'sum': 68 | loss = torch.sum(loss) 69 | else: 70 | raise NotImplementedError 71 | return loss 72 | 73 | def forward(self, predicts, batch): 74 | structure_probs = predicts['structure_probs'] 75 | # structure_targets = batch[1].astype("int64") 76 | structure_targets = torch.tensor(batch[1], dtype=torch.int64) 77 | structure_targets = structure_targets[:, 1:] 78 | if len(batch) == 6: 79 | # structure_mask = batch[5].astype("int64") 80 | structure_mask = torch.tensor(batch[5], dtype=torch.int64) 81 | structure_mask = structure_mask[:, 1:] 82 | structure_mask = torch.reshape(structure_mask, [-1]) 83 | structure_probs = torch.reshape(structure_probs, [-1, structure_probs.shape[-1]]) 84 | structure_targets = torch.reshape(structure_targets, [-1]) 85 | structure_loss = self.loss_func(structure_probs, structure_targets) 86 | 87 | if len(batch) == 6: 88 | structure_loss = structure_loss * structure_mask 89 | 90 | # structure_loss = paddle.sum(structure_loss) * self.structure_weight 91 | structure_loss = torch.mean(structure_loss) * self.structure_weight 92 | 93 | loc_preds = predicts['loc_preds'] 94 | # loc_targets = batch[2].astype("float32") 95 | loc_targets = torch.tensor(batch[2], dtype=torch.float32) 96 | # loc_targets_mask = batch[4].astype("float32") 97 | loc_targets_mask = torch.tensor(batch[4], dtype=torch.float32) 98 | loc_targets = loc_targets[:, 1:, :] 99 | loc_targets_mask = loc_targets_mask[:, 1:, :] 100 | loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight 101 | if self.use_giou: 102 | loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight 103 | total_loss = structure_loss + loc_loss + loc_loss_giou 104 | return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} 105 | else: 106 | total_loss = structure_loss + loc_loss 107 | return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} -------------------------------------------------------------------------------- /utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | from __future__ import unicode_literals 19 | 20 | import copy 21 | 22 | __all__ = ["build_metric"] 23 | 24 | from .table_metric import TableMetric 25 | 26 | 27 | def build_metric(config): 28 | support_dict = [ 29 | "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric" 30 | ] 31 | 32 | config = copy.deepcopy(config) 33 | module_name = config.pop("name") 34 | assert module_name in support_dict, Exception( 35 | "metric only support {}".format(support_dict)) 36 | module_class = eval(module_name)(**config) 37 | return module_class 38 | -------------------------------------------------------------------------------- /utils/metrics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/metrics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/metrics/__pycache__/table_metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/metrics/__pycache__/table_metric.cpython-38.pyc -------------------------------------------------------------------------------- /utils/metrics/table_metric.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | 16 | 17 | class TableMetric(object): 18 | def __init__(self, main_indicator='acc', **kwargs): 19 | self.main_indicator = main_indicator 20 | self.reset() 21 | 22 | def __call__(self, pred, batch, *args, **kwargs): 23 | structure_probs = pred['structure_probs'] # .detach().numpy() 24 | structure_labels = batch[1] 25 | correct_num = 0 26 | all_num = 0 27 | structure_probs = np.argmax(structure_probs, axis=2) 28 | structure_labels = structure_labels[:, 1:] 29 | batch_size = structure_probs.shape[0] 30 | for bno in range(batch_size): 31 | all_num += 1 32 | if (structure_probs[bno] == structure_labels[bno]).all(): 33 | correct_num += 1 34 | self.correct_num += correct_num 35 | self.all_num += all_num 36 | return { 37 | 'acc': correct_num * 1.0 / all_num, 38 | } 39 | 40 | def get_metric(self): 41 | """ 42 | return metrics { 43 | 'acc': 0, 44 | } 45 | """ 46 | acc = 1.0 * self.correct_num / self.all_num 47 | self.reset() 48 | return {'acc': acc} 49 | 50 | def reset(self): 51 | self.correct_num = 0 52 | self.all_num = 0 53 | -------------------------------------------------------------------------------- /utils/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import copy 6 | 7 | __all__ = ['build_optimizer'] 8 | 9 | 10 | def build_optimizer(config, epochs, step_each_epoch, parameters): 11 | from . import optimizer 12 | config = copy.deepcopy(config) 13 | # # step1 build lr 14 | lr = config.pop('lr')['learning_rate'] 15 | 16 | # step2 build optimizer 17 | optim_name = config.pop('name') 18 | optim = getattr(optimizer, optim_name)(learning_rate=lr, **config) 19 | return optim(parameters), lr 20 | -------------------------------------------------------------------------------- /utils/optimizer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/optimizer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/optimizer/__pycache__/optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/optimizer/__pycache__/optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/optimizer/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from torch import optim 7 | 8 | 9 | class Momentum(object): 10 | """ 11 | Simple Momentum optimizer with velocity state. 12 | Args: 13 | learning_rate (float|Variable) - The learning rate used to update parameters. 14 | Can be a float value or a Variable with one float value as data element. 15 | momentum (float) - Momentum factor. 16 | regularization (WeightDecayRegularizer, optional) - The strategy of regularization. 17 | """ 18 | 19 | def __init__(self, 20 | learning_rate, 21 | momentum, 22 | weight_decay=None, 23 | grad_clip=None, 24 | **args): 25 | super(Momentum, self).__init__() 26 | self.learning_rate = learning_rate 27 | self.momentum = momentum 28 | self.weight_decay = weight_decay 29 | self.grad_clip = grad_clip 30 | 31 | def __call__(self, parameters): 32 | opt = optim.Momentum( 33 | learning_rate=self.learning_rate, 34 | momentum=self.momentum, 35 | weight_decay=self.weight_decay, 36 | grad_clip=self.grad_clip, 37 | parameters=parameters) 38 | return opt 39 | 40 | 41 | class Adam(object): 42 | def __init__(self, 43 | learning_rate=0.001, 44 | beta1=0.9, 45 | beta2=0.999, 46 | epsilon=1e-08, 47 | parameter_list=None, 48 | name=None, 49 | lazy_mode=False, 50 | **kwargs): 51 | self.learning_rate = learning_rate 52 | self.beta1 = beta1 53 | self.beta2 = beta2 54 | self.epsilon = epsilon 55 | self.parameter_list = parameter_list 56 | self.learning_rate = learning_rate 57 | self.name = name 58 | self.lazy_mode = lazy_mode 59 | 60 | def __call__(self, parameters): 61 | opt = optim.Adam( 62 | lr=self.learning_rate, 63 | betas=(self.beta1, self.beta2), 64 | eps=self.epsilon, 65 | amsgrad=self.lazy_mode, 66 | params=parameters) 67 | return opt 68 | 69 | 70 | class RMSProp(object): 71 | """ 72 | Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method. 73 | Args: 74 | learning_rate (float|Variable) - The learning rate used to update parameters. 75 | Can be a float value or a Variable with one float value as data element. 76 | momentum (float) - Momentum factor. 77 | rho (float) - rho value in equation. 78 | epsilon (float) - avoid division by zero, default is 1e-6. 79 | regularization (WeightDecayRegularizer, optional) - The strategy of regularization. 80 | """ 81 | 82 | def __init__(self, 83 | learning_rate, 84 | momentum=0.0, 85 | rho=0.95, 86 | epsilon=1e-6, 87 | weight_decay=None, 88 | grad_clip=None, 89 | **args): 90 | super(RMSProp, self).__init__() 91 | self.learning_rate = learning_rate 92 | self.momentum = momentum 93 | self.rho = rho 94 | self.epsilon = epsilon 95 | self.weight_decay = weight_decay 96 | self.grad_clip = grad_clip 97 | 98 | def __call__(self, parameters): 99 | opt = optim.RMSProp( 100 | learning_rate=self.learning_rate, 101 | momentum=self.momentum, 102 | rho=self.rho, 103 | epsilon=self.epsilon, 104 | weight_decay=self.weight_decay, 105 | grad_clip=self.grad_clip, 106 | parameters=parameters) 107 | return opt 108 | -------------------------------------------------------------------------------- /utils/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import copy 7 | 8 | __all__ = ['build_post_process'] 9 | 10 | from .rec_postprocess import TableLabelDecode 11 | 12 | 13 | def build_post_process(config, global_config=None): 14 | support_dict = ['TableLabelDecode'] 15 | 16 | config = copy.deepcopy(config) 17 | module_name = config.pop('name') 18 | if global_config is not None: 19 | config.update(global_config) 20 | assert module_name in support_dict, Exception( 21 | 'post process only support {}'.format(support_dict)) 22 | module_class = eval(module_name)(**config) 23 | return module_class 24 | -------------------------------------------------------------------------------- /utils/postprocess/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/postprocess/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/postprocess/__pycache__/rec_postprocess.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/postprocess/__pycache__/rec_postprocess.cpython-38.pyc -------------------------------------------------------------------------------- /utils/postprocess/rec_postprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | import torch 4 | 5 | 6 | class TableLabelDecode(object): 7 | """ """ 8 | def __init__(self, 9 | character_dict_path, 10 | **kwargs): 11 | list_character, list_elem = self.load_char_elem_dict(character_dict_path) 12 | list_character = self.add_special_char(list_character) 13 | list_elem = self.add_special_char(list_elem) 14 | self.dict_character = {} 15 | self.dict_idx_character = {} 16 | for i, char in enumerate(list_character): 17 | self.dict_idx_character[i] = char 18 | self.dict_character[char] = i 19 | self.dict_elem = {} 20 | self.dict_idx_elem = {} 21 | for i, elem in enumerate(list_elem): 22 | self.dict_idx_elem[i] = elem 23 | self.dict_elem[elem] = i 24 | 25 | def load_char_elem_dict(self, character_dict_path): 26 | list_character = [] 27 | list_elem = [] 28 | with open(character_dict_path, "rb") as fin: 29 | lines = fin.readlines() 30 | substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t") 31 | character_num = int(substr[0]) 32 | elem_num = int(substr[1]) 33 | for cno in range(1, 1 + character_num): 34 | character = lines[cno].decode('utf-8').strip("\n").strip("\r\n") 35 | list_character.append(character) 36 | for eno in range(1 + character_num, 1 + character_num + elem_num): 37 | elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n") 38 | list_elem.append(elem) 39 | return list_character, list_elem 40 | 41 | def add_special_char(self, list_character): 42 | self.beg_str = "sos" 43 | self.end_str = "eos" 44 | list_character = [self.beg_str] + list_character + [self.end_str] 45 | return list_character 46 | 47 | def __call__(self, preds): 48 | structure_probs = preds['structure_probs'] 49 | loc_preds = preds['loc_preds'] 50 | if isinstance(structure_probs, torch.Tensor): 51 | structure_probs = structure_probs.numpy() 52 | if isinstance(loc_preds, torch.Tensor): 53 | loc_preds = loc_preds.numpy() 54 | structure_idx = structure_probs.argmax(axis=2) 55 | structure_probs = structure_probs.max(axis=2) 56 | structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx, 57 | structure_probs, 'elem') 58 | res_html_code_list = [] 59 | res_loc_list = [] 60 | batch_num = len(structure_str) 61 | for bno in range(batch_num): 62 | res_loc = [] 63 | for sno in range(len(structure_str[bno])): 64 | text = structure_str[bno][sno] 65 | if text in ['', ' 0 and tmp_elem_idx == end_idx: 98 | break 99 | if tmp_elem_idx in ignored_tokens: 100 | continue 101 | 102 | char_list.append(current_dict[tmp_elem_idx]) 103 | elem_pos_list.append(idx) 104 | score_list.append(structure_probs[batch_idx, idx]) 105 | elem_idx_list.append(tmp_elem_idx) 106 | result_list.append(char_list) 107 | result_pos_list.append(elem_pos_list) 108 | result_score_list.append(score_list) 109 | result_elem_idx_list.append(elem_idx_list) 110 | return result_list, result_pos_list, result_score_list, result_elem_idx_list 111 | 112 | def get_ignored_tokens(self, char_or_elem): 113 | beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) 114 | end_idx = self.get_beg_end_flag_idx("end", char_or_elem) 115 | return [beg_idx, end_idx] 116 | 117 | def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): 118 | if char_or_elem == "char": 119 | if beg_or_end == "beg": 120 | idx = self.dict_character[self.beg_str] 121 | elif beg_or_end == "end": 122 | idx = self.dict_character[self.end_str] 123 | else: 124 | assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ 125 | % beg_or_end 126 | elif char_or_elem == "elem": 127 | if beg_or_end == "beg": 128 | idx = self.dict_elem[self.beg_str] 129 | elif beg_or_end == "end": 130 | idx = self.dict_elem[self.end_str] 131 | else: 132 | assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ 133 | % beg_or_end 134 | else: 135 | assert False, "Unsupport type %s in char_or_elem" \ 136 | % char_or_elem 137 | return idx 138 | -------------------------------------------------------------------------------- /utils/save_load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/01/25 09:52 3 | # @Author : lijun 4 | 5 | import torch 6 | 7 | 8 | # def load_model(model, model_path): 9 | # model.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict']) 10 | # 11 | # 12 | # def save_model(save_dict, save_path): 13 | # torch.save(save_dict, save_path) 14 | 15 | 16 | def load_model(model, model_path): 17 | state_dict = {} 18 | for k, v in torch.load(model_path, map_location='cpu')['state_dict'].items(): 19 | state_dict[k.replace('module.', '')] = v 20 | model.load_state_dict(state_dict) 21 | 22 | 23 | def save_model(save_dict, save_path): 24 | state_dict = {} 25 | for k, v in save_dict['state_dict'].items(): 26 | state_dict[k.replace('module.', '')] = v 27 | save_dict['state_dict'] = state_dict 28 | torch.save(save_dict, save_path) 29 | -------------------------------------------------------------------------------- /utils/stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | import numpy as np 17 | import datetime 18 | 19 | __all__ = ['TrainingStats', 'Time'] 20 | 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size): 28 | self.deque = collections.deque(maxlen=window_size) 29 | 30 | def add_value(self, value): 31 | self.deque.append(value) 32 | 33 | def get_median_value(self): 34 | return np.median(self.deque) 35 | 36 | 37 | def Time(): 38 | return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') 39 | 40 | 41 | class TrainingStats(object): 42 | def __init__(self, window_size, stats_keys): 43 | self.window_size = window_size 44 | self.smoothed_losses_and_metrics = { 45 | key: SmoothedValue(window_size) 46 | for key in stats_keys 47 | } 48 | 49 | def update(self, stats): 50 | for k, v in stats.items(): 51 | if k not in self.smoothed_losses_and_metrics: 52 | self.smoothed_losses_and_metrics[k] = SmoothedValue( 53 | self.window_size) 54 | self.smoothed_losses_and_metrics[k].add_value(v) 55 | 56 | def get(self, extras=None): 57 | stats = collections.OrderedDict() 58 | if extras: 59 | for k, v in extras.items(): 60 | stats[k] = v 61 | for k, v in self.smoothed_losses_and_metrics.items(): 62 | stats[k] = round(v.get_median_value(), 6) 63 | 64 | return stats 65 | 66 | def log(self, extras=None): 67 | d = self.get(extras) 68 | strs = [] 69 | for k, v in d.items(): 70 | strs.append('{}: {:x<6f}'.format(k, v)) 71 | strs = ', '.join(strs) 72 | return strs 73 | -------------------------------------------------------------------------------- /utils/utility.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import imghdr 4 | import cv2 5 | import yaml 6 | import numpy as np 7 | from utils.torch_utils import select_device 8 | from argparse import ArgumentParser, RawDescriptionHelpFormatter 9 | from utils.logging import get_logger 10 | import torch 11 | 12 | 13 | def print_dict(d, logger, delimiter=0): 14 | """ 15 | Recursively visualize a dict and 16 | indenting acrrording by the relationship of keys. 17 | """ 18 | for k, v in sorted(d.items()): 19 | if isinstance(v, dict): 20 | logger.info("{}{} : ".format(delimiter * " ", str(k))) 21 | print_dict(v, logger, delimiter + 4) 22 | elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): 23 | logger.info("{}{} : ".format(delimiter * " ", str(k))) 24 | for value in v: 25 | print_dict(value, logger, delimiter + 4) 26 | else: 27 | logger.info("{}{} : {}".format(delimiter * " ", k, v)) 28 | 29 | 30 | def get_check_global_params(mode): 31 | check_params = ['use_gpu', 'max_text_length', 'image_shape', \ 32 | 'image_shape', 'character_type', 'loss_type'] 33 | if mode == "train_eval": 34 | check_params = check_params + [ \ 35 | 'train_batch_size_per_card', 'test_batch_size_per_card'] 36 | elif mode == "test": 37 | check_params = check_params + ['test_batch_size_per_card'] 38 | return check_params 39 | 40 | 41 | def get_image_file_list(img_file): 42 | imgs_lists = [] 43 | if img_file is None or not os.path.exists(img_file): 44 | raise Exception("not found any img file in {}".format(img_file)) 45 | 46 | img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'} 47 | if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: 48 | imgs_lists.append(img_file) 49 | elif os.path.isdir(img_file): 50 | for single_file in os.listdir(img_file): 51 | file_path = os.path.join(img_file, single_file) 52 | if os.path.isfile(file_path) and imghdr.what(file_path) in img_end: 53 | imgs_lists.append(file_path) 54 | if len(imgs_lists) == 0: 55 | raise Exception("not found any img file in {}".format(img_file)) 56 | imgs_lists = sorted(imgs_lists) 57 | return imgs_lists 58 | 59 | 60 | def check_and_read_gif(img_path): 61 | if os.path.basename(img_path)[-3:] in ['gif', 'GIF']: 62 | gif = cv2.VideoCapture(img_path) 63 | ret, frame = gif.read() 64 | if not ret: 65 | logger = logging.getLogger('ppocr') 66 | logger.info("Cannot read {}. This gif image maybe corrupted.") 67 | return None, False 68 | if len(frame.shape) == 2 or frame.shape[-1] == 1: 69 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) 70 | imgvalue = frame[:, :, ::-1] 71 | return imgvalue, True 72 | return None, False 73 | 74 | 75 | def get_rotate_crop_image(img, points): 76 | ''' 77 | img_height, img_width = img.shape[0:2] 78 | left = int(np.min(points[:, 0])) 79 | right = int(np.max(points[:, 0])) 80 | top = int(np.min(points[:, 1])) 81 | bottom = int(np.max(points[:, 1])) 82 | img_crop = img[top:bottom, left:right, :].copy() 83 | points[:, 0] = points[:, 0] - left 84 | points[:, 1] = points[:, 1] - top 85 | ''' 86 | assert len(points) == 4, "shape of points must be 4*2" 87 | img_crop_width = int( 88 | max( 89 | np.linalg.norm(points[0] - points[1]), 90 | np.linalg.norm(points[2] - points[3]))) 91 | img_crop_height = int( 92 | max( 93 | np.linalg.norm(points[0] - points[3]), 94 | np.linalg.norm(points[1] - points[2]))) 95 | pts_std = np.float32([[0, 0], [img_crop_width, 0], 96 | [img_crop_width, img_crop_height], 97 | [0, img_crop_height]]) 98 | M = cv2.getPerspectiveTransform(points, pts_std) 99 | dst_img = cv2.warpPerspective( 100 | img, 101 | M, (img_crop_width, img_crop_height), 102 | borderMode=cv2.BORDER_REPLICATE, 103 | flags=cv2.INTER_CUBIC) 104 | dst_img_height, dst_img_width = dst_img.shape[0:2] 105 | if dst_img_height * 1.0 / dst_img_width >= 1.5: 106 | dst_img = np.rot90(dst_img) 107 | return dst_img 108 | 109 | 110 | def preprocess(is_train=False): 111 | FLAGS = ArgsParser().parse_args() 112 | config = load_config(FLAGS.config) 113 | merge_config(FLAGS.opt) 114 | 115 | # check if set use_gpu=True in paddlepaddle cpu version 116 | use_gpu = config['Global']['use_gpu'] 117 | # check_gpu(use_gpu) 118 | 119 | alg = config['Architecture']['algorithm'] 120 | assert alg in ['TableAttn'] 121 | 122 | # device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' 123 | # device = paddle.set_device(device) 124 | device = select_device(use_gpu) 125 | 126 | # config['Global']['distributed'] = dist.get_world_size() != 1 127 | if is_train: 128 | # save_config 129 | save_model_dir = config['Global']['save_model_dir'] 130 | os.makedirs(save_model_dir, exist_ok=True) 131 | with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: 132 | yaml.dump( 133 | dict(config), f, default_flow_style=False, sort_keys=False) 134 | log_file = '{}/train.log'.format(save_model_dir) 135 | else: 136 | log_file = None 137 | logger = get_logger(name='root', log_file=log_file) 138 | 139 | print_dict(config, logger) 140 | logger.info('train with paddle {} and device {}'.format(torch.__version__, device)) 141 | return config, device, logger 142 | 143 | 144 | class ArgsParser(ArgumentParser): 145 | def __init__(self): 146 | super(ArgsParser, self).__init__( 147 | formatter_class=RawDescriptionHelpFormatter) 148 | self.add_argument("-c", "--config", help="configuration file to use") 149 | self.add_argument( 150 | "-o", "--opt", nargs='+', help="set configuration options") 151 | 152 | def parse_args(self, argv=None): 153 | args = super(ArgsParser, self).parse_args(argv) 154 | assert args.config is not None, \ 155 | "Please specify --config=configure_file_path." 156 | args.opt = self._parse_opt(args.opt) 157 | return args 158 | 159 | def _parse_opt(self, opts): 160 | config = {} 161 | if not opts: 162 | return config 163 | for s in opts: 164 | s = s.strip() 165 | k, v = s.split('=') 166 | config[k] = yaml.load(v, Loader=yaml.Loader) 167 | return config 168 | 169 | 170 | class AttrDict(dict): 171 | """Single level attribute dict, NOT recursive""" 172 | 173 | def __init__(self, **kwargs): 174 | super(AttrDict, self).__init__() 175 | super(AttrDict, self).update(kwargs) 176 | 177 | def __getattr__(self, key): 178 | if key in self: 179 | return self[key] 180 | raise AttributeError("object has no attribute '{}'".format(key)) 181 | 182 | 183 | global_config = AttrDict() 184 | default_config = {'Global': {'debug': False, }} 185 | 186 | 187 | def load_config(file_path): 188 | """ 189 | Load config from yml/yaml file. 190 | Args: 191 | file_path (str): Path of the config file to be loaded. 192 | Returns: global config 193 | """ 194 | merge_config(default_config) 195 | _, ext = os.path.splitext(file_path) 196 | assert ext in ['.yml', '.yaml'], "only support yaml files for now" 197 | merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)) 198 | return global_config 199 | 200 | 201 | def merge_config(config): 202 | """ 203 | Merge config into global config. 204 | Args: 205 | config (dict): Config to be merged. 206 | Returns: global config 207 | """ 208 | for key, value in config.items(): 209 | if "." not in key: 210 | if isinstance(value, dict) and key in global_config: 211 | global_config[key].update(value) 212 | else: 213 | global_config[key] = value 214 | else: 215 | sub_keys = key.split('.') 216 | assert ( 217 | sub_keys[0] in global_config 218 | ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( 219 | global_config.keys(), sub_keys[0]) 220 | cur = global_config[sub_keys[0]] 221 | for idx, sub_key in enumerate(sub_keys[1:]): 222 | if idx == len(sub_keys) - 2: 223 | cur[sub_key] = value 224 | else: 225 | cur = cur[sub_key] 226 | 227 | --------------------------------------------------------------------------------