├── README.md
├── __pycache__
└── program.cpython-38.pyc
├── configs
├── test_table_mv3.yml
└── train_table_mv3.yml
├── inference_model
├── baidu
│ └── infer
│ │ └── table_rec
│ │ └── pd2pt.pt
└── juneli
│ └── finetune
│ └── table_rec
│ └── best.pt
├── models
├── architectures
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── base_model.cpython-38.pyc
│ └── base_model.py
├── backbones
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── table_mobilenet_v3.cpython-38.pyc
│ │ └── table_resnet_vd.cpython-38.pyc
│ ├── table_mobilenet_v3.py
│ └── table_resnet_vd.py
└── heads
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── table_att_head.cpython-38.pyc
│ └── table_att_head.py
├── recognizer_table.py
├── script
├── paddle2pytorch.py
├── t_0.py
├── t_1.py
└── t_2.py
├── test.py
├── train.py
└── utils
├── __pycache__
├── logging.cpython-38.pyc
├── save_load.cpython-38.pyc
├── stats.cpython-38.pyc
├── torch_utils.cpython-38.pyc
└── utility.cpython-38.pyc
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── pubtab_dataset.cpython-38.pyc
├── imaug
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── copy_paste.cpython-38.pyc
│ │ ├── east_process.cpython-38.pyc
│ │ ├── gen_table_mask.cpython-38.pyc
│ │ ├── iaa_augment.cpython-38.pyc
│ │ ├── label_ops.cpython-38.pyc
│ │ ├── make_border_map.cpython-38.pyc
│ │ ├── make_shrink_map.cpython-38.pyc
│ │ ├── operators.cpython-38.pyc
│ │ ├── pg_process.cpython-38.pyc
│ │ ├── randaugment.cpython-38.pyc
│ │ ├── random_crop_data.cpython-38.pyc
│ │ ├── rec_img_aug.cpython-38.pyc
│ │ └── sast_process.cpython-38.pyc
│ ├── copy_paste.py
│ ├── east_process.py
│ ├── gen_table_mask.py
│ ├── iaa_augment.py
│ ├── label_ops.py
│ ├── make_border_map.py
│ ├── make_shrink_map.py
│ ├── operators.py
│ ├── pg_process.py
│ ├── randaugment.py
│ ├── random_crop_data.py
│ ├── rec_img_aug.py
│ ├── sast_process.py
│ └── text_image_aug
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── augment.cpython-38.pyc
│ │ └── warp_mls.cpython-38.pyc
│ │ ├── augment.py
│ │ └── warp_mls.py
└── pubtab_dataset.py
├── dict
├── ar_dict.txt
├── arabic_dict.txt
├── be_dict.txt
├── bg_dict.txt
├── chinese_cht_dict.txt
├── cyrillic_dict.txt
├── devanagari_dict.txt
├── en_dict.txt
├── fa_dict.txt
├── french_dict.txt
├── german_dict.txt
├── hi_dict.txt
├── it_dict.txt
├── japan_dict.txt
├── ka_dict.txt
├── korean_dict.txt
├── latin_dict.txt
├── mr_dict.txt
├── ne_dict.txt
├── oc_dict.txt
├── pu_dict.txt
├── rs_dict.txt
├── rsc_dict.txt
├── ru_dict.txt
├── ta_dict.txt
├── table_dict.txt
├── table_structure_dict.txt
├── te_dict.txt
├── ug_dict.txt
├── uk_dict.txt
├── ur_dict.txt
└── xi_dict.txt
├── logging.py
├── losses
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── table_att_loss.cpython-38.pyc
└── table_att_loss.py
├── metrics
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── table_metric.cpython-38.pyc
└── table_metric.py
├── optimizer
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── optimizer.cpython-38.pyc
└── optimizer.py
├── postprocess
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── rec_postprocess.cpython-38.pyc
└── rec_postprocess.py
├── save_load.py
├── stats.py
├── torch_utils.py
└── utility.py
/README.md:
--------------------------------------------------------------------------------
1 | *注:本项目只是做表格结构预测和cell location回归,如果需要填充内容,需要连接自己的OCR模块;
2 | - todo list
3 | - [x] 完成table cell box格式(box格式可以用来对单元格做检测,且可简化标注难度)转html算法
4 | - [x] paddle表格识别代码和预训练模型转pytorch,预训练模型转换代码script/paddle2pytorch.py
5 | - [x] 私有有线表格数据训练
6 | - [x] 训练、测试、inder代码重构
7 | - [ ] 私有无线表格数据训练
8 | - [ ] 表格识别模型cell location回归优化
9 |
10 | # 环境
11 | python:3.8
12 | pytorch:1.7.1
13 | CUDA:10.1
14 |
15 | # DATASETS
16 | ## Public:
17 | [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet): 预训练使用了这个数据集,可自行去官网下载。
18 |
19 | ## Private:
20 | 私人数据暂不公开,后期会提供table cell box格式转HTML代码。
21 |
22 | # train
23 | python train.py --config=configs/train_table_mv3.yml
24 |
25 | # test
26 | python test.py --config=configs/test_table_mv3.yml
27 |
28 | # infer
29 | python recognizer_table.py
30 |
31 | # model
32 | 百度基于Pubtabnet训练的模型(paddle转pytorch):
33 | inference_model/juneli/finetune/table_rec/pd2pt.pt
34 | 基于私人数据训练的有线表格识别模型(目前场景比较固定,后续会继续泛化,使用者也可自行finetune):
35 | inference_model/juneli/finetune/table_rec/best.pt
36 |
37 | # 参考
38 | [PPOCR](https://github.com/PaddlePaddle/PaddleOCR)
39 |
--------------------------------------------------------------------------------
/__pycache__/program.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/__pycache__/program.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/test_table_mv3.yml:
--------------------------------------------------------------------------------
1 | Global:
2 | use_gpu: '3'
3 | # evaluation is run every 400 iterations after the 0th iteration
4 | save_model_dir: /
5 | checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/baidu/infer/table_rec/pd2pt.pt
6 | # checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/juneli/finetune/table_rec/last.pt
7 | character_dict_path: utils/dict/table_structure_dict.txt
8 | character_type: en
9 | max_text_length: 100
10 | max_elem_length: 800
11 | max_cell_num: 500
12 |
13 | Architecture:
14 | model_type: table
15 | algorithm: TableAttn
16 | Backbone:
17 | name: MobileNetV3
18 | scale: 1.0
19 | model_name: large
20 | Head:
21 | name: TableAttentionHead
22 | hidden_size: 256
23 | l2_decay: 0.00001
24 | loc_type: 2
25 | max_text_length: 100
26 | max_elem_length: 800
27 | max_cell_num: 500
28 |
29 | PostProcess:
30 | name: TableLabelDecode
31 |
32 | Metric:
33 | name: TableMetric
34 | main_indicator: acc
35 |
36 | Eval:
37 | dataset:
38 | name: PubTabDataSet
39 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/public/pubtabnet_split/val/
40 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/public/pubtabnet_split/val.json
41 | # data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/
42 | # label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/val.json
43 | transforms:
44 | - DecodeImage: # load image
45 | img_mode: BGR
46 | channel_first: False
47 | - ResizeTableImage:
48 | max_len: 488
49 | - TableLabelEncode:
50 | - NormalizeImage:
51 | scale: 1./255.
52 | mean: [0.485, 0.456, 0.406]
53 | std: [0.229, 0.224, 0.225]
54 | order: 'hwc'
55 | - PaddingTableImage:
56 | - ToCHWImage:
57 | - KeepKeys:
58 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
59 | loader:
60 | shuffle: False
61 | drop_last: False
62 | batch_size_per_card: 128
63 | num_workers: 8
64 |
--------------------------------------------------------------------------------
/configs/train_table_mv3.yml:
--------------------------------------------------------------------------------
1 | Global:
2 | use_gpu: '0,1,2,3'
3 | epoch_num: 400
4 | log_smooth_window: 20
5 | print_batch_step: 5
6 | checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/baidu/infer/table_rec/pd2pt.pt
7 | # checkpoints: /workspace/JuneLi/bbtv/ExpCode/RARE/inference_model/juneli/finetune/table_rec/last.pt
8 | save_model_dir: ./inference_model/juneli/finetune/table_rec/
9 | save_epoch_step: 1
10 | # evaluation is run every 1000 iterations after the 0th iteration
11 | eval_batch_step: 1000
12 | cal_metric_during_train: True
13 | # for data or label process
14 | character_dict_path: utils/dict/table_structure_dict.txt
15 | character_type: en
16 | max_text_length: 100
17 | max_elem_length: 800
18 | max_cell_num: 500
19 | process_total_num: 0
20 | process_cut_num: 0
21 |
22 | Optimizer:
23 | name: Adam
24 | beta1: 0.9
25 | beta2: 0.999
26 | clip_norm: 5.0
27 | lr:
28 | learning_rate: 0.0001
29 | regularizer:
30 | name: 'L2'
31 | factor: 0.00000
32 |
33 | Architecture:
34 | model_type: table
35 | algorithm: TableAttn
36 | Backbone:
37 | name: MobileNetV3
38 | scale: 1.0
39 | model_name: large
40 | Head:
41 | name: TableAttentionHead
42 | hidden_size: 256
43 | l2_decay: 0.00001
44 | loc_type: 2
45 | max_text_length: 100
46 | max_elem_length: 800
47 | max_cell_num: 500
48 |
49 | Loss:
50 | name: TableAttentionLoss
51 | structure_weight: 100.0
52 | loc_weight: 10000.0
53 |
54 | PostProcess:
55 | name: TableLabelDecode
56 |
57 | Metric:
58 | name: TableMetric
59 | main_indicator: acc
60 |
61 | Train:
62 | dataset:
63 | name: PubTabDataSet
64 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/
65 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/train.json
66 | transforms:
67 | - DecodeImage: # load image
68 | img_mode: BGR
69 | channel_first: False
70 | - ResizeTableImage:
71 | max_len: 488
72 | - TableLabelEncode:
73 | - NormalizeImage:
74 | scale: 1./255.
75 | mean: [0.485, 0.456, 0.406]
76 | std: [0.229, 0.224, 0.225]
77 | order: 'hwc'
78 | - PaddingTableImage:
79 | - ToCHWImage:
80 | - KeepKeys:
81 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
82 | loader:
83 | shuffle: True
84 | batch_size_per_card: 24
85 | drop_last: True
86 | num_workers: 8
87 |
88 | Eval:
89 | dataset:
90 | name: PubTabDataSet
91 | data_dir: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/images/
92 | label_file_path: /workspace/JuneLi/bbtv/ExpCode/RARE/datasets/private/FundScan/val.json
93 | transforms:
94 | - DecodeImage: # load image
95 | img_mode: BGR
96 | channel_first: False
97 | - ResizeTableImage:
98 | max_len: 488
99 | - TableLabelEncode:
100 | - NormalizeImage:
101 | scale: 1./255.
102 | mean: [0.485, 0.456, 0.406]
103 | std: [0.229, 0.224, 0.225]
104 | order: 'hwc'
105 | - PaddingTableImage:
106 | - ToCHWImage:
107 | - KeepKeys:
108 | keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
109 | loader:
110 | shuffle: False
111 | drop_last: False
112 | batch_size_per_card: 32
113 | num_workers: 4
114 |
--------------------------------------------------------------------------------
/inference_model/baidu/infer/table_rec/pd2pt.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/inference_model/baidu/infer/table_rec/pd2pt.pt
--------------------------------------------------------------------------------
/inference_model/juneli/finetune/table_rec/best.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/inference_model/juneli/finetune/table_rec/best.pt
--------------------------------------------------------------------------------
/models/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import importlib
3 |
4 | from .base_model import BaseModel
5 | # from .distillation_model import DistillationModel
6 |
7 | __all__ = ['build_model']
8 |
9 |
10 | def build_model(config):
11 | config = copy.deepcopy(config)
12 | if not "name" in config:
13 | arch = BaseModel(config)
14 | else:
15 | name = config.pop("name")
16 | mod = importlib.import_module(__name__)
17 | arch = getattr(mod, name)(config)
18 | return arch
19 |
--------------------------------------------------------------------------------
/models/architectures/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/architectures/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/architectures/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/architectures/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/models/architectures/base_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 | from torch import nn
18 | from models.backbones import build_backbone
19 | from models.heads import build_head
20 |
21 | __all__ = ['BaseModel']
22 |
23 |
24 | class BaseModel(nn.Module):
25 | def __init__(self, config):
26 | """
27 | the module for OCR.
28 | args:
29 | config (dict): the super parameters for module.
30 | """
31 | super(BaseModel, self).__init__()
32 | in_channels = config.get('in_channels', 3)
33 | model_type = config['model_type']
34 | # build transfrom,
35 | # for rec, transfrom can be TPS,None
36 | # for det and cls, transfrom should be None,
37 | # if you make model differently, you can use transfrom in det and cls
38 | if 'Transform' not in config or config['Transform'] is None:
39 | self.use_transform = False
40 | else:
41 | self.use_transform = True
42 | config['Transform']['in_channels'] = in_channels
43 | self.transform = build_transform(config['Transform'])
44 | in_channels = self.transform.out_channels
45 |
46 | # build backbone, backbone is needed for del, rec and cls
47 | config["Backbone"]['in_channels'] = in_channels
48 | self.backbone = build_backbone(config["Backbone"], model_type)
49 | in_channels = self.backbone.out_channels
50 |
51 | # build neck
52 | # for rec, neck can be cnn,rnn or reshape(None)
53 | # for det, neck can be FPN, BIFPN and so on.
54 | # for cls, neck should be none
55 | if 'Neck' not in config or config['Neck'] is None:
56 | self.use_neck = False
57 | else:
58 | self.use_neck = True
59 | config['Neck']['in_channels'] = in_channels
60 | self.neck = build_neck(config['Neck'])
61 | in_channels = self.neck.out_channels
62 |
63 | # # build head, head is needed for det, rec and cls
64 | config["Head"]['in_channels'] = in_channels
65 | self.head = build_head(config["Head"])
66 |
67 | self.return_all_feats = config.get("return_all_feats", False)
68 |
69 | def forward(self, x, data=None):
70 | y = dict()
71 | if self.use_transform:
72 | x = self.transform(x)
73 | x = self.backbone(x)
74 | y["backbone_out"] = x
75 | if self.use_neck:
76 | x = self.neck(x)
77 | y["neck_out"] = x
78 | x = self.head(x, targets=data)
79 | if isinstance(x, dict):
80 | y.update(x)
81 | else:
82 | y["head_out"] = x
83 | if self.return_all_feats:
84 | return y
85 | else:
86 | return x
87 |
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ["build_backbone"]
16 |
17 |
18 | def build_backbone(config, model_type):
19 | if model_type == "table":
20 | from .table_resnet_vd import ResNet
21 | from .table_mobilenet_v3 import MobileNetV3
22 | support_dict = ["ResNet", "MobileNetV3"]
23 | else:
24 | raise NotImplementedError
25 |
26 | module_name = config.pop("name")
27 | assert module_name in support_dict, Exception(
28 | "when model typs is {}, backbone only support {}".format(model_type,
29 | support_dict))
30 | module_class = eval(module_name)(**config)
31 | return module_class
32 |
--------------------------------------------------------------------------------
/models/backbones/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__pycache__/table_mobilenet_v3.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/table_mobilenet_v3.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/__pycache__/table_resnet_vd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/backbones/__pycache__/table_resnet_vd.cpython-38.pyc
--------------------------------------------------------------------------------
/models/backbones/table_mobilenet_v3.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | __all__ = ['MobileNetV3']
10 |
11 |
12 | def make_divisible(v, divisor=8, min_value=None):
13 | if min_value is None:
14 | min_value = divisor
15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
16 | if new_v < 0.9 * v:
17 | new_v += divisor
18 | return new_v
19 |
20 |
21 | class MobileNetV3(nn.Module):
22 | def __init__(self,
23 | in_channels=3,
24 | model_name='large',
25 | scale=0.5,
26 | disable_se=False,
27 | **kwargs):
28 | """
29 | the MobilenetV3 backbone network for detection module.
30 | Args:
31 | params(dict): the super parameters for build network
32 | """
33 | super(MobileNetV3, self).__init__()
34 |
35 | self.disable_se = disable_se
36 |
37 | if model_name == "large":
38 | cfg = [
39 | # k, exp, c, se, nl, s,
40 | [3, 16, 16, False, 'relu', 1],
41 | [3, 64, 24, False, 'relu', 2],
42 | [3, 72, 24, False, 'relu', 1],
43 | [5, 72, 40, True, 'relu', 2],
44 | [5, 120, 40, True, 'relu', 1],
45 | [5, 120, 40, True, 'relu', 1],
46 | [3, 240, 80, False, 'hardswish', 2],
47 | [3, 200, 80, False, 'hardswish', 1],
48 | [3, 184, 80, False, 'hardswish', 1],
49 | [3, 184, 80, False, 'hardswish', 1],
50 | [3, 480, 112, True, 'hardswish', 1],
51 | [3, 672, 112, True, 'hardswish', 1],
52 | [5, 672, 160, True, 'hardswish', 2],
53 | [5, 960, 160, True, 'hardswish', 1],
54 | [5, 960, 160, True, 'hardswish', 1],
55 | ]
56 | cls_ch_squeeze = 960
57 | elif model_name == "small":
58 | cfg = [
59 | # k, exp, c, se, nl, s,
60 | [3, 16, 16, True, 'relu', 2],
61 | [3, 72, 24, False, 'relu', 2],
62 | [3, 88, 24, False, 'relu', 1],
63 | [5, 96, 40, True, 'hardswish', 2],
64 | [5, 240, 40, True, 'hardswish', 1],
65 | [5, 240, 40, True, 'hardswish', 1],
66 | [5, 120, 48, True, 'hardswish', 1],
67 | [5, 144, 48, True, 'hardswish', 1],
68 | [5, 288, 96, True, 'hardswish', 2],
69 | [5, 576, 96, True, 'hardswish', 1],
70 | [5, 576, 96, True, 'hardswish', 1],
71 | ]
72 | cls_ch_squeeze = 576
73 | else:
74 | raise NotImplementedError("mode[" + model_name +
75 | "_model] is not implemented!")
76 |
77 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
78 | assert scale in supported_scale, \
79 | "supported scale are {} but input scale is {}".format(supported_scale, scale)
80 | inplanes = 16
81 | # conv1
82 | self.conv = ConvBNLayer(
83 | in_channels=in_channels,
84 | out_channels=make_divisible(inplanes * scale),
85 | kernel_size=3,
86 | stride=2,
87 | padding=1,
88 | groups=1,
89 | if_act=True,
90 | act='hardswish',
91 | name='conv1')
92 |
93 | self.stages = []
94 | self.out_channels = []
95 | block_list = []
96 | i = 0
97 | inplanes = make_divisible(inplanes * scale)
98 | for (k, exp, c, se, nl, s) in cfg:
99 | se = se and not self.disable_se
100 | start_idx = 2 if model_name == 'large' else 0
101 | if s == 2 and i > start_idx:
102 | self.out_channels.append(inplanes)
103 | self.stages.append(nn.Sequential(*block_list))
104 | block_list = []
105 | block_list.append(
106 | ResidualUnit(
107 | in_channels=inplanes,
108 | mid_channels=make_divisible(scale * exp),
109 | out_channels=make_divisible(scale * c),
110 | kernel_size=k,
111 | stride=s,
112 | use_se=se,
113 | act=nl,
114 | name="conv" + str(i + 2)))
115 | inplanes = make_divisible(scale * c)
116 | i += 1
117 | block_list.append(
118 | ConvBNLayer(
119 | in_channels=inplanes,
120 | out_channels=make_divisible(scale * cls_ch_squeeze),
121 | kernel_size=1,
122 | stride=1,
123 | padding=0,
124 | groups=1,
125 | if_act=True,
126 | act='hardswish',
127 | name='conv_last'))
128 | self.stages.append(nn.Sequential(*block_list))
129 | self.stages_pipline = nn.Sequential(*self.stages)
130 | self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
131 | # for i, stage in enumerate(self.stages):
132 | # self.add_sublayer(sublayer=stage, name="stage{}".format(i))
133 |
134 | def forward(self, x):
135 | x = self.conv(x)
136 | out_list = []
137 | # for stage in self.stages:
138 | # x = stage(x)
139 | # out_list.append(x)
140 | x = self.stages_pipline(x)
141 | out_list.append(x)
142 | return out_list
143 |
144 |
145 | class ConvBNLayer(nn.Module):
146 | def __init__(self,
147 | in_channels,
148 | out_channels,
149 | kernel_size,
150 | stride,
151 | padding,
152 | groups=1,
153 | if_act=True,
154 | act=None,
155 | name=None):
156 | super(ConvBNLayer, self).__init__()
157 | self.if_act = if_act
158 | self.act = act
159 | self.conv = nn.Conv2d(
160 | in_channels=in_channels,
161 | out_channels=out_channels,
162 | kernel_size=kernel_size,
163 | stride=stride,
164 | padding=padding,
165 | groups=groups,
166 | bias=False)
167 |
168 | self.bn = nn.BatchNorm2d(out_channels)
169 |
170 | def forward(self, x):
171 | x = self.conv(x)
172 | x = self.bn(x)
173 | if self.if_act:
174 | if self.act == "relu":
175 | x = F.relu(x)
176 | elif self.act == "hardswish":
177 | x = F.hardswish(x)
178 | else:
179 | print("The activation function({}) is selected incorrectly.".
180 | format(self.act))
181 | exit()
182 | return x
183 |
184 |
185 | class ResidualUnit(nn.Module):
186 | def __init__(self,
187 | in_channels,
188 | mid_channels,
189 | out_channels,
190 | kernel_size,
191 | stride,
192 | use_se,
193 | act=None,
194 | name=''):
195 | super(ResidualUnit, self).__init__()
196 | self.if_shortcut = stride == 1 and in_channels == out_channels
197 | self.if_se = use_se
198 |
199 | self.expand_conv = ConvBNLayer(
200 | in_channels=in_channels,
201 | out_channels=mid_channels,
202 | kernel_size=1,
203 | stride=1,
204 | padding=0,
205 | if_act=True,
206 | act=act,
207 | name=name + "_expand")
208 | self.bottleneck_conv = ConvBNLayer(
209 | in_channels=mid_channels,
210 | out_channels=mid_channels,
211 | kernel_size=kernel_size,
212 | stride=stride,
213 | padding=int((kernel_size - 1) // 2),
214 | groups=mid_channels,
215 | if_act=True,
216 | act=act,
217 | name=name + "_depthwise")
218 | if self.if_se:
219 | self.mid_se = SEModule(mid_channels, name=name + "_se")
220 | self.linear_conv = ConvBNLayer(
221 | in_channels=mid_channels,
222 | out_channels=out_channels,
223 | kernel_size=1,
224 | stride=1,
225 | padding=0,
226 | if_act=False,
227 | act=None,
228 | name=name + "_linear")
229 |
230 | def forward(self, inputs):
231 | x = self.expand_conv(inputs)
232 | x = self.bottleneck_conv(x)
233 | if self.if_se:
234 | x = self.mid_se(x)
235 | x = self.linear_conv(x)
236 | if self.if_shortcut:
237 | x = torch.add(inputs, x)
238 | return x
239 |
240 |
241 | class SEModule(nn.Module):
242 | def __init__(self, in_channels, reduction=4, name=""):
243 | super(SEModule, self).__init__()
244 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
245 | self.conv1 = nn.Conv2d(
246 | in_channels=in_channels,
247 | out_channels=in_channels // reduction,
248 | kernel_size=1,
249 | stride=1,
250 | padding=0)
251 | self.conv2 = nn.Conv2d(
252 | in_channels=in_channels // reduction,
253 | out_channels=in_channels,
254 | kernel_size=1,
255 | stride=1,
256 | padding=0)
257 |
258 | def forward(self, inputs):
259 | outputs = self.avg_pool(inputs)
260 | outputs = self.conv1(outputs)
261 | outputs = F.relu(outputs)
262 | outputs = self.conv2(outputs)
263 | # outputs = F.hardsigmoid(outputs)
264 | outputs = F.relu6(1.2 * outputs + 3.) / 6.
265 | return inputs * outputs
266 |
--------------------------------------------------------------------------------
/models/backbones/table_resnet_vd.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import paddle
20 | from paddle import ParamAttr
21 | import paddle.nn as nn
22 | import paddle.nn.functional as F
23 |
24 | __all__ = ["ResNet"]
25 |
26 |
27 | class ConvBNLayer(nn.Layer):
28 | def __init__(
29 | self,
30 | in_channels,
31 | out_channels,
32 | kernel_size,
33 | stride=1,
34 | groups=1,
35 | is_vd_mode=False,
36 | act=None,
37 | name=None, ):
38 | super(ConvBNLayer, self).__init__()
39 |
40 | self.is_vd_mode = is_vd_mode
41 | self._pool2d_avg = nn.AvgPool2D(
42 | kernel_size=2, stride=2, padding=0, ceil_mode=True)
43 | self._conv = nn.Conv2D(
44 | in_channels=in_channels,
45 | out_channels=out_channels,
46 | kernel_size=kernel_size,
47 | stride=stride,
48 | padding=(kernel_size - 1) // 2,
49 | groups=groups,
50 | weight_attr=ParamAttr(name=name + "_weights"),
51 | bias_attr=False)
52 | if name == "conv1":
53 | bn_name = "bn_" + name
54 | else:
55 | bn_name = "bn" + name[3:]
56 | self._batch_norm = nn.BatchNorm(
57 | out_channels,
58 | act=act,
59 | param_attr=ParamAttr(name=bn_name + '_scale'),
60 | bias_attr=ParamAttr(bn_name + '_offset'),
61 | moving_mean_name=bn_name + '_mean',
62 | moving_variance_name=bn_name + '_variance')
63 |
64 | def forward(self, inputs):
65 | if self.is_vd_mode:
66 | inputs = self._pool2d_avg(inputs)
67 | y = self._conv(inputs)
68 | y = self._batch_norm(y)
69 | return y
70 |
71 |
72 | class BottleneckBlock(nn.Layer):
73 | def __init__(self,
74 | in_channels,
75 | out_channels,
76 | stride,
77 | shortcut=True,
78 | if_first=False,
79 | name=None):
80 | super(BottleneckBlock, self).__init__()
81 |
82 | self.conv0 = ConvBNLayer(
83 | in_channels=in_channels,
84 | out_channels=out_channels,
85 | kernel_size=1,
86 | act='relu',
87 | name=name + "_branch2a")
88 | self.conv1 = ConvBNLayer(
89 | in_channels=out_channels,
90 | out_channels=out_channels,
91 | kernel_size=3,
92 | stride=stride,
93 | act='relu',
94 | name=name + "_branch2b")
95 | self.conv2 = ConvBNLayer(
96 | in_channels=out_channels,
97 | out_channels=out_channels * 4,
98 | kernel_size=1,
99 | act=None,
100 | name=name + "_branch2c")
101 |
102 | if not shortcut:
103 | self.short = ConvBNLayer(
104 | in_channels=in_channels,
105 | out_channels=out_channels * 4,
106 | kernel_size=1,
107 | stride=1,
108 | is_vd_mode=False if if_first else True,
109 | name=name + "_branch1")
110 |
111 | self.shortcut = shortcut
112 |
113 | def forward(self, inputs):
114 | y = self.conv0(inputs)
115 | conv1 = self.conv1(y)
116 | conv2 = self.conv2(conv1)
117 |
118 | if self.shortcut:
119 | short = inputs
120 | else:
121 | short = self.short(inputs)
122 | y = paddle.add(x=short, y=conv2)
123 | y = F.relu(y)
124 | return y
125 |
126 |
127 | class BasicBlock(nn.Layer):
128 | def __init__(self,
129 | in_channels,
130 | out_channels,
131 | stride,
132 | shortcut=True,
133 | if_first=False,
134 | name=None):
135 | super(BasicBlock, self).__init__()
136 | self.stride = stride
137 | self.conv0 = ConvBNLayer(
138 | in_channels=in_channels,
139 | out_channels=out_channels,
140 | kernel_size=3,
141 | stride=stride,
142 | act='relu',
143 | name=name + "_branch2a")
144 | self.conv1 = ConvBNLayer(
145 | in_channels=out_channels,
146 | out_channels=out_channels,
147 | kernel_size=3,
148 | act=None,
149 | name=name + "_branch2b")
150 |
151 | if not shortcut:
152 | self.short = ConvBNLayer(
153 | in_channels=in_channels,
154 | out_channels=out_channels,
155 | kernel_size=1,
156 | stride=1,
157 | is_vd_mode=False if if_first else True,
158 | name=name + "_branch1")
159 |
160 | self.shortcut = shortcut
161 |
162 | def forward(self, inputs):
163 | y = self.conv0(inputs)
164 | conv1 = self.conv1(y)
165 |
166 | if self.shortcut:
167 | short = inputs
168 | else:
169 | short = self.short(inputs)
170 | y = paddle.add(x=short, y=conv1)
171 | y = F.relu(y)
172 | return y
173 |
174 |
175 | class ResNet(nn.Layer):
176 | def __init__(self, in_channels=3, layers=50, **kwargs):
177 | super(ResNet, self).__init__()
178 |
179 | self.layers = layers
180 | supported_layers = [18, 34, 50, 101, 152, 200]
181 | assert layers in supported_layers, \
182 | "supported layers are {} but input layer is {}".format(
183 | supported_layers, layers)
184 |
185 | if layers == 18:
186 | depth = [2, 2, 2, 2]
187 | elif layers == 34 or layers == 50:
188 | depth = [3, 4, 6, 3]
189 | elif layers == 101:
190 | depth = [3, 4, 23, 3]
191 | elif layers == 152:
192 | depth = [3, 8, 36, 3]
193 | elif layers == 200:
194 | depth = [3, 12, 48, 3]
195 | num_channels = [64, 256, 512,
196 | 1024] if layers >= 50 else [64, 64, 128, 256]
197 | num_filters = [64, 128, 256, 512]
198 |
199 | self.conv1_1 = ConvBNLayer(
200 | in_channels=in_channels,
201 | out_channels=32,
202 | kernel_size=3,
203 | stride=2,
204 | act='relu',
205 | name="conv1_1")
206 | self.conv1_2 = ConvBNLayer(
207 | in_channels=32,
208 | out_channels=32,
209 | kernel_size=3,
210 | stride=1,
211 | act='relu',
212 | name="conv1_2")
213 | self.conv1_3 = ConvBNLayer(
214 | in_channels=32,
215 | out_channels=64,
216 | kernel_size=3,
217 | stride=1,
218 | act='relu',
219 | name="conv1_3")
220 | self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
221 |
222 | self.stages = []
223 | self.out_channels = []
224 | if layers >= 50:
225 | for block in range(len(depth)):
226 | block_list = []
227 | shortcut = False
228 | for i in range(depth[block]):
229 | if layers in [101, 152] and block == 2:
230 | if i == 0:
231 | conv_name = "res" + str(block + 2) + "a"
232 | else:
233 | conv_name = "res" + str(block + 2) + "b" + str(i)
234 | else:
235 | conv_name = "res" + str(block + 2) + chr(97 + i)
236 | bottleneck_block = self.add_sublayer(
237 | 'bb_%d_%d' % (block, i),
238 | BottleneckBlock(
239 | in_channels=num_channels[block]
240 | if i == 0 else num_filters[block] * 4,
241 | out_channels=num_filters[block],
242 | stride=2 if i == 0 and block != 0 else 1,
243 | shortcut=shortcut,
244 | if_first=block == i == 0,
245 | name=conv_name))
246 | shortcut = True
247 | block_list.append(bottleneck_block)
248 | self.out_channels.append(num_filters[block] * 4)
249 | self.stages.append(nn.Sequential(*block_list))
250 | else:
251 | for block in range(len(depth)):
252 | block_list = []
253 | shortcut = False
254 | for i in range(depth[block]):
255 | conv_name = "res" + str(block + 2) + chr(97 + i)
256 | basic_block = self.add_sublayer(
257 | 'bb_%d_%d' % (block, i),
258 | BasicBlock(
259 | in_channels=num_channels[block]
260 | if i == 0 else num_filters[block],
261 | out_channels=num_filters[block],
262 | stride=2 if i == 0 and block != 0 else 1,
263 | shortcut=shortcut,
264 | if_first=block == i == 0,
265 | name=conv_name))
266 | shortcut = True
267 | block_list.append(basic_block)
268 | self.out_channels.append(num_filters[block])
269 | self.stages.append(nn.Sequential(*block_list))
270 |
271 | def forward(self, inputs):
272 | y = self.conv1_1(inputs)
273 | y = self.conv1_2(y)
274 | y = self.conv1_3(y)
275 | y = self.pool2d_max(y)
276 | out = []
277 | for block in self.stages:
278 | y = block(y)
279 | out.append(y)
280 | return out
281 |
--------------------------------------------------------------------------------
/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __all__ = ['build_head']
16 |
17 |
18 | def build_head(config):
19 | support_dict = ['TableAttentionHead']
20 | #table head
21 | from .table_att_head import TableAttentionHead
22 |
23 | module_name = config.pop('name')
24 | assert module_name in support_dict, Exception('head only support {}'.format(
25 | support_dict))
26 | module_class = eval(module_name)(**config)
27 | return module_class
28 |
--------------------------------------------------------------------------------
/models/heads/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/heads/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/heads/__pycache__/table_att_head.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/models/heads/__pycache__/table_att_head.cpython-38.pyc
--------------------------------------------------------------------------------
/recognizer_table.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/01/25 09:52
3 | # @Author : lijun
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import sys
10 | import platform
11 | import time
12 | import cv2
13 | import torch
14 | import numpy as np
15 | from tqdm import tqdm
16 |
17 | from models.architectures import build_model
18 | from utils.save_load import load_model
19 | from utils.torch_utils import select_device
20 | from utils.data import transform, create_operators
21 | from utils.postprocess import build_post_process
22 |
23 |
24 | class Recognizer:
25 | def __init__(self, model_path, table_char_dict_path='utils/dict/table_structure_dict.txt',
26 | gpu='0', half_flag=False):
27 | self.weights = model_path
28 | self.table_char_dict_path = table_char_dict_path
29 | self.gpu = gpu
30 | self.device = select_device(self.gpu)
31 | self.half_flag = half_flag
32 | ckpt = torch.load(self.weights, map_location='cpu')
33 | cfg = ckpt['cfg']
34 |
35 | pre_process_list = []
36 | for pre_process in cfg['Train']['dataset']['transforms']:
37 | if 'DecodeImage' in pre_process.keys() or 'TableLabelEncode' in pre_process.keys():
38 | continue
39 | if 'KeepKeys' in pre_process.keys():
40 | pre_process['KeepKeys'] = {'keep_keys': ['image']}
41 | pre_process_list.append(pre_process)
42 | self.preprocess_op = create_operators(pre_process_list)
43 | postprocess_params = {
44 | 'name': 'TableLabelDecode',
45 | "character_type": 'en',
46 | "character_dict_path": self.table_char_dict_path,
47 | }
48 | self.postprocess_op = build_post_process(postprocess_params)
49 |
50 | self.model = build_model(cfg['Architecture'])
51 | load_model(self.model, model_path)
52 | self.model.to(self.device)
53 | self.half = self.device.type != 'cpu' # half precision only supported on CUDA
54 | if self.half_flag and self.half:
55 | self.model.half() # to FP16
56 |
57 | def inference(self, img):
58 | ori_im = img.copy()
59 | data = {'image': img}
60 | data = transform(data, self.preprocess_op)
61 | img = data[0]
62 | if img is None:
63 | return None
64 | img = np.expand_dims(img, axis=0)
65 | img = img.copy()
66 | img = torch.tensor(img).to(self.device)
67 | img = img.half() if self.half_flag and self.half else img
68 | preds = self.model(img)
69 |
70 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(),
71 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()}
72 | post_result = self.postprocess_op(preds)
73 |
74 | structure_str_list = post_result['structure_str_list']
75 | res_loc = post_result['res_loc']
76 | imgh, imgw = ori_im.shape[0:2]
77 | res_loc_final = []
78 | for rno in range(len(res_loc[0])):
79 | x0, y0, x1, y1 = res_loc[0][rno]
80 | left = max(int(imgw * x0), 0)
81 | top = max(int(imgh * y0), 0)
82 | right = min(int(imgw * x1), imgw - 1)
83 | bottom = min(int(imgh * y1), imgh - 1)
84 | res_loc_final.append([left, top, right, bottom])
85 | structure_str_list = structure_str_list[0][:-1]
86 | structure_str_list = ['', '
', ''] + structure_str_list + ['
', '', '']
87 | return structure_str_list, res_loc_final
88 |
89 |
90 | if __name__ == '__main__':
91 | def html_configuration():
92 | html_head = '\n' + \
93 | '\n' + \
94 | '\n' + \
95 | ' \n' + \
96 | ' Title\n' + \
97 | '\n' + \
98 | '\n' + \
99 | '\n' + \
107 | '\n' + \
109 | '\n' + \
110 | '\n'
111 | return [html_head, html_tail]
112 |
113 |
114 | # table_rec = Recognizer('inference_model/baidu/infer/table_rec/pd2pt.pt', gpu='1')
115 | table_rec = Recognizer('inference_model/juneli/finetune/table_rec/last.pt', gpu='cpu')
116 | base_path = './'
117 | in_path = base_path + 'test_data/FundScan/'
118 | out_path = base_path + 'test_out/FundScan/'
119 | if not os.path.exists(out_path):
120 | os.makedirs(out_path)
121 | # file_name_list = os.listdir(base_path + 'datasets/public/pubtabnet_subset/train/')
122 | file_name_list = os.listdir(in_path)
123 | for idx, file_name in enumerate(tqdm(file_name_list)):
124 | # if not file_name.startswith('0.jpg'):
125 | # continue
126 | # if file_name != '0.jpg':
127 | # continue
128 | img = cv2.imread(in_path + file_name)
129 | structure, loc = table_rec.inference(img)
130 | # my output->开始
131 | # print('html: \n', structure)
132 | open(out_path + file_name.replace('.jpg', '.html').replace('.png', '.html'), 'w').write(
133 | html_configuration()[0] +
134 | ''.join([i.replace('"', '"') for i in structure[3:-3]]) + '\n' + html_configuration()[1])
135 | print(loc)
136 | # print(len(structure), len(loc))
137 | show_img = img.copy()
138 | for box in loc:
139 | cv2.rectangle(show_img, tuple(box[:2]), tuple(box[2:]), (0, 0, 255), 2)
140 | cv2.imwrite(out_path + file_name, show_img)
141 | # my output->结束
142 |
143 |
144 |
--------------------------------------------------------------------------------
/script/paddle2pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import paddle
3 | import numpy as np
4 | import yaml
5 |
6 |
7 | config = yaml.load(open('0.yml', 'r'), Loader=yaml.FullLoader) # 0.yml是配置文件,eg:configs/train_table_mv3.yml
8 | pd = paddle.load('/Volumes/my_disk/company/xxx/buffer_disk/a/best_accuracy.pdparams')
9 | ckpt = torch.load('/Volumes/my_disk/company/xxx/buffer_disk/a/last.pt', map_location='cpu') # ckpt表示pytorch模型
10 |
11 | out_dict = {}
12 | for pd_k in pd.keys():
13 | ckpt_k = pd_k.replace('stage', 'stages_pipline.').replace('_mean', 'running_mean').replace('_variance', 'running_var')
14 | if np.shape(pd[pd_k].numpy()) != np.shape(ckpt[ckpt_k].numpy()) or pd_k == 'head.structure_attention_cell.h2h.weight':
15 | pd[pd_k] = paddle.transpose(pd[pd_k], (1, 0))
16 | out_dict[ckpt_k] = torch.tensor(pd[pd_k].numpy())
17 | torch.save({'state_dict': out_dict, 'cfg': config}, '/Volumes/my_disk/company/xxx/buffer_disk/a/pd2pt.pt')
--------------------------------------------------------------------------------
/script/t_0.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 |
17 | __dir__ = os.path.dirname(os.path.abspath(__file__))
18 | sys.path.append(__dir__)
19 | sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
20 |
21 | os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
22 |
23 | import cv2
24 | import numpy as np
25 | import time
26 |
27 | import tools.infer.utility as utility
28 | from ppocr.data import create_operators, transform
29 | from ppocr.postprocess import build_post_process
30 | from ppocr.utils.logging import get_logger
31 | from ppocr.utils.utility import get_image_file_list, check_and_read_gif
32 | from ppstructure.utility import parse_args
33 |
34 | logger = get_logger()
35 |
36 |
37 | class TableStructurer(object):
38 | def __init__(self, args):
39 | pre_process_list = [{
40 | 'ResizeTableImage': {
41 | 'max_len': args.table_max_len
42 | }
43 | }, {
44 | 'NormalizeImage': {
45 | 'std': [0.229, 0.224, 0.225],
46 | 'mean': [0.485, 0.456, 0.406],
47 | 'scale': '1./255.',
48 | 'order': 'hwc'
49 | }
50 | }, {
51 | 'PaddingTableImage': None
52 | }, {
53 | 'ToCHWImage': None
54 | }, {
55 | 'KeepKeys': {
56 | 'keep_keys': ['image']
57 | }
58 | }]
59 | postprocess_params = {
60 | 'name': 'TableLabelDecode',
61 | "character_type": args.table_char_type,
62 | "character_dict_path": args.table_char_dict_path,
63 | }
64 |
65 | self.preprocess_op = create_operators(pre_process_list)
66 | self.postprocess_op = build_post_process(postprocess_params)
67 | self.predictor, self.input_tensor, self.output_tensors, self.config = \
68 | utility.create_predictor(args, 'table', logger)
69 |
70 | def __call__(self, img):
71 | ori_im = img.copy()
72 | data = {'image': img}
73 | data = transform(data, self.preprocess_op)
74 | img = data[0]
75 | if img is None:
76 | return None, 0
77 | img = np.expand_dims(img, axis=0)
78 | img = img.copy()
79 | starttime = time.time()
80 |
81 | self.input_tensor.copy_from_cpu(img)
82 | self.predictor.run()
83 | outputs = []
84 | for output_tensor in self.output_tensors:
85 | output = output_tensor.copy_to_cpu()
86 | outputs.append(output)
87 |
88 | preds = {}
89 | preds['structure_probs'] = outputs[1]
90 | preds['loc_preds'] = outputs[0]
91 |
92 | post_result = self.postprocess_op(preds)
93 |
94 | structure_str_list = post_result['structure_str_list']
95 | res_loc = post_result['res_loc']
96 | imgh, imgw = ori_im.shape[0:2]
97 | res_loc_final = []
98 | for rno in range(len(res_loc[0])):
99 | x0, y0, x1, y1 = res_loc[0][rno]
100 | left = max(int(imgw * x0), 0)
101 | top = max(int(imgh * y0), 0)
102 | right = min(int(imgw * x1), imgw - 1)
103 | bottom = min(int(imgh * y1), imgh - 1)
104 | res_loc_final.append([left, top, right, bottom])
105 |
106 | structure_str_list = structure_str_list[0][:-1]
107 | structure_str_list = ['', '', ''] + structure_str_list + ['
', '', '']
108 |
109 | elapse = time.time() - starttime
110 | return (structure_str_list, res_loc_final), elapse
111 |
112 |
113 | def main(args):
114 | image_file_list = get_image_file_list(args.image_dir)
115 | table_structurer = TableStructurer(args)
116 | count = 0
117 | total_time = 0
118 | for image_file in image_file_list:
119 | img, flag = check_and_read_gif(image_file)
120 | if not flag:
121 | img = cv2.imread(image_file)
122 | if img is None:
123 | logger.info("error in loading image:{}".format(image_file))
124 | continue
125 | structure_res, elapse = table_structurer(img)
126 |
127 | logger.info("result: {}".format(structure_res))
128 |
129 | if count > 0:
130 | total_time += elapse
131 | count += 1
132 | logger.info("Predict time of {}: {}".format(image_file, elapse))
133 |
134 |
135 | if __name__ == "__main__":
136 | main(parse_args())
137 |
--------------------------------------------------------------------------------
/script/t_2.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | __all__ = ['MobileNetV3']
10 |
11 |
12 | def make_divisible(v, divisor=8, min_value=None):
13 | if min_value is None:
14 | min_value = divisor
15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
16 | if new_v < 0.9 * v:
17 | new_v += divisor
18 | return new_v
19 |
20 |
21 | class MobileNetV3(nn.Module):
22 | def __init__(self,
23 | in_channels=3,
24 | model_name='large',
25 | scale=0.5,
26 | disable_se=False,
27 | **kwargs):
28 | """
29 | the MobilenetV3 backbone network for detection module.
30 | Args:
31 | params(dict): the super parameters for build network
32 | """
33 | super(MobileNetV3, self).__init__()
34 |
35 | self.disable_se = disable_se
36 |
37 | if model_name == "large":
38 | cfg = [
39 | # k, exp, c, se, nl, s,
40 | [3, 16, 16, False, 'relu', 1],
41 | [3, 64, 24, False, 'relu', 2],
42 | [3, 72, 24, False, 'relu', 1],
43 | [5, 72, 40, True, 'relu', 2],
44 | [5, 120, 40, True, 'relu', 1],
45 | [5, 120, 40, True, 'relu', 1],
46 | [3, 240, 80, False, 'hardswish', 2],
47 | [3, 200, 80, False, 'hardswish', 1],
48 | [3, 184, 80, False, 'hardswish', 1],
49 | [3, 184, 80, False, 'hardswish', 1],
50 | [3, 480, 112, True, 'hardswish', 1],
51 | [3, 672, 112, True, 'hardswish', 1],
52 | [5, 672, 160, True, 'hardswish', 2],
53 | [5, 960, 160, True, 'hardswish', 1],
54 | [5, 960, 160, True, 'hardswish', 1],
55 | ]
56 | cls_ch_squeeze = 960
57 | elif model_name == "small":
58 | cfg = [
59 | # k, exp, c, se, nl, s,
60 | [3, 16, 16, True, 'relu', 2],
61 | [3, 72, 24, False, 'relu', 2],
62 | [3, 88, 24, False, 'relu', 1],
63 | [5, 96, 40, True, 'hardswish', 2],
64 | [5, 240, 40, True, 'hardswish', 1],
65 | [5, 240, 40, True, 'hardswish', 1],
66 | [5, 120, 48, True, 'hardswish', 1],
67 | [5, 144, 48, True, 'hardswish', 1],
68 | [5, 288, 96, True, 'hardswish', 2],
69 | [5, 576, 96, True, 'hardswish', 1],
70 | [5, 576, 96, True, 'hardswish', 1],
71 | ]
72 | cls_ch_squeeze = 576
73 | else:
74 | raise NotImplementedError("mode[" + model_name +
75 | "_model] is not implemented!")
76 |
77 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
78 | assert scale in supported_scale, \
79 | "supported scale are {} but input scale is {}".format(supported_scale, scale)
80 | inplanes = 16
81 | # conv1
82 | self.conv = ConvBNLayer(
83 | in_channels=in_channels,
84 | out_channels=make_divisible(inplanes * scale),
85 | kernel_size=3,
86 | stride=2,
87 | padding=1,
88 | groups=1,
89 | if_act=True,
90 | act='hardswish',
91 | name='conv1')
92 |
93 | self.stages = []
94 | self.out_channels = []
95 | block_list = []
96 | i = 0
97 | inplanes = make_divisible(inplanes * scale)
98 | for (k, exp, c, se, nl, s) in cfg:
99 | se = se and not self.disable_se
100 | start_idx = 2 if model_name == 'large' else 0
101 | if s == 2 and i > start_idx:
102 | self.out_channels.append(inplanes)
103 | self.stages.append(nn.Sequential(*block_list))
104 | block_list = []
105 | block_list.append(
106 | ResidualUnit(
107 | in_channels=inplanes,
108 | mid_channels=make_divisible(scale * exp),
109 | out_channels=make_divisible(scale * c),
110 | kernel_size=k,
111 | stride=s,
112 | use_se=se,
113 | act=nl,
114 | name="conv" + str(i + 2)))
115 | inplanes = make_divisible(scale * c)
116 | i += 1
117 | block_list.append(
118 | ConvBNLayer(
119 | in_channels=inplanes,
120 | out_channels=make_divisible(scale * cls_ch_squeeze),
121 | kernel_size=1,
122 | stride=1,
123 | padding=0,
124 | groups=1,
125 | if_act=True,
126 | act='hardswish',
127 | name='conv_last'))
128 | self.stages.append(nn.Sequential(*block_list))
129 | self.stages_pipline = nn.Sequential(*self.stages)
130 | self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
131 | # for i, stage in enumerate(self.stages):
132 | # self.add_sublayer(sublayer=stage, name="stage{}".format(i))
133 |
134 | def forward(self, x):
135 | x = self.conv(x)
136 | out_list = []
137 | # for stage in self.stages:
138 | # x = stage(x)
139 | # out_list.append(x)
140 | x = self.stages_pipline(x)
141 | out_list.append(x)
142 | return out_list
143 |
144 |
145 | class ConvBNLayer(nn.Module):
146 | def __init__(self,
147 | in_channels,
148 | out_channels,
149 | kernel_size,
150 | stride,
151 | padding,
152 | groups=1,
153 | if_act=True,
154 | act=None,
155 | name=None):
156 | super(ConvBNLayer, self).__init__()
157 | self.if_act = if_act
158 | self.act = act
159 | self.conv = nn.Conv2d(
160 | in_channels=in_channels,
161 | out_channels=out_channels,
162 | kernel_size=kernel_size,
163 | stride=stride,
164 | padding=padding,
165 | groups=groups,
166 | bias=False)
167 |
168 | self.bn = nn.BatchNorm2d(out_channels)
169 |
170 | def forward(self, x):
171 | x = self.conv(x)
172 | x = self.bn(x)
173 | if self.if_act:
174 | if self.act == "relu":
175 | x = F.relu(x)
176 | elif self.act == "hardswish":
177 | x = F.hardswish(x)
178 | else:
179 | print("The activation function({}) is selected incorrectly.".
180 | format(self.act))
181 | exit()
182 | return x
183 |
184 |
185 | class ResidualUnit(nn.Module):
186 | def __init__(self,
187 | in_channels,
188 | mid_channels,
189 | out_channels,
190 | kernel_size,
191 | stride,
192 | use_se,
193 | act=None,
194 | name=''):
195 | super(ResidualUnit, self).__init__()
196 | self.if_shortcut = stride == 1 and in_channels == out_channels
197 | self.if_se = use_se
198 |
199 | self.expand_conv = ConvBNLayer(
200 | in_channels=in_channels,
201 | out_channels=mid_channels,
202 | kernel_size=1,
203 | stride=1,
204 | padding=0,
205 | if_act=True,
206 | act=act,
207 | name=name + "_expand")
208 | self.bottleneck_conv = ConvBNLayer(
209 | in_channels=mid_channels,
210 | out_channels=mid_channels,
211 | kernel_size=kernel_size,
212 | stride=stride,
213 | padding=int((kernel_size - 1) // 2),
214 | groups=mid_channels,
215 | if_act=True,
216 | act=act,
217 | name=name + "_depthwise")
218 | if self.if_se:
219 | self.mid_se = SEModule(mid_channels, name=name + "_se")
220 | self.linear_conv = ConvBNLayer(
221 | in_channels=mid_channels,
222 | out_channels=out_channels,
223 | kernel_size=1,
224 | stride=1,
225 | padding=0,
226 | if_act=False,
227 | act=None,
228 | name=name + "_linear")
229 |
230 | def forward(self, inputs):
231 | x = self.expand_conv(inputs)
232 | x = self.bottleneck_conv(x)
233 | if self.if_se:
234 | x = self.mid_se(x)
235 | x = self.linear_conv(x)
236 | if self.if_shortcut:
237 | x = torch.add(inputs, x)
238 | return x
239 |
240 |
241 | class SEModule(nn.Module):
242 | def __init__(self, in_channels, reduction=4, name=""):
243 | super(SEModule, self).__init__()
244 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
245 | self.conv1 = nn.Conv2d(
246 | in_channels=in_channels,
247 | out_channels=in_channels // reduction,
248 | kernel_size=1,
249 | stride=1,
250 | padding=0)
251 | self.conv2 = nn.Conv2d(
252 | in_channels=in_channels // reduction,
253 | out_channels=in_channels,
254 | kernel_size=1,
255 | stride=1,
256 | padding=0)
257 |
258 | def forward(self, inputs):
259 | outputs = self.avg_pool(inputs)
260 | outputs = self.conv1(outputs)
261 | outputs = F.relu(outputs)
262 | outputs = self.conv2(outputs)
263 | outputs = F.hardsigmoid(outputs)
264 | return inputs * outputs
265 |
266 |
267 | def test():
268 | net = MobileNetV3()
269 | net.to('cuda:0')
270 | x = torch.randn(2, 3, 224, 224)
271 | y = net(x)
272 | print(y.size())
273 |
274 |
275 | # test()
276 | a = torch.tensor([1, 2, 3])
277 | print()
278 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/01/25 09:52
3 | # @Author : lijun
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import sys
10 | import platform
11 | import time
12 | import torch
13 | from tqdm import tqdm
14 |
15 | from models.architectures import build_model
16 | from utils.save_load import load_model
17 | from utils.metrics import build_metric
18 | from utils.data import build_dataloader
19 | from utils.utility import preprocess
20 |
21 |
22 | def test():
23 | """
24 | :param: None
25 | :return: 准确率
26 | """
27 |
28 | """ 读取配置文件--->开始 """
29 | config, device, logger = preprocess(is_train=True)
30 | """ 读取配置文件--->结束 """
31 |
32 | """ 构建dataloader--->开始 """
33 | valid_dataloader = build_dataloader(config, 'Eval', device, logger)
34 | if len(valid_dataloader) == 0:
35 | logger.error("please check val_dataloader\n")
36 | return
37 | """ 构建dataloader--->结束 """
38 |
39 | """ 构建模型--->开始 """
40 | model = build_model(config['Architecture'])
41 | load_model(model, config['Global']['checkpoints'])
42 | model.to(device)
43 | """ 构建模型--->结束 """
44 |
45 | """ 构建metric--->开始 """
46 | eval_class = build_metric(config['Metric'])
47 | """ 构建metric--->结束 """
48 |
49 | """ test--->开始 """
50 | model.eval()
51 | with torch.no_grad():
52 | total_frame = 0.0
53 | total_time = 0.0
54 | pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
55 | max_iter = len(valid_dataloader) - 1 if platform.system() == "Windows" else len(valid_dataloader)
56 | for idx, batch in enumerate(valid_dataloader):
57 | batch = [i.to(device) for i in batch]
58 | if idx >= max_iter:
59 | break
60 | images = batch[0]
61 | start = time.time()
62 |
63 | preds = model(images, data=batch[1:])
64 |
65 | # Obtain usable results from post-processing methods
66 | total_time += time.time() - start
67 | # Evaluate the results of the current batch
68 |
69 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(),
70 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()}
71 | batch = [item.cpu().numpy() for item in batch]
72 | eval_class(preds, batch)
73 |
74 | pbar.update(1)
75 | total_frame += len(images)
76 | # Get final metric->acc
77 | metric = eval_class.get_metric()
78 | """ test--->结束 """
79 |
80 | pbar.close()
81 | model.train()
82 | metric['fps'] = total_frame / total_time
83 | return metric
84 |
85 |
86 | if __name__ == '__main__':
87 | metric = test()
88 | print('metric: ', metric)
89 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/01/25 09:52
3 | # @Author : lijun
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import os
10 | import sys
11 | import platform
12 | import yaml
13 | import time
14 | import shutil
15 | import torch
16 | from tqdm import tqdm
17 | import numpy as np
18 |
19 | from models.architectures import build_model
20 | from utils.save_load import load_model
21 | from utils.losses import build_loss
22 | from utils.optimizer import build_optimizer
23 | from utils.metrics import build_metric
24 | from utils.stats import TrainingStats
25 | from utils.save_load import save_model
26 | from utils.data import build_dataloader
27 | from utils.utility import preprocess
28 |
29 |
30 | def eval(model,
31 | valid_dataloader,
32 | eval_class,
33 | device):
34 | model.eval()
35 | with torch.no_grad():
36 | total_frame = 0.0
37 | total_time = 0.0
38 | pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
39 | max_iter = len(valid_dataloader) - 1 if platform.system() == "Windows" else len(valid_dataloader)
40 | for idx, batch in enumerate(valid_dataloader):
41 | batch = [i.to(device) for i in batch]
42 | if idx >= max_iter:
43 | break
44 | images = batch[0]
45 | start = time.time()
46 |
47 | preds = model(images, data=batch[1:])
48 |
49 | # Obtain usable results from post-processing methods
50 | total_time += time.time() - start
51 | # Evaluate the results of the current batch
52 |
53 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(),
54 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()}
55 | batch = [item.cpu().numpy() for item in batch]
56 | eval_class(preds, batch)
57 |
58 | pbar.update(1)
59 | total_frame += len(images)
60 | # Get final metric->acc
61 | metric = eval_class.get_metric()
62 |
63 | pbar.close()
64 | model.train()
65 | metric['fps'] = total_frame / total_time
66 | return metric
67 |
68 |
69 | def train():
70 | """
71 | :param:None
72 | :return: None
73 | """
74 |
75 | """ 读取配置文件--->开始 """
76 | config, device, logger = preprocess(is_train=True)
77 | """ 读取配置文件--->结束 """
78 |
79 | """ 构建dataloader--->开始 """
80 | train_dataloader = build_dataloader(config, 'Train', device, logger)
81 | valid_dataloader = build_dataloader(config, 'Eval', device, logger)
82 | if len(train_dataloader) == 0:
83 | logger.error("please check train_dataloader\n")
84 | return
85 | if len(valid_dataloader) == 0:
86 | logger.error("please check val_dataloader\n")
87 | return
88 | """ 构建dataloader--->结束 """
89 |
90 | """ 构建模型--->开始 """
91 | model = build_model(config['Architecture'])
92 | if config['Global']['checkpoints']:
93 | load_model(model, config['Global']['checkpoints'])
94 | model.to(device)
95 | model = torch.nn.DataParallel(model)
96 | """ 构建模型--->结束 """
97 |
98 | """ 构建loss--->开始 """
99 | loss_class = build_loss(config['Loss'])
100 | """ 构建loss--->结束 """
101 |
102 | """ 构建optim--->开始 """
103 | optimizer, lr_scheduler = build_optimizer(
104 | config['Optimizer'],
105 | epochs=config['Global']['epoch_num'],
106 | step_each_epoch=len(train_dataloader),
107 | parameters=model.parameters())
108 | """ 构建optim--->结束 """
109 |
110 | """ 构建metric--->开始 """
111 | eval_class = build_metric(config['Metric'])
112 | """ 构建metric--->结束 """
113 |
114 | """ 其他---> """
115 | cal_metric_during_train = config['Global'].get('cal_metric_during_train', False)
116 | log_smooth_window = config['Global']['log_smooth_window']
117 | epoch_num = config['Global']['epoch_num']
118 | print_batch_step = config['Global']['print_batch_step']
119 | eval_batch_step = config['Global']['eval_batch_step']
120 |
121 | global_step = 0
122 | save_epoch_step = config['Global']['save_epoch_step']
123 | save_model_dir = config['Global']['save_model_dir']
124 | if not os.path.exists(save_model_dir):
125 | os.makedirs(save_model_dir)
126 | main_indicator = eval_class.main_indicator
127 | best_model_dict = {main_indicator: 0}
128 | train_stats = TrainingStats(log_smooth_window, ['lr'])
129 | """ 其他---> """
130 |
131 | """ trian--->开始 """
132 | model.train()
133 | for epoch in range(1, epoch_num + 1):
134 | train_dataloader = build_dataloader(config, 'Train', device, logger, seed=epoch)
135 | train_batch_cost = 0.0
136 | train_reader_cost = 0.0
137 | batch_sum = 0
138 | batch_start = time.time()
139 | max_iter = len(train_dataloader) - 1 if platform.system() == "Windows" else len(train_dataloader)
140 | for idx, batch in enumerate(train_dataloader):
141 | batch = [i.to(device) for i in batch]
142 | train_reader_cost += time.time() - batch_start
143 | if idx >= max_iter:
144 | break
145 | lr = optimizer.defaults['lr']
146 | images = batch[0]
147 |
148 | preds = model(images, data=batch[1:])
149 |
150 | loss = loss_class(preds, batch)
151 | avg_loss = loss['loss']
152 | avg_loss.backward()
153 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['Optimizer']['clip_norm'])
154 | optimizer.step()
155 | optimizer.zero_grad()
156 |
157 | train_batch_cost += time.time() - batch_start
158 | batch_sum += len(images)
159 |
160 | if not isinstance(lr_scheduler, float):
161 | lr_scheduler.step()
162 |
163 | stats = {k: v.cpu().detach().numpy().mean() for k, v in loss.items()}
164 | stats['lr'] = lr
165 | train_stats.update(stats)
166 |
167 | if cal_metric_during_train: # only rec and cls need
168 | batch = [item.cpu().numpy() for item in batch]
169 | preds = {'structure_probs': preds['structure_probs'].cpu().detach().numpy(),
170 | 'loc_preds': preds['loc_preds'].cpu().detach().numpy()}
171 | eval_class(preds, batch)
172 | metric = eval_class.get_metric()
173 | train_stats.update(metric)
174 |
175 | if (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(train_dataloader) - 1):
176 | logs = train_stats.log()
177 | strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
178 | epoch, epoch_num, global_step, logs, train_reader_cost /
179 | print_batch_step, train_batch_cost / print_batch_step,
180 | batch_sum, batch_sum / train_batch_cost)
181 | logger.info(strs)
182 | train_batch_cost = 0.0
183 | train_reader_cost = 0.0
184 | batch_sum = 0
185 | if global_step % 400 == 0:
186 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)},
187 | os.path.join(save_model_dir, 'last.pt'))
188 | # eval
189 | if global_step % eval_batch_step == 0 and global_step > 0:
190 | cur_metric = eval(
191 | model,
192 | valid_dataloader,
193 | eval_class,
194 | device
195 | )
196 | cur_metric_str = \
197 | 'cur metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
198 | logger.info(cur_metric_str)
199 |
200 | if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
201 | best_model_dict.update(cur_metric)
202 | best_model_dict['best_epoch'] = epoch
203 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)},
204 | os.path.join(save_model_dir, 'best.pt'))
205 | best_str = \
206 | 'best metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
207 | logger.info(best_str)
208 |
209 | global_step += 1
210 | optimizer.zero_grad()
211 | batch_start = time.time()
212 | if epoch % save_epoch_step == 0:
213 | save_model({'state_dict': model.state_dict(), 'cfg': dict(config)},
214 | os.path.join(save_model_dir, str(epoch) + '.pt'))
215 | # if dist.get_rank() == 0:
216 | # save_model(
217 | # model,
218 | # optimizer,
219 | # save_model_dir,
220 | # logger,
221 | # is_best=False,
222 | # prefix='latest',
223 | # best_model_dict=best_model_dict,
224 | # epoch=epoch,
225 | # global_step=global_step)
226 | # if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
227 | # save_model(
228 | # model,
229 | # optimizer,
230 | # save_model_dir,
231 | # logger,
232 | # is_best=False,
233 | # prefix='iter_epoch_{}'.format(epoch),
234 | # best_model_dict=best_model_dict,
235 | # epoch=epoch,
236 | # global_step=global_step)
237 | best_str = 'best metric, {}'.format(', '.join(['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
238 | logger.info(best_str)
239 | """ trian--->结束 """
240 |
241 |
242 | if __name__ == '__main__':
243 | train()
244 |
--------------------------------------------------------------------------------
/utils/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/save_load.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/save_load.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/stats.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/stats.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/torch_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/torch_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utility.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/__pycache__/utility.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | import os
7 | import sys
8 | import numpy as np
9 | import signal
10 | import random
11 |
12 | __dir__ = os.path.dirname(os.path.abspath(__file__))
13 | sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
14 |
15 | import copy
16 | from torch.utils.data import DataLoader
17 | from utils.data.imaug import transform, create_operators
18 | from utils.data.pubtab_dataset import PubTabDataSet
19 |
20 | __all__ = ['build_dataloader', 'transform', 'create_operators']
21 |
22 |
23 | def term_mp(sig_num, frame):
24 | """ kill all child processes
25 | """
26 | pid = os.getpid()
27 | pgid = os.getpgid(os.getpid())
28 | print("main proc {} exit, kill process group " "{}".format(pid, pgid))
29 | os.killpg(pgid, signal.SIGKILL)
30 |
31 |
32 | def build_dataloader(config, mode, device, logger, seed=None):
33 | config = copy.deepcopy(config)
34 |
35 | support_dict = ['PubTabDataSet']
36 | module_name = config[mode]['dataset']['name']
37 | assert module_name in support_dict, Exception('DataSet only support {}'.format(support_dict))
38 | assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
39 |
40 | dataset = eval(module_name)(config, mode, logger, seed)
41 | loader_config = config[mode]['loader']
42 | batch_size = loader_config['batch_size_per_card']
43 | drop_last = loader_config['drop_last']
44 | shuffle = loader_config['shuffle']
45 | num_workers = loader_config['num_workers']
46 |
47 | data_loader = DataLoader(dataset=dataset,
48 | batch_size=batch_size,
49 | num_workers=num_workers,
50 | pin_memory=True,
51 | shuffle=shuffle,
52 | drop_last=drop_last)
53 |
54 | # support exit using ctrl+c
55 | signal.signal(signal.SIGINT, term_mp)
56 | signal.signal(signal.SIGTERM, term_mp)
57 |
58 | return data_loader
59 |
--------------------------------------------------------------------------------
/utils/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/pubtab_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/__pycache__/pubtab_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__init__.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 | from __future__ import unicode_literals
18 |
19 | from .iaa_augment import IaaAugment
20 | from .make_border_map import MakeBorderMap
21 | from .make_shrink_map import MakeShrinkMap
22 | from .random_crop_data import EastRandomCropData, PSERandomCrop
23 |
24 | from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
25 | from .randaugment import RandAugment
26 | from .copy_paste import CopyPaste
27 | from .operators import *
28 | from .label_ops import *
29 |
30 | from .east_process import *
31 | from .sast_process import *
32 | from .pg_process import *
33 | from .gen_table_mask import *
34 |
35 |
36 | def transform(data, ops=None):
37 | """ transform """
38 | if ops is None:
39 | ops = []
40 | for op in ops:
41 | data = op(data)
42 | if data is None:
43 | return None
44 | return data
45 |
46 |
47 | def create_operators(op_param_list, global_config=None):
48 | """
49 | create operators based on the config
50 |
51 | Args:
52 | params(list): a dict list, used to create some operators
53 | """
54 | assert isinstance(op_param_list, list), ('operator config should be a list')
55 | ops = []
56 | for operator in op_param_list:
57 | assert isinstance(operator,
58 | dict) and len(operator) == 1, "yaml format error"
59 | op_name = list(operator)[0]
60 | param = {} if operator[op_name] is None else operator[op_name]
61 | if global_config is not None:
62 | param.update(global_config)
63 | op = eval(op_name)(**param)
64 | ops.append(op)
65 | return ops
66 |
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/copy_paste.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/copy_paste.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/east_process.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/east_process.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/gen_table_mask.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/gen_table_mask.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/iaa_augment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/iaa_augment.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/label_ops.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/label_ops.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/make_border_map.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/make_border_map.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/make_shrink_map.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/make_shrink_map.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/operators.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/operators.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/pg_process.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/pg_process.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/randaugment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/randaugment.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/random_crop_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/random_crop_data.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/rec_img_aug.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/rec_img_aug.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/__pycache__/sast_process.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/__pycache__/sast_process.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/copy_paste.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cv2
3 | import random
4 | import numpy as np
5 | from PIL import Image
6 | from shapely.geometry import Polygon
7 |
8 | from utils.data.imaug.iaa_augment import IaaAugment
9 | from utils.data.imaug.random_crop_data import is_poly_outside_rect
10 | from utils.utility import get_rotate_crop_image
11 |
12 |
13 | class CopyPaste(object):
14 | def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
15 | self.ext_data_num = 1
16 | self.objects_paste_ratio = objects_paste_ratio
17 | self.limit_paste = limit_paste
18 | augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
19 | self.aug = IaaAugment(augmenter_args)
20 |
21 | def __call__(self, data):
22 | src_img = data['image']
23 | src_polys = data['polys'].tolist()
24 | src_ignores = data['ignore_tags'].tolist()
25 | ext_data = data['ext_data'][0]
26 | ext_image = ext_data['image']
27 | ext_polys = ext_data['polys']
28 | ext_ignores = ext_data['ignore_tags']
29 |
30 | indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
31 | select_num = max(
32 | 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
33 |
34 | random.shuffle(indexs)
35 | select_idxs = indexs[:select_num]
36 | select_polys = ext_polys[select_idxs]
37 | select_ignores = ext_ignores[select_idxs]
38 |
39 | src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
40 | ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
41 | src_img = Image.fromarray(src_img).convert('RGBA')
42 | for poly, tag in zip(select_polys, select_ignores):
43 | box_img = get_rotate_crop_image(ext_image, poly)
44 |
45 | src_img, box = self.paste_img(src_img, box_img, src_polys)
46 | if box is not None:
47 | src_polys.append(box)
48 | src_ignores.append(tag)
49 | src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
50 | h, w = src_img.shape[:2]
51 | src_polys = np.array(src_polys)
52 | src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
53 | src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
54 | data['image'] = src_img
55 | data['polys'] = src_polys
56 | data['ignore_tags'] = np.array(src_ignores)
57 | return data
58 |
59 | def paste_img(self, src_img, box_img, src_polys):
60 | box_img_pil = Image.fromarray(box_img).convert('RGBA')
61 | src_w, src_h = src_img.size
62 | box_w, box_h = box_img_pil.size
63 |
64 | angle = np.random.randint(0, 360)
65 | box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
66 | box = rotate_bbox(box_img, box, angle)[0]
67 | box_img_pil = box_img_pil.rotate(angle, expand=1)
68 | box_w, box_h = box_img_pil.width, box_img_pil.height
69 | if src_w - box_w < 0 or src_h - box_h < 0:
70 | return src_img, None
71 |
72 | paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
73 | src_h - box_h)
74 | if paste_x is None:
75 | return src_img, None
76 | box[:, 0] += paste_x
77 | box[:, 1] += paste_y
78 | r, g, b, A = box_img_pil.split()
79 | src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
80 |
81 | return src_img, box
82 |
83 | def select_coord(self, src_polys, box, endx, endy):
84 | if self.limit_paste:
85 | xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
86 | ), box[:, 0].max(), box[:, 1].max()
87 | for _ in range(50):
88 | paste_x = random.randint(0, endx)
89 | paste_y = random.randint(0, endy)
90 | xmin1 = xmin + paste_x
91 | xmax1 = xmax + paste_x
92 | ymin1 = ymin + paste_y
93 | ymax1 = ymax + paste_y
94 |
95 | num_poly_in_rect = 0
96 | for poly in src_polys:
97 | if not is_poly_outside_rect(poly, xmin1, ymin1,
98 | xmax1 - xmin1, ymax1 - ymin1):
99 | num_poly_in_rect += 1
100 | break
101 | if num_poly_in_rect == 0:
102 | return paste_x, paste_y
103 | return None, None
104 | else:
105 | paste_x = random.randint(0, endx)
106 | paste_y = random.randint(0, endy)
107 | return paste_x, paste_y
108 |
109 |
110 | def get_union(pD, pG):
111 | return Polygon(pD).union(Polygon(pG)).area
112 |
113 |
114 | def get_intersection_over_union(pD, pG):
115 | return get_intersection(pD, pG) / get_union(pD, pG)
116 |
117 |
118 | def get_intersection(pD, pG):
119 | return Polygon(pD).intersection(Polygon(pG)).area
120 |
121 |
122 | def rotate_bbox(img, text_polys, angle, scale=1):
123 | """
124 | from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
125 | Args:
126 | img: np.ndarray
127 | text_polys: np.ndarray N*4*2
128 | angle: int
129 | scale: int
130 |
131 | Returns:
132 |
133 | """
134 | w = img.shape[1]
135 | h = img.shape[0]
136 |
137 | rangle = np.deg2rad(angle)
138 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
139 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
140 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
141 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
142 | rot_mat[0, 2] += rot_move[0]
143 | rot_mat[1, 2] += rot_move[1]
144 |
145 | # ---------------------- rotate box ----------------------
146 | rot_text_polys = list()
147 | for bbox in text_polys:
148 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
149 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
150 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
151 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
152 | rot_text_polys.append([point1, point2, point3, point4])
153 | return np.array(rot_text_polys, dtype=np.float32)
154 |
--------------------------------------------------------------------------------
/utils/data/imaug/gen_table_mask.py:
--------------------------------------------------------------------------------
1 | """
2 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | from __future__ import unicode_literals
21 |
22 | import sys
23 | import six
24 | import cv2
25 | import numpy as np
26 |
27 |
28 | class GenTableMask(object):
29 | """ gen table mask """
30 |
31 | def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
32 | self.shrink_h_max = 5
33 | self.shrink_w_max = 5
34 | self.mask_type = mask_type
35 |
36 | def projection(self, erosion, h, w, spilt_threshold=0):
37 | # 水平投影
38 | projection_map = np.ones_like(erosion)
39 | project_val_array = [0 for _ in range(0, h)]
40 |
41 | for j in range(0, h):
42 | for i in range(0, w):
43 | if erosion[j, i] == 255:
44 | project_val_array[j] += 1
45 | # 根据数组,获取切割点
46 | start_idx = 0 # 记录进入字符区的索引
47 | end_idx = 0 # 记录进入空白区域的索引
48 | in_text = False # 是否遍历到了字符区内
49 | box_list = []
50 | for i in range(len(project_val_array)):
51 | if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
52 | in_text = True
53 | start_idx = i
54 | elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
55 | end_idx = i
56 | in_text = False
57 | if end_idx - start_idx <= 2:
58 | continue
59 | box_list.append((start_idx, end_idx + 1))
60 |
61 | if in_text:
62 | box_list.append((start_idx, h - 1))
63 | # 绘制投影直方图
64 | for j in range(0, h):
65 | for i in range(0, project_val_array[j]):
66 | projection_map[j, i] = 0
67 | return box_list, projection_map
68 |
69 | def projection_cx(self, box_img):
70 | box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
71 | h, w = box_gray_img.shape
72 | # 灰度图片进行二值化处理
73 | ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
74 | # 纵向腐蚀
75 | if h < w:
76 | kernel = np.ones((2, 1), np.uint8)
77 | erode = cv2.erode(thresh1, kernel, iterations=1)
78 | else:
79 | erode = thresh1
80 | # 水平膨胀
81 | kernel = np.ones((1, 5), np.uint8)
82 | erosion = cv2.dilate(erode, kernel, iterations=1)
83 | # 水平投影
84 | projection_map = np.ones_like(erosion)
85 | project_val_array = [0 for _ in range(0, h)]
86 |
87 | for j in range(0, h):
88 | for i in range(0, w):
89 | if erosion[j, i] == 255:
90 | project_val_array[j] += 1
91 | # 根据数组,获取切割点
92 | start_idx = 0 # 记录进入字符区的索引
93 | end_idx = 0 # 记录进入空白区域的索引
94 | in_text = False # 是否遍历到了字符区内
95 | box_list = []
96 | spilt_threshold = 0
97 | for i in range(len(project_val_array)):
98 | if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
99 | in_text = True
100 | start_idx = i
101 | elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
102 | end_idx = i
103 | in_text = False
104 | if end_idx - start_idx <= 2:
105 | continue
106 | box_list.append((start_idx, end_idx + 1))
107 |
108 | if in_text:
109 | box_list.append((start_idx, h - 1))
110 | # 绘制投影直方图
111 | for j in range(0, h):
112 | for i in range(0, project_val_array[j]):
113 | projection_map[j, i] = 0
114 | split_bbox_list = []
115 | if len(box_list) > 1:
116 | for i, (h_start, h_end) in enumerate(box_list):
117 | if i == 0:
118 | h_start = 0
119 | if i == len(box_list):
120 | h_end = h
121 | word_img = erosion[h_start:h_end + 1, :]
122 | word_h, word_w = word_img.shape
123 | w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
124 | w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
125 | if h_start > 0:
126 | h_start -= 1
127 | h_end += 1
128 | word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
129 | split_bbox_list.append([w_start, h_start, w_end, h_end])
130 | else:
131 | split_bbox_list.append([0, 0, w, h])
132 | return split_bbox_list
133 |
134 | def shrink_bbox(self, bbox):
135 | left, top, right, bottom = bbox
136 | sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
137 | sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
138 | left_new = left + sh_w
139 | right_new = right - sh_w
140 | top_new = top + sh_h
141 | bottom_new = bottom - sh_h
142 | if left_new >= right_new:
143 | left_new = left
144 | right_new = right
145 | if top_new >= bottom_new:
146 | top_new = top
147 | bottom_new = bottom
148 | return [left_new, top_new, right_new, bottom_new]
149 |
150 | def __call__(self, data):
151 | img = data['image']
152 | cells = data['cells']
153 | height, width = img.shape[0:2]
154 | if self.mask_type == 1:
155 | mask_img = np.zeros((height, width), dtype=np.float32)
156 | else:
157 | mask_img = np.zeros((height, width, 3), dtype=np.float32)
158 | cell_num = len(cells)
159 | for cno in range(cell_num):
160 | if "bbox" in cells[cno]:
161 | bbox = cells[cno]['bbox']
162 | left, top, right, bottom = bbox
163 | box_img = img[top:bottom, left:right, :].copy()
164 | split_bbox_list = self.projection_cx(box_img)
165 | for sno in range(len(split_bbox_list)):
166 | split_bbox_list[sno][0] += left
167 | split_bbox_list[sno][1] += top
168 | split_bbox_list[sno][2] += left
169 | split_bbox_list[sno][3] += top
170 |
171 | for sno in range(len(split_bbox_list)):
172 | left, top, right, bottom = split_bbox_list[sno]
173 | left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
174 | if self.mask_type == 1:
175 | mask_img[top:bottom, left:right] = 1.0
176 | data['mask_img'] = mask_img
177 | else:
178 | mask_img[top:bottom, left:right, :] = (255, 255, 255)
179 | data['image'] = mask_img
180 | return data
181 |
182 | class ResizeTableImage(object):
183 | def __init__(self, max_len, **kwargs):
184 | super(ResizeTableImage, self).__init__()
185 | self.max_len = max_len
186 |
187 | def get_img_bbox(self, cells):
188 | bbox_list = []
189 | if len(cells) == 0:
190 | return bbox_list
191 | cell_num = len(cells)
192 | for cno in range(cell_num):
193 | if "bbox" in cells[cno]:
194 | bbox = cells[cno]['bbox']
195 | bbox_list.append(bbox)
196 | return bbox_list
197 |
198 | def resize_img_table(self, img, bbox_list, max_len):
199 | height, width = img.shape[0:2]
200 | ratio = max_len / (max(height, width) * 1.0)
201 | resize_h = int(height * ratio)
202 | resize_w = int(width * ratio)
203 | img_new = cv2.resize(img, (resize_w, resize_h))
204 | bbox_list_new = []
205 | for bno in range(len(bbox_list)):
206 | left, top, right, bottom = bbox_list[bno].copy()
207 | left = int(left * ratio)
208 | top = int(top * ratio)
209 | right = int(right * ratio)
210 | bottom = int(bottom * ratio)
211 | bbox_list_new.append([left, top, right, bottom])
212 | return img_new, bbox_list_new
213 |
214 | def __call__(self, data):
215 | img = data['image']
216 | if 'cells' not in data:
217 | cells = []
218 | else:
219 | cells = data['cells']
220 | bbox_list = self.get_img_bbox(cells)
221 | img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
222 | data['image'] = img_new
223 | cell_num = len(cells)
224 | bno = 0
225 | for cno in range(cell_num):
226 | if "bbox" in data['cells'][cno]:
227 | data['cells'][cno]['bbox'] = bbox_list_new[bno]
228 | bno += 1
229 | data['max_len'] = self.max_len
230 | return data
231 |
232 | class PaddingTableImage(object):
233 | def __init__(self, **kwargs):
234 | super(PaddingTableImage, self).__init__()
235 |
236 | def __call__(self, data):
237 | img = data['image']
238 | max_len = data['max_len']
239 | padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
240 | height, width = img.shape[0:2]
241 | padding_img[0:height, 0:width, :] = img.copy()
242 | data['image'] = padding_img
243 | return data
244 |
--------------------------------------------------------------------------------
/utils/data/imaug/iaa_augment.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 | from __future__ import unicode_literals
23 |
24 | import numpy as np
25 | import imgaug
26 | import imgaug.augmenters as iaa
27 |
28 |
29 | class AugmenterBuilder(object):
30 | def __init__(self):
31 | pass
32 |
33 | def build(self, args, root=True):
34 | if args is None or len(args) == 0:
35 | return None
36 | elif isinstance(args, list):
37 | if root:
38 | sequence = [self.build(value, root=False) for value in args]
39 | return iaa.Sequential(sequence)
40 | else:
41 | return getattr(iaa, args[0])(
42 | *[self.to_tuple_if_list(a) for a in args[1:]])
43 | elif isinstance(args, dict):
44 | cls = getattr(iaa, args['type'])
45 | return cls(**{
46 | k: self.to_tuple_if_list(v)
47 | for k, v in args['args'].items()
48 | })
49 | else:
50 | raise RuntimeError('unknown augmenter arg: ' + str(args))
51 |
52 | def to_tuple_if_list(self, obj):
53 | if isinstance(obj, list):
54 | return tuple(obj)
55 | return obj
56 |
57 |
58 | class IaaAugment():
59 | def __init__(self, augmenter_args=None, **kwargs):
60 | if augmenter_args is None:
61 | augmenter_args = [{
62 | 'type': 'Fliplr',
63 | 'args': {
64 | 'p': 0.5
65 | }
66 | }, {
67 | 'type': 'Affine',
68 | 'args': {
69 | 'rotate': [-10, 10]
70 | }
71 | }, {
72 | 'type': 'Resize',
73 | 'args': {
74 | 'size': [0.5, 3]
75 | }
76 | }]
77 | self.augmenter = AugmenterBuilder().build(augmenter_args)
78 |
79 | def __call__(self, data):
80 | image = data['image']
81 | shape = image.shape
82 |
83 | if self.augmenter:
84 | aug = self.augmenter.to_deterministic()
85 | data['image'] = aug.augment_image(image)
86 | data = self.may_augment_annotation(aug, data, shape)
87 | return data
88 |
89 | def may_augment_annotation(self, aug, data, shape):
90 | if aug is None:
91 | return data
92 |
93 | line_polys = []
94 | for poly in data['polys']:
95 | new_poly = self.may_augment_poly(aug, shape, poly)
96 | line_polys.append(new_poly)
97 | data['polys'] = np.array(line_polys)
98 | return data
99 |
100 | def may_augment_poly(self, aug, img_shape, poly):
101 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
102 | keypoints = aug.augment_keypoints(
103 | [imgaug.KeypointsOnImage(
104 | keypoints, shape=img_shape)])[0].keypoints
105 | poly = [(p.x, p.y) for p in keypoints]
106 | return poly
107 |
--------------------------------------------------------------------------------
/utils/data/imaug/make_border_map.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
17 | """
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import numpy as np
24 | import cv2
25 |
26 | np.seterr(divide='ignore', invalid='ignore')
27 | import pyclipper
28 | from shapely.geometry import Polygon
29 | import sys
30 | import warnings
31 |
32 | warnings.simplefilter("ignore")
33 |
34 | __all__ = ['MakeBorderMap']
35 |
36 |
37 | class MakeBorderMap(object):
38 | def __init__(self,
39 | shrink_ratio=0.4,
40 | thresh_min=0.3,
41 | thresh_max=0.7,
42 | **kwargs):
43 | self.shrink_ratio = shrink_ratio
44 | self.thresh_min = thresh_min
45 | self.thresh_max = thresh_max
46 |
47 | def __call__(self, data):
48 |
49 | img = data['image']
50 | text_polys = data['polys']
51 | ignore_tags = data['ignore_tags']
52 |
53 | canvas = np.zeros(img.shape[:2], dtype=np.float32)
54 | mask = np.zeros(img.shape[:2], dtype=np.float32)
55 |
56 | for i in range(len(text_polys)):
57 | if ignore_tags[i]:
58 | continue
59 | self.draw_border_map(text_polys[i], canvas, mask=mask)
60 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
61 |
62 | data['threshold_map'] = canvas
63 | data['threshold_mask'] = mask
64 | return data
65 |
66 | def draw_border_map(self, polygon, canvas, mask):
67 | polygon = np.array(polygon)
68 | assert polygon.ndim == 2
69 | assert polygon.shape[1] == 2
70 |
71 | polygon_shape = Polygon(polygon)
72 | if polygon_shape.area <= 0:
73 | return
74 | distance = polygon_shape.area * (
75 | 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
76 | subject = [tuple(l) for l in polygon]
77 | padding = pyclipper.PyclipperOffset()
78 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
79 |
80 | padded_polygon = np.array(padding.Execute(distance)[0])
81 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
82 |
83 | xmin = padded_polygon[:, 0].min()
84 | xmax = padded_polygon[:, 0].max()
85 | ymin = padded_polygon[:, 1].min()
86 | ymax = padded_polygon[:, 1].max()
87 | width = xmax - xmin + 1
88 | height = ymax - ymin + 1
89 |
90 | polygon[:, 0] = polygon[:, 0] - xmin
91 | polygon[:, 1] = polygon[:, 1] - ymin
92 |
93 | xs = np.broadcast_to(
94 | np.linspace(
95 | 0, width - 1, num=width).reshape(1, width), (height, width))
96 | ys = np.broadcast_to(
97 | np.linspace(
98 | 0, height - 1, num=height).reshape(height, 1), (height, width))
99 |
100 | distance_map = np.zeros(
101 | (polygon.shape[0], height, width), dtype=np.float32)
102 | for i in range(polygon.shape[0]):
103 | j = (i + 1) % polygon.shape[0]
104 | absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
105 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
106 | distance_map = distance_map.min(axis=0)
107 |
108 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
109 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
110 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
111 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
112 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
113 | 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
114 | xmin_valid - xmin:xmax_valid - xmax + width],
115 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
116 |
117 | def _distance(self, xs, ys, point_1, point_2):
118 | '''
119 | compute the distance from point to a line
120 | ys: coordinates in the first axis
121 | xs: coordinates in the second axis
122 | point_1, point_2: (x, y), the end of the line
123 | '''
124 | height, width = xs.shape[:2]
125 | square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
126 | 1])
127 | square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
128 | 1])
129 | square_distance = np.square(point_1[0] - point_2[0]) + np.square(
130 | point_1[1] - point_2[1])
131 |
132 | cosin = (square_distance - square_distance_1 - square_distance_2) / (
133 | 2 * np.sqrt(square_distance_1 * square_distance_2))
134 | square_sin = 1 - np.square(cosin)
135 | square_sin = np.nan_to_num(square_sin)
136 | result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
137 | square_distance)
138 |
139 | result[cosin <
140 | 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
141 | < 0]
142 | # self.extend_line(point_1, point_2, result)
143 | return result
144 |
145 | def extend_line(self, point_1, point_2, result, shrink_ratio):
146 | ex_point_1 = (int(
147 | round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
148 | int(
149 | round(point_1[1] + (point_1[1] - point_2[1]) * (
150 | 1 + shrink_ratio))))
151 | cv2.line(
152 | result,
153 | tuple(ex_point_1),
154 | tuple(point_1),
155 | 4096.0,
156 | 1,
157 | lineType=cv2.LINE_AA,
158 | shift=0)
159 | ex_point_2 = (int(
160 | round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
161 | int(
162 | round(point_2[1] + (point_2[1] - point_1[1]) * (
163 | 1 + shrink_ratio))))
164 | cv2.line(
165 | result,
166 | tuple(ex_point_2),
167 | tuple(point_2),
168 | 4096.0,
169 | 1,
170 | lineType=cv2.LINE_AA,
171 | shift=0)
172 | return ex_point_1, ex_point_2
173 |
--------------------------------------------------------------------------------
/utils/data/imaug/make_shrink_map.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
17 | """
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import numpy as np
24 | import cv2
25 | from shapely.geometry import Polygon
26 | import pyclipper
27 |
28 | __all__ = ['MakeShrinkMap']
29 |
30 |
31 | class MakeShrinkMap(object):
32 | r'''
33 | Making binary mask from detection data with ICDAR format.
34 | Typically following the process of class `MakeICDARData`.
35 | '''
36 |
37 | def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
38 | self.min_text_size = min_text_size
39 | self.shrink_ratio = shrink_ratio
40 |
41 | def __call__(self, data):
42 | image = data['image']
43 | text_polys = data['polys']
44 | ignore_tags = data['ignore_tags']
45 |
46 | h, w = image.shape[:2]
47 | text_polys, ignore_tags = self.validate_polygons(text_polys,
48 | ignore_tags, h, w)
49 | gt = np.zeros((h, w), dtype=np.float32)
50 | mask = np.ones((h, w), dtype=np.float32)
51 | for i in range(len(text_polys)):
52 | polygon = text_polys[i]
53 | height = max(polygon[:, 1]) - min(polygon[:, 1])
54 | width = max(polygon[:, 0]) - min(polygon[:, 0])
55 | if ignore_tags[i] or min(height, width) < self.min_text_size:
56 | cv2.fillPoly(mask,
57 | polygon.astype(np.int32)[np.newaxis, :, :], 0)
58 | ignore_tags[i] = True
59 | else:
60 | polygon_shape = Polygon(polygon)
61 | subject = [tuple(l) for l in polygon]
62 | padding = pyclipper.PyclipperOffset()
63 | padding.AddPath(subject, pyclipper.JT_ROUND,
64 | pyclipper.ET_CLOSEDPOLYGON)
65 | shrinked = []
66 |
67 | # Increase the shrink ratio every time we get multiple polygon returned back
68 | possible_ratios = np.arange(self.shrink_ratio, 1,
69 | self.shrink_ratio)
70 | np.append(possible_ratios, 1)
71 | # print(possible_ratios)
72 | for ratio in possible_ratios:
73 | # print(f"Change shrink ratio to {ratio}")
74 | distance = polygon_shape.area * (
75 | 1 - np.power(ratio, 2)) / polygon_shape.length
76 | shrinked = padding.Execute(-distance)
77 | if len(shrinked) == 1:
78 | break
79 |
80 | if shrinked == []:
81 | cv2.fillPoly(mask,
82 | polygon.astype(np.int32)[np.newaxis, :, :], 0)
83 | ignore_tags[i] = True
84 | continue
85 |
86 | for each_shirnk in shrinked:
87 | shirnk = np.array(each_shirnk).reshape(-1, 2)
88 | cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
89 |
90 | data['shrink_map'] = gt
91 | data['shrink_mask'] = mask
92 | return data
93 |
94 | def validate_polygons(self, polygons, ignore_tags, h, w):
95 | '''
96 | polygons (numpy.array, required): of shape (num_instances, num_points, 2)
97 | '''
98 | if len(polygons) == 0:
99 | return polygons, ignore_tags
100 | assert len(polygons) == len(ignore_tags)
101 | for polygon in polygons:
102 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
103 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
104 |
105 | for i in range(len(polygons)):
106 | area = self.polygon_area(polygons[i])
107 | if abs(area) < 1:
108 | ignore_tags[i] = True
109 | if area > 0:
110 | polygons[i] = polygons[i][::-1, :]
111 | return polygons, ignore_tags
112 |
113 | def polygon_area(self, polygon):
114 | """
115 | compute polygon area
116 | """
117 | area = 0
118 | q = polygon[-1]
119 | for p in polygon:
120 | area += p[0] * q[1] - p[1] * q[0]
121 | q = p
122 | return area / 2.0
123 |
--------------------------------------------------------------------------------
/utils/data/imaug/randaugment.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 | from __future__ import unicode_literals
19 |
20 | from PIL import Image, ImageEnhance, ImageOps
21 | import numpy as np
22 | import random
23 | import six
24 |
25 |
26 | class RawRandAugment(object):
27 | def __init__(self,
28 | num_layers=2,
29 | magnitude=5,
30 | fillcolor=(128, 128, 128),
31 | **kwargs):
32 | self.num_layers = num_layers
33 | self.magnitude = magnitude
34 | self.max_level = 10
35 |
36 | abso_level = self.magnitude / self.max_level
37 | self.level_map = {
38 | "shearX": 0.3 * abso_level,
39 | "shearY": 0.3 * abso_level,
40 | "translateX": 150.0 / 331 * abso_level,
41 | "translateY": 150.0 / 331 * abso_level,
42 | "rotate": 30 * abso_level,
43 | "color": 0.9 * abso_level,
44 | "posterize": int(4.0 * abso_level),
45 | "solarize": 256.0 * abso_level,
46 | "contrast": 0.9 * abso_level,
47 | "sharpness": 0.9 * abso_level,
48 | "brightness": 0.9 * abso_level,
49 | "autocontrast": 0,
50 | "equalize": 0,
51 | "invert": 0
52 | }
53 |
54 | # from https://stackoverflow.com/questions/5252170/
55 | # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
56 | def rotate_with_fill(img, magnitude):
57 | rot = img.convert("RGBA").rotate(magnitude)
58 | return Image.composite(rot,
59 | Image.new("RGBA", rot.size, (128, ) * 4),
60 | rot).convert(img.mode)
61 |
62 | rnd_ch_op = random.choice
63 |
64 | self.func = {
65 | "shearX": lambda img, magnitude: img.transform(
66 | img.size,
67 | Image.AFFINE,
68 | (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
69 | Image.BICUBIC,
70 | fillcolor=fillcolor),
71 | "shearY": lambda img, magnitude: img.transform(
72 | img.size,
73 | Image.AFFINE,
74 | (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
75 | Image.BICUBIC,
76 | fillcolor=fillcolor),
77 | "translateX": lambda img, magnitude: img.transform(
78 | img.size,
79 | Image.AFFINE,
80 | (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
81 | fillcolor=fillcolor),
82 | "translateY": lambda img, magnitude: img.transform(
83 | img.size,
84 | Image.AFFINE,
85 | (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
86 | fillcolor=fillcolor),
87 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
88 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
89 | 1 + magnitude * rnd_ch_op([-1, 1])),
90 | "posterize": lambda img, magnitude:
91 | ImageOps.posterize(img, magnitude),
92 | "solarize": lambda img, magnitude:
93 | ImageOps.solarize(img, magnitude),
94 | "contrast": lambda img, magnitude:
95 | ImageEnhance.Contrast(img).enhance(
96 | 1 + magnitude * rnd_ch_op([-1, 1])),
97 | "sharpness": lambda img, magnitude:
98 | ImageEnhance.Sharpness(img).enhance(
99 | 1 + magnitude * rnd_ch_op([-1, 1])),
100 | "brightness": lambda img, magnitude:
101 | ImageEnhance.Brightness(img).enhance(
102 | 1 + magnitude * rnd_ch_op([-1, 1])),
103 | "autocontrast": lambda img, magnitude:
104 | ImageOps.autocontrast(img),
105 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
106 | "invert": lambda img, magnitude: ImageOps.invert(img)
107 | }
108 |
109 | def __call__(self, img):
110 | avaiable_op_names = list(self.level_map.keys())
111 | for layer_num in range(self.num_layers):
112 | op_name = np.random.choice(avaiable_op_names)
113 | img = self.func[op_name](img, self.level_map[op_name])
114 | return img
115 |
116 |
117 | class RandAugment(RawRandAugment):
118 | """ RandAugment wrapper to auto fit different img types """
119 |
120 | def __init__(self, prob=0.5, *args, **kwargs):
121 | self.prob = prob
122 | if six.PY2:
123 | super(RandAugment, self).__init__(*args, **kwargs)
124 | else:
125 | super().__init__(*args, **kwargs)
126 |
127 | def __call__(self, data):
128 | if np.random.rand() > self.prob:
129 | return data
130 | img = data['image']
131 | if not isinstance(img, Image.Image):
132 | img = np.ascontiguousarray(img)
133 | img = Image.fromarray(img)
134 |
135 | if six.PY2:
136 | img = super(RandAugment, self).__call__(img)
137 | else:
138 | img = super().__call__(img)
139 |
140 | if isinstance(img, Image.Image):
141 | img = np.asarray(img)
142 | data['image'] = img
143 | return data
144 |
--------------------------------------------------------------------------------
/utils/data/imaug/random_crop_data.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
17 | """
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import numpy as np
24 | import cv2
25 | import random
26 |
27 |
28 | def is_poly_in_rect(poly, x, y, w, h):
29 | poly = np.array(poly)
30 | if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
31 | return False
32 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
33 | return False
34 | return True
35 |
36 |
37 | def is_poly_outside_rect(poly, x, y, w, h):
38 | poly = np.array(poly)
39 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
40 | return True
41 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
42 | return True
43 | return False
44 |
45 |
46 | def split_regions(axis):
47 | regions = []
48 | min_axis = 0
49 | for i in range(1, axis.shape[0]):
50 | if axis[i] != axis[i - 1] + 1:
51 | region = axis[min_axis:i]
52 | min_axis = i
53 | regions.append(region)
54 | return regions
55 |
56 |
57 | def random_select(axis, max_size):
58 | xx = np.random.choice(axis, size=2)
59 | xmin = np.min(xx)
60 | xmax = np.max(xx)
61 | xmin = np.clip(xmin, 0, max_size - 1)
62 | xmax = np.clip(xmax, 0, max_size - 1)
63 | return xmin, xmax
64 |
65 |
66 | def region_wise_random_select(regions, max_size):
67 | selected_index = list(np.random.choice(len(regions), 2))
68 | selected_values = []
69 | for index in selected_index:
70 | axis = regions[index]
71 | xx = int(np.random.choice(axis, size=1))
72 | selected_values.append(xx)
73 | xmin = min(selected_values)
74 | xmax = max(selected_values)
75 | return xmin, xmax
76 |
77 |
78 | def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
79 | h, w, _ = im.shape
80 | h_array = np.zeros(h, dtype=np.int32)
81 | w_array = np.zeros(w, dtype=np.int32)
82 | for points in text_polys:
83 | points = np.round(points, decimals=0).astype(np.int32)
84 | minx = np.min(points[:, 0])
85 | maxx = np.max(points[:, 0])
86 | w_array[minx:maxx] = 1
87 | miny = np.min(points[:, 1])
88 | maxy = np.max(points[:, 1])
89 | h_array[miny:maxy] = 1
90 | # ensure the cropped area not across a text
91 | h_axis = np.where(h_array == 0)[0]
92 | w_axis = np.where(w_array == 0)[0]
93 |
94 | if len(h_axis) == 0 or len(w_axis) == 0:
95 | return 0, 0, w, h
96 |
97 | h_regions = split_regions(h_axis)
98 | w_regions = split_regions(w_axis)
99 |
100 | for i in range(max_tries):
101 | if len(w_regions) > 1:
102 | xmin, xmax = region_wise_random_select(w_regions, w)
103 | else:
104 | xmin, xmax = random_select(w_axis, w)
105 | if len(h_regions) > 1:
106 | ymin, ymax = region_wise_random_select(h_regions, h)
107 | else:
108 | ymin, ymax = random_select(h_axis, h)
109 |
110 | if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
111 | # area too small
112 | continue
113 | num_poly_in_rect = 0
114 | for poly in text_polys:
115 | if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
116 | ymax - ymin):
117 | num_poly_in_rect += 1
118 | break
119 |
120 | if num_poly_in_rect > 0:
121 | return xmin, ymin, xmax - xmin, ymax - ymin
122 |
123 | return 0, 0, w, h
124 |
125 |
126 | class EastRandomCropData(object):
127 | def __init__(self,
128 | size=(640, 640),
129 | max_tries=10,
130 | min_crop_side_ratio=0.1,
131 | keep_ratio=True,
132 | **kwargs):
133 | self.size = size
134 | self.max_tries = max_tries
135 | self.min_crop_side_ratio = min_crop_side_ratio
136 | self.keep_ratio = keep_ratio
137 |
138 | def __call__(self, data):
139 | img = data['image']
140 | text_polys = data['polys']
141 | ignore_tags = data['ignore_tags']
142 | texts = data['texts']
143 | all_care_polys = [
144 | text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
145 | ]
146 | # 计算crop区域
147 | crop_x, crop_y, crop_w, crop_h = crop_area(
148 | img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
149 | # crop 图片 保持比例填充
150 | scale_w = self.size[0] / crop_w
151 | scale_h = self.size[1] / crop_h
152 | scale = min(scale_w, scale_h)
153 | h = int(crop_h * scale)
154 | w = int(crop_w * scale)
155 | if self.keep_ratio:
156 | padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
157 | img.dtype)
158 | padimg[:h, :w] = cv2.resize(
159 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
160 | img = padimg
161 | else:
162 | img = cv2.resize(
163 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
164 | tuple(self.size))
165 | # crop 文本框
166 | text_polys_crop = []
167 | ignore_tags_crop = []
168 | texts_crop = []
169 | for poly, text, tag in zip(text_polys, texts, ignore_tags):
170 | poly = ((poly - (crop_x, crop_y)) * scale).tolist()
171 | if not is_poly_outside_rect(poly, 0, 0, w, h):
172 | text_polys_crop.append(poly)
173 | ignore_tags_crop.append(tag)
174 | texts_crop.append(text)
175 | data['image'] = img
176 | data['polys'] = np.array(text_polys_crop)
177 | data['ignore_tags'] = ignore_tags_crop
178 | data['texts'] = texts_crop
179 | return data
180 |
181 |
182 | class PSERandomCrop(object):
183 | def __init__(self, size, **kwargs):
184 | self.size = size
185 |
186 | def __call__(self, data):
187 | imgs = data['imgs']
188 |
189 | h, w = imgs[0].shape[0:2]
190 | th, tw = self.size
191 | if w == tw and h == th:
192 | return imgs
193 |
194 | # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
195 | if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
196 | # 文本实例的左上角点
197 | tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
198 | tl[tl < 0] = 0
199 | # 文本实例的右下角点
200 | br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
201 | br[br < 0] = 0
202 | # 保证选到右下角点时,有足够的距离进行crop
203 | br[0] = min(br[0], h - th)
204 | br[1] = min(br[1], w - tw)
205 |
206 | for _ in range(50000):
207 | i = random.randint(tl[0], br[0])
208 | j = random.randint(tl[1], br[1])
209 | # 保证shrink_label_map有文本
210 | if imgs[1][i:i + th, j:j + tw].sum() <= 0:
211 | continue
212 | else:
213 | break
214 | else:
215 | i = random.randint(0, h - th)
216 | j = random.randint(0, w - tw)
217 |
218 | # return i, j, th, tw
219 | for idx in range(len(imgs)):
220 | if len(imgs[idx].shape) == 3:
221 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
222 | else:
223 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
224 | data['imgs'] = imgs
225 | return data
226 |
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/__init__.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .augment import tia_perspective, tia_distort, tia_stretch
16 |
17 | __all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
18 |
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/augment.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/data/imaug/text_image_aug/__pycache__/warp_mls.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/augment.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
17 | """
18 | import numpy as np
19 | from .warp_mls import WarpMLS
20 |
21 |
22 | def tia_distort(src, segment=4):
23 | img_h, img_w = src.shape[:2]
24 |
25 | cut = img_w // segment
26 | thresh = cut // 3
27 |
28 | src_pts = list()
29 | dst_pts = list()
30 |
31 | src_pts.append([0, 0])
32 | src_pts.append([img_w, 0])
33 | src_pts.append([img_w, img_h])
34 | src_pts.append([0, img_h])
35 |
36 | dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
37 | dst_pts.append(
38 | [img_w - np.random.randint(thresh), np.random.randint(thresh)])
39 | dst_pts.append(
40 | [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
41 | dst_pts.append(
42 | [np.random.randint(thresh), img_h - np.random.randint(thresh)])
43 |
44 | half_thresh = thresh * 0.5
45 |
46 | for cut_idx in np.arange(1, segment, 1):
47 | src_pts.append([cut * cut_idx, 0])
48 | src_pts.append([cut * cut_idx, img_h])
49 | dst_pts.append([
50 | cut * cut_idx + np.random.randint(thresh) - half_thresh,
51 | np.random.randint(thresh) - half_thresh
52 | ])
53 | dst_pts.append([
54 | cut * cut_idx + np.random.randint(thresh) - half_thresh,
55 | img_h + np.random.randint(thresh) - half_thresh
56 | ])
57 |
58 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
59 | dst = trans.generate()
60 |
61 | return dst
62 |
63 |
64 | def tia_stretch(src, segment=4):
65 | img_h, img_w = src.shape[:2]
66 |
67 | cut = img_w // segment
68 | thresh = cut * 4 // 5
69 |
70 | src_pts = list()
71 | dst_pts = list()
72 |
73 | src_pts.append([0, 0])
74 | src_pts.append([img_w, 0])
75 | src_pts.append([img_w, img_h])
76 | src_pts.append([0, img_h])
77 |
78 | dst_pts.append([0, 0])
79 | dst_pts.append([img_w, 0])
80 | dst_pts.append([img_w, img_h])
81 | dst_pts.append([0, img_h])
82 |
83 | half_thresh = thresh * 0.5
84 |
85 | for cut_idx in np.arange(1, segment, 1):
86 | move = np.random.randint(thresh) - half_thresh
87 | src_pts.append([cut * cut_idx, 0])
88 | src_pts.append([cut * cut_idx, img_h])
89 | dst_pts.append([cut * cut_idx + move, 0])
90 | dst_pts.append([cut * cut_idx + move, img_h])
91 |
92 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
93 | dst = trans.generate()
94 |
95 | return dst
96 |
97 |
98 | def tia_perspective(src):
99 | img_h, img_w = src.shape[:2]
100 |
101 | thresh = img_h // 2
102 |
103 | src_pts = list()
104 | dst_pts = list()
105 |
106 | src_pts.append([0, 0])
107 | src_pts.append([img_w, 0])
108 | src_pts.append([img_w, img_h])
109 | src_pts.append([0, img_h])
110 |
111 | dst_pts.append([0, np.random.randint(thresh)])
112 | dst_pts.append([img_w, np.random.randint(thresh)])
113 | dst_pts.append([img_w, img_h - np.random.randint(thresh)])
114 | dst_pts.append([0, img_h - np.random.randint(thresh)])
115 |
116 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
117 | dst = trans.generate()
118 |
119 | return dst
--------------------------------------------------------------------------------
/utils/data/imaug/text_image_aug/warp_mls.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This code is refer from:
16 | https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
17 | """
18 | import numpy as np
19 |
20 |
21 | class WarpMLS:
22 | def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
23 | self.src = src
24 | self.src_pts = src_pts
25 | self.dst_pts = dst_pts
26 | self.pt_count = len(self.dst_pts)
27 | self.dst_w = dst_w
28 | self.dst_h = dst_h
29 | self.trans_ratio = trans_ratio
30 | self.grid_size = 100
31 | self.rdx = np.zeros((self.dst_h, self.dst_w))
32 | self.rdy = np.zeros((self.dst_h, self.dst_w))
33 |
34 | @staticmethod
35 | def __bilinear_interp(x, y, v11, v12, v21, v22):
36 | return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
37 | (1 - y) + v22 * y) * x
38 |
39 | def generate(self):
40 | self.calc_delta()
41 | return self.gen_img()
42 |
43 | def calc_delta(self):
44 | w = np.zeros(self.pt_count, dtype=np.float32)
45 |
46 | if self.pt_count < 2:
47 | return
48 |
49 | i = 0
50 | while 1:
51 | if self.dst_w <= i < self.dst_w + self.grid_size - 1:
52 | i = self.dst_w - 1
53 | elif i >= self.dst_w:
54 | break
55 |
56 | j = 0
57 | while 1:
58 | if self.dst_h <= j < self.dst_h + self.grid_size - 1:
59 | j = self.dst_h - 1
60 | elif j >= self.dst_h:
61 | break
62 |
63 | sw = 0
64 | swp = np.zeros(2, dtype=np.float32)
65 | swq = np.zeros(2, dtype=np.float32)
66 | new_pt = np.zeros(2, dtype=np.float32)
67 | cur_pt = np.array([i, j], dtype=np.float32)
68 |
69 | k = 0
70 | for k in range(self.pt_count):
71 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
72 | break
73 |
74 | w[k] = 1. / (
75 | (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
76 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
77 |
78 | sw += w[k]
79 | swp = swp + w[k] * np.array(self.dst_pts[k])
80 | swq = swq + w[k] * np.array(self.src_pts[k])
81 |
82 | if k == self.pt_count - 1:
83 | pstar = 1 / sw * swp
84 | qstar = 1 / sw * swq
85 |
86 | miu_s = 0
87 | for k in range(self.pt_count):
88 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
89 | continue
90 | pt_i = self.dst_pts[k] - pstar
91 | miu_s += w[k] * np.sum(pt_i * pt_i)
92 |
93 | cur_pt -= pstar
94 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
95 |
96 | for k in range(self.pt_count):
97 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
98 | continue
99 |
100 | pt_i = self.dst_pts[k] - pstar
101 | pt_j = np.array([-pt_i[1], pt_i[0]])
102 |
103 | tmp_pt = np.zeros(2, dtype=np.float32)
104 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
105 | np.sum(pt_j * cur_pt) * self.src_pts[k][1]
106 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
107 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
108 | tmp_pt *= (w[k] / miu_s)
109 | new_pt += tmp_pt
110 |
111 | new_pt += qstar
112 | else:
113 | new_pt = self.src_pts[k]
114 |
115 | self.rdx[j, i] = new_pt[0] - i
116 | self.rdy[j, i] = new_pt[1] - j
117 |
118 | j += self.grid_size
119 | i += self.grid_size
120 |
121 | def gen_img(self):
122 | src_h, src_w = self.src.shape[:2]
123 | dst = np.zeros_like(self.src, dtype=np.float32)
124 |
125 | for i in np.arange(0, self.dst_h, self.grid_size):
126 | for j in np.arange(0, self.dst_w, self.grid_size):
127 | ni = i + self.grid_size
128 | nj = j + self.grid_size
129 | w = h = self.grid_size
130 | if ni >= self.dst_h:
131 | ni = self.dst_h - 1
132 | h = ni - i + 1
133 | if nj >= self.dst_w:
134 | nj = self.dst_w - 1
135 | w = nj - j + 1
136 |
137 | di = np.reshape(np.arange(h), (-1, 1))
138 | dj = np.reshape(np.arange(w), (1, -1))
139 | delta_x = self.__bilinear_interp(
140 | di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
141 | self.rdx[ni, j], self.rdx[ni, nj])
142 | delta_y = self.__bilinear_interp(
143 | di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
144 | self.rdy[ni, j], self.rdy[ni, nj])
145 | nx = j + dj + delta_x * self.trans_ratio
146 | ny = i + di + delta_y * self.trans_ratio
147 | nx = np.clip(nx, 0, src_w - 1)
148 | ny = np.clip(ny, 0, src_h - 1)
149 | nxi = np.array(np.floor(nx), dtype=np.int32)
150 | nyi = np.array(np.floor(ny), dtype=np.int32)
151 | nxi1 = np.array(np.ceil(nx), dtype=np.int32)
152 | nyi1 = np.array(np.ceil(ny), dtype=np.int32)
153 |
154 | if len(self.src.shape) == 3:
155 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
156 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
157 | else:
158 | x = ny - nyi
159 | y = nx - nxi
160 | dst[i:i + h, j:j + w] = self.__bilinear_interp(
161 | x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
162 | self.src[nyi1, nxi], self.src[nyi1, nxi1])
163 |
164 | dst = np.clip(dst, 0, 255)
165 | dst = np.array(dst, dtype=np.uint8)
166 |
167 | return dst
--------------------------------------------------------------------------------
/utils/data/pubtab_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | from torch.utils.data import Dataset
5 | import json
6 | from .imaug import transform, create_operators
7 |
8 |
9 | class PubTabDataSet(Dataset):
10 | def __init__(self, config, mode, logger, seed=None):
11 | super(PubTabDataSet, self).__init__()
12 | self.logger = logger
13 |
14 | global_config = config['Global']
15 | dataset_config = config[mode]['dataset']
16 | loader_config = config[mode]['loader']
17 |
18 | label_file_path = dataset_config.pop('label_file_path')
19 |
20 | self.data_dir = dataset_config['data_dir']
21 | self.do_shuffle = loader_config['shuffle']
22 | self.do_hard_select = False
23 | if 'hard_select' in loader_config:
24 | self.do_hard_select = loader_config['hard_select']
25 | self.hard_prob = loader_config['hard_prob']
26 | if self.do_hard_select:
27 | self.img_select_prob = self.load_hard_select_prob()
28 | self.table_select_type = None
29 | if 'table_select_type' in loader_config:
30 | self.table_select_type = loader_config['table_select_type']
31 | self.table_select_prob = loader_config['table_select_prob']
32 |
33 | self.seed = seed
34 | logger.info("Initialize indexs of datasets:%s" % label_file_path)
35 | with open(label_file_path, "rb") as f:
36 | self.data_lines = f.readlines()
37 | self.data_idx_order_list = list(range(len(self.data_lines)))
38 | if mode.lower() == "train":
39 | self.shuffle_data_random()
40 | self.ops = create_operators(dataset_config['transforms'], global_config)
41 |
42 | def shuffle_data_random(self):
43 | if self.do_shuffle:
44 | random.seed(self.seed)
45 | random.shuffle(self.data_lines)
46 | return
47 |
48 | def __getitem__(self, idx):
49 | try:
50 | data_line = self.data_lines[idx]
51 | data_line = data_line.decode('utf-8').strip("\n")
52 | info = json.loads(data_line)
53 | file_name = info['filename']
54 | select_flag = True
55 | if self.do_hard_select:
56 | prob = self.img_select_prob[file_name]
57 | if prob < random.uniform(0, 1):
58 | select_flag = False
59 |
60 | if self.table_select_type:
61 | structure = info['html']['structure']['tokens'].copy()
62 | structure_str = ''.join(structure)
63 | table_type = "simple"
64 | if 'colspan' in structure_str or 'rowspan' in structure_str:
65 | table_type = "complex"
66 | if table_type == "complex":
67 | if self.table_select_prob < random.uniform(0, 1):
68 | select_flag = False
69 |
70 | if select_flag:
71 | cells = info['html']['cells'].copy()
72 | structure = info['html']['structure'].copy()
73 | img_path = os.path.join(self.data_dir, file_name)
74 | data = {'img_path': img_path, 'cells': cells, 'structure': structure}
75 | if not os.path.exists(img_path):
76 | raise Exception("{} does not exist!".format(img_path))
77 | with open(data['img_path'], 'rb') as f:
78 | img = f.read()
79 | data['image'] = img
80 | outs = transform(data, self.ops)
81 | else:
82 | outs = None
83 | except Exception as e:
84 | self.logger.error(
85 | "When parsing line {}, error happened with msg: {}".format(
86 | data_line, e))
87 | outs = None
88 | if outs is None:
89 | return self.__getitem__(np.random.randint(self.__len__()))
90 | return outs
91 |
92 | def __len__(self):
93 | return len(self.data_idx_order_list)
--------------------------------------------------------------------------------
/utils/dict/ar_dict.txt:
--------------------------------------------------------------------------------
1 | a
2 | r
3 | b
4 | i
5 | c
6 | _
7 | m
8 | g
9 | /
10 | 1
11 | 0
12 | I
13 | L
14 | S
15 | V
16 | R
17 | C
18 | 2
19 | v
20 | l
21 | 6
22 | 3
23 | 9
24 | .
25 | j
26 | p
27 | ا
28 | ل
29 | م
30 | ر
31 | ج
32 | و
33 | ح
34 | ي
35 | ة
36 | 5
37 | 8
38 | 7
39 | أ
40 | ب
41 | ض
42 | 4
43 | ك
44 | س
45 | ه
46 | ث
47 | ن
48 | ط
49 | ع
50 | ت
51 | غ
52 | خ
53 | ف
54 | ئ
55 | ز
56 | إ
57 | د
58 | ص
59 | ظ
60 | ذ
61 | ش
62 | ى
63 | ق
64 | ؤ
65 | آ
66 | ء
67 | s
68 | e
69 | n
70 | w
71 | t
72 | u
73 | z
74 | d
75 | A
76 | N
77 | G
78 | h
79 | o
80 | E
81 | T
82 | H
83 | O
84 | B
85 | y
86 | F
87 | U
88 | J
89 | X
90 | W
91 | P
92 | Z
93 | M
94 | k
95 | q
96 | Y
97 | Q
98 | D
99 | f
100 | K
101 | x
102 | '
103 | %
104 | -
105 | #
106 | @
107 | !
108 | &
109 | $
110 | ,
111 | :
112 | é
113 | ?
114 | +
115 | É
116 | (
117 |
118 |
--------------------------------------------------------------------------------
/utils/dict/arabic_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | ء
83 | آ
84 | أ
85 | ؤ
86 | إ
87 | ئ
88 | ا
89 | ب
90 | ة
91 | ت
92 | ث
93 | ج
94 | ح
95 | خ
96 | د
97 | ذ
98 | ر
99 | ز
100 | س
101 | ش
102 | ص
103 | ض
104 | ط
105 | ظ
106 | ع
107 | غ
108 | ف
109 | ق
110 | ك
111 | ل
112 | م
113 | ن
114 | ه
115 | و
116 | ى
117 | ي
118 | ً
119 | ٌ
120 | ٍ
121 | َ
122 | ُ
123 | ِ
124 | ّ
125 | ْ
126 | ٓ
127 | ٔ
128 | ٰ
129 | ٱ
130 | ٹ
131 | پ
132 | چ
133 | ڈ
134 | ڑ
135 | ژ
136 | ک
137 | ڭ
138 | گ
139 | ں
140 | ھ
141 | ۀ
142 | ہ
143 | ۂ
144 | ۃ
145 | ۆ
146 | ۇ
147 | ۈ
148 | ۋ
149 | ی
150 | ې
151 | ے
152 | ۓ
153 | ە
154 | ١
155 | ٢
156 | ٣
157 | ٤
158 | ٥
159 | ٦
160 | ٧
161 | ٨
162 | ٩
163 |
--------------------------------------------------------------------------------
/utils/dict/be_dict.txt:
--------------------------------------------------------------------------------
1 | b
2 | e
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 2
9 | 0
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 1
17 | v
18 | a
19 | l
20 | 6
21 | 9
22 | 4
23 | 3
24 | .
25 | j
26 | p
27 | п
28 | а
29 | з
30 | б
31 | у
32 | г
33 | н
34 | ц
35 | ь
36 | 8
37 | м
38 | л
39 | і
40 | о
41 | ў
42 | ы
43 | 7
44 | 5
45 | М
46 | х
47 | с
48 | р
49 | ф
50 | я
51 | е
52 | д
53 | ж
54 | ю
55 | ч
56 | й
57 | к
58 | Д
59 | в
60 | Б
61 | т
62 | І
63 | ш
64 | ё
65 | э
66 | К
67 | Л
68 | Н
69 | А
70 | Ж
71 | Г
72 | В
73 | П
74 | З
75 | Е
76 | О
77 | Р
78 | С
79 | У
80 | Ё
81 | Й
82 | Т
83 | Ч
84 | Э
85 | Ц
86 | Ю
87 | Ш
88 | Ф
89 | Х
90 | Я
91 | Ь
92 | Ы
93 | Ў
94 | s
95 | c
96 | n
97 | w
98 | M
99 | o
100 | t
101 | T
102 | E
103 | A
104 | B
105 | u
106 | h
107 | y
108 | k
109 | r
110 | H
111 | d
112 | Y
113 | O
114 | U
115 | F
116 | f
117 | x
118 | D
119 | G
120 | N
121 | K
122 | P
123 | z
124 | J
125 | X
126 | W
127 | Z
128 | Q
129 | %
130 | -
131 | q
132 | @
133 | '
134 | !
135 | #
136 | &
137 | ,
138 | :
139 | $
140 | (
141 | ?
142 | é
143 | +
144 | É
145 |
146 |
--------------------------------------------------------------------------------
/utils/dict/bg_dict.txt:
--------------------------------------------------------------------------------
1 | !
2 | #
3 | $
4 | %
5 | &
6 | '
7 | (
8 | +
9 | ,
10 | -
11 | .
12 | /
13 | 0
14 | 1
15 | 2
16 | 3
17 | 4
18 | 5
19 | 6
20 | 7
21 | 8
22 | 9
23 | :
24 | ?
25 | @
26 | A
27 | B
28 | C
29 | D
30 | E
31 | F
32 | G
33 | H
34 | I
35 | J
36 | K
37 | L
38 | M
39 | N
40 | O
41 | P
42 | Q
43 | R
44 | S
45 | T
46 | U
47 | V
48 | W
49 | X
50 | Y
51 | Z
52 | _
53 | a
54 | b
55 | c
56 | d
57 | e
58 | f
59 | g
60 | h
61 | i
62 | j
63 | k
64 | l
65 | m
66 | n
67 | o
68 | p
69 | q
70 | r
71 | s
72 | t
73 | u
74 | v
75 | w
76 | x
77 | y
78 | z
79 | É
80 | é
81 | А
82 | Б
83 | В
84 | Г
85 | Д
86 | Е
87 | Ж
88 | З
89 | И
90 | Й
91 | К
92 | Л
93 | М
94 | Н
95 | О
96 | П
97 | Р
98 | С
99 | Т
100 | У
101 | Ф
102 | Х
103 | Ц
104 | Ч
105 | Ш
106 | Щ
107 | Ъ
108 | Ю
109 | Я
110 | а
111 | б
112 | в
113 | г
114 | д
115 | е
116 | ж
117 | з
118 | и
119 | й
120 | к
121 | л
122 | м
123 | н
124 | о
125 | п
126 | р
127 | с
128 | т
129 | у
130 | ф
131 | х
132 | ц
133 | ч
134 | ш
135 | щ
136 | ъ
137 | ь
138 | ю
139 | я
140 |
141 |
--------------------------------------------------------------------------------
/utils/dict/cyrillic_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | Ё
83 | Є
84 | І
85 | Ј
86 | Љ
87 | Ў
88 | А
89 | Б
90 | В
91 | Г
92 | Д
93 | Е
94 | Ж
95 | З
96 | И
97 | Й
98 | К
99 | Л
100 | М
101 | Н
102 | О
103 | П
104 | Р
105 | С
106 | Т
107 | У
108 | Ф
109 | Х
110 | Ц
111 | Ч
112 | Ш
113 | Щ
114 | Ъ
115 | Ы
116 | Ь
117 | Э
118 | Ю
119 | Я
120 | а
121 | б
122 | в
123 | г
124 | д
125 | е
126 | ж
127 | з
128 | и
129 | й
130 | к
131 | л
132 | м
133 | н
134 | о
135 | п
136 | р
137 | с
138 | т
139 | у
140 | ф
141 | х
142 | ц
143 | ч
144 | ш
145 | щ
146 | ъ
147 | ы
148 | ь
149 | э
150 | ю
151 | я
152 | ё
153 | ђ
154 | є
155 | і
156 | ј
157 | љ
158 | њ
159 | ћ
160 | ў
161 | џ
162 | Ґ
163 | ґ
164 |
--------------------------------------------------------------------------------
/utils/dict/devanagari_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | ँ
83 | ं
84 | ः
85 | अ
86 | आ
87 | इ
88 | ई
89 | उ
90 | ऊ
91 | ऋ
92 | ए
93 | ऐ
94 | ऑ
95 | ओ
96 | औ
97 | क
98 | ख
99 | ग
100 | घ
101 | ङ
102 | च
103 | छ
104 | ज
105 | झ
106 | ञ
107 | ट
108 | ठ
109 | ड
110 | ढ
111 | ण
112 | त
113 | थ
114 | द
115 | ध
116 | न
117 | ऩ
118 | प
119 | फ
120 | ब
121 | भ
122 | म
123 | य
124 | र
125 | ऱ
126 | ल
127 | ळ
128 | व
129 | श
130 | ष
131 | स
132 | ह
133 | ़
134 | ा
135 | ि
136 | ी
137 | ु
138 | ू
139 | ृ
140 | ॅ
141 | े
142 | ै
143 | ॉ
144 | ो
145 | ौ
146 | ्
147 | ॒
148 | क़
149 | ख़
150 | ग़
151 | ज़
152 | ड़
153 | ढ़
154 | फ़
155 | ॠ
156 | ।
157 | ०
158 | १
159 | २
160 | ३
161 | ४
162 | ५
163 | ६
164 | ७
165 | ८
166 | ९
167 | ॰
168 |
--------------------------------------------------------------------------------
/utils/dict/en_dict.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 | 4
6 | 5
7 | 6
8 | 7
9 | 8
10 | 9
11 | a
12 | b
13 | c
14 | d
15 | e
16 | f
17 | g
18 | h
19 | i
20 | j
21 | k
22 | l
23 | m
24 | n
25 | o
26 | p
27 | q
28 | r
29 | s
30 | t
31 | u
32 | v
33 | w
34 | x
35 | y
36 | z
37 | A
38 | B
39 | C
40 | D
41 | E
42 | F
43 | G
44 | H
45 | I
46 | J
47 | K
48 | L
49 | M
50 | N
51 | O
52 | P
53 | Q
54 | R
55 | S
56 | T
57 | U
58 | V
59 | W
60 | X
61 | Y
62 | Z
63 |
64 |
--------------------------------------------------------------------------------
/utils/dict/fa_dict.txt:
--------------------------------------------------------------------------------
1 | f
2 | a
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 1
9 | 3
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 2
17 | 0
18 | v
19 | l
20 | 6
21 | 8
22 | 5
23 | .
24 | j
25 | p
26 | و
27 | د
28 | ر
29 | ك
30 | ن
31 | ش
32 | ه
33 | ا
34 | 4
35 | 9
36 | ی
37 | ج
38 | ِ
39 | 7
40 | غ
41 | ل
42 | س
43 | ز
44 | ّ
45 | ت
46 | ک
47 | گ
48 | ي
49 | م
50 | ب
51 | ف
52 | چ
53 | خ
54 | ق
55 | ژ
56 | آ
57 | ص
58 | پ
59 | َ
60 | ع
61 | ئ
62 | ح
63 | ٔ
64 | ض
65 | ُ
66 | ذ
67 | أ
68 | ى
69 | ط
70 | ظ
71 | ث
72 | ة
73 | ً
74 | ء
75 | ؤ
76 | ْ
77 | ۀ
78 | إ
79 | ٍ
80 | ٌ
81 | ٰ
82 | ٓ
83 | ٱ
84 | s
85 | c
86 | e
87 | n
88 | w
89 | N
90 | E
91 | W
92 | Y
93 | D
94 | O
95 | H
96 | A
97 | d
98 | z
99 | r
100 | T
101 | G
102 | o
103 | t
104 | x
105 | h
106 | b
107 | B
108 | M
109 | Z
110 | u
111 | P
112 | F
113 | y
114 | q
115 | U
116 | K
117 | k
118 | J
119 | Q
120 | '
121 | X
122 | #
123 | ?
124 | %
125 | $
126 | ,
127 | :
128 | &
129 | !
130 | -
131 | (
132 | É
133 | @
134 | é
135 | +
136 |
137 |
--------------------------------------------------------------------------------
/utils/dict/french_dict.txt:
--------------------------------------------------------------------------------
1 | f
2 | e
3 | n
4 | c
5 | h
6 | _
7 | i
8 | m
9 | g
10 | /
11 | r
12 | v
13 | a
14 | l
15 | t
16 | w
17 | o
18 | d
19 | 6
20 | 1
21 | .
22 | p
23 | B
24 | u
25 | 2
26 | à
27 | 3
28 | R
29 | y
30 | 4
31 | U
32 | E
33 | A
34 | 5
35 | P
36 | O
37 | S
38 | T
39 | D
40 | 7
41 | Z
42 | 8
43 | I
44 | N
45 | L
46 | G
47 | M
48 | H
49 | 0
50 | J
51 | K
52 | -
53 | 9
54 | F
55 | C
56 | V
57 | é
58 | X
59 | '
60 | s
61 | Q
62 | :
63 | è
64 | x
65 | b
66 | Y
67 | Œ
68 | É
69 | z
70 | W
71 | Ç
72 | È
73 | k
74 | Ô
75 | ô
76 | €
77 | À
78 | Ê
79 | q
80 | ù
81 | °
82 | ê
83 | î
84 | *
85 | Â
86 | j
87 | "
88 | ,
89 | â
90 | %
91 | û
92 | ç
93 | ü
94 | ?
95 | !
96 | ;
97 | ö
98 | (
99 | )
100 | ï
101 | º
102 | ó
103 | ø
104 | å
105 | +
106 | ™
107 | á
108 | Ë
109 | <
110 | ²
111 | Á
112 | Î
113 | &
114 | @
115 | œ
116 | ε
117 | Ü
118 | ë
119 | [
120 | ]
121 | í
122 | ò
123 | Ö
124 | ä
125 | ß
126 | «
127 | »
128 | ú
129 | ñ
130 | æ
131 | µ
132 | ³
133 | Å
134 | $
135 | #
136 |
137 |
--------------------------------------------------------------------------------
/utils/dict/german_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | "
4 | #
5 | $
6 | %
7 | &
8 | '
9 | (
10 | )
11 | *
12 | +
13 | ,
14 | -
15 | .
16 | /
17 | 0
18 | 1
19 | 2
20 | 3
21 | 4
22 | 5
23 | 6
24 | 7
25 | 8
26 | 9
27 | :
28 | ;
29 | =
30 | >
31 | ?
32 | @
33 | A
34 | B
35 | C
36 | D
37 | E
38 | F
39 | G
40 | H
41 | I
42 | J
43 | K
44 | L
45 | M
46 | N
47 | O
48 | P
49 | Q
50 | R
51 | S
52 | T
53 | U
54 | V
55 | W
56 | X
57 | Y
58 | Z
59 | [
60 | ]
61 | _
62 | a
63 | b
64 | c
65 | d
66 | e
67 | f
68 | g
69 | h
70 | i
71 | j
72 | k
73 | l
74 | m
75 | n
76 | o
77 | p
78 | q
79 | r
80 | s
81 | t
82 | u
83 | v
84 | w
85 | x
86 | y
87 | z
88 | £
89 | §
90 |
91 | °
92 | ´
93 | µ
94 | ·
95 | º
96 | ¿
97 | Á
98 | Ä
99 | Å
100 | É
101 | Ï
102 | Ô
103 | Ö
104 | Ü
105 | ß
106 | à
107 | á
108 | â
109 | ã
110 | ä
111 | å
112 | æ
113 | ç
114 | è
115 | é
116 | ê
117 | ë
118 | í
119 | ï
120 | ñ
121 | ò
122 | ó
123 | ô
124 | ö
125 | ø
126 | ù
127 | ú
128 | û
129 | ü
130 | ō
131 | Š
132 | Ÿ
133 | ʒ
134 | β
135 | δ
136 | з
137 | Ṡ
138 | ‘
139 | €
140 | ©
141 | ª
142 | «
143 | ¬
144 |
--------------------------------------------------------------------------------
/utils/dict/hi_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | ँ
83 | ं
84 | ः
85 | अ
86 | आ
87 | इ
88 | ई
89 | उ
90 | ऊ
91 | ऋ
92 | ए
93 | ऐ
94 | ऑ
95 | ओ
96 | औ
97 | क
98 | ख
99 | ग
100 | घ
101 | ङ
102 | च
103 | छ
104 | ज
105 | झ
106 | ञ
107 | ट
108 | ठ
109 | ड
110 | ढ
111 | ण
112 | त
113 | थ
114 | द
115 | ध
116 | न
117 | प
118 | फ
119 | ब
120 | भ
121 | म
122 | य
123 | र
124 | ल
125 | ळ
126 | व
127 | श
128 | ष
129 | स
130 | ह
131 | ़
132 | ा
133 | ि
134 | ी
135 | ु
136 | ू
137 | ृ
138 | ॅ
139 | े
140 | ै
141 | ॉ
142 | ो
143 | ौ
144 | ्
145 | क़
146 | ख़
147 | ग़
148 | ज़
149 | ड़
150 | ढ़
151 | फ़
152 | ०
153 | १
154 | २
155 | ३
156 | ४
157 | ५
158 | ६
159 | ७
160 | ८
161 | ९
162 | ॰
163 |
--------------------------------------------------------------------------------
/utils/dict/it_dict.txt:
--------------------------------------------------------------------------------
1 | i
2 | t
3 | _
4 | m
5 | g
6 | /
7 | 5
8 | I
9 | L
10 | S
11 | V
12 | R
13 | C
14 | 2
15 | 0
16 | 1
17 | v
18 | a
19 | l
20 | 7
21 | 8
22 | 9
23 | 6
24 | .
25 | j
26 | p
27 |
28 | e
29 | r
30 | o
31 | d
32 | s
33 | n
34 | 3
35 | 4
36 | P
37 | u
38 | c
39 | A
40 | -
41 | ,
42 | "
43 | z
44 | h
45 | f
46 | b
47 | q
48 | ì
49 | '
50 | à
51 | O
52 | è
53 | G
54 | ù
55 | é
56 | ò
57 | ;
58 | F
59 | E
60 | B
61 | N
62 | H
63 | k
64 | :
65 | U
66 | T
67 | X
68 | D
69 | K
70 | ?
71 | [
72 | M
73 |
74 | x
75 | y
76 | (
77 | )
78 | W
79 | ö
80 | º
81 | w
82 | ]
83 | Q
84 | J
85 | +
86 | ü
87 | !
88 | È
89 | á
90 | %
91 | =
92 | »
93 | ñ
94 | Ö
95 | Y
96 | ä
97 | í
98 | Z
99 | «
100 | @
101 | ó
102 | ø
103 | ï
104 | ú
105 | ê
106 | ç
107 | Á
108 | É
109 | Å
110 | ß
111 | {
112 | }
113 | &
114 | `
115 | û
116 | î
117 | #
118 | $
119 |
--------------------------------------------------------------------------------
/utils/dict/ka_dict.txt:
--------------------------------------------------------------------------------
1 | k
2 | a
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 1
9 | 2
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 0
17 | v
18 | l
19 | 6
20 | 4
21 | 8
22 | .
23 | j
24 | p
25 | ಗ
26 | ು
27 | ಣ
28 | ಪ
29 | ಡ
30 | ಿ
31 | ಸ
32 | ಲ
33 | ಾ
34 | ದ
35 | ್
36 | 7
37 | 5
38 | 3
39 | ವ
40 | ಷ
41 | ಬ
42 | ಹ
43 | ೆ
44 | 9
45 | ಅ
46 | ಳ
47 | ನ
48 | ರ
49 | ಉ
50 | ಕ
51 | ಎ
52 | ೇ
53 | ಂ
54 | ೈ
55 | ೊ
56 | ೀ
57 | ಯ
58 | ೋ
59 | ತ
60 | ಶ
61 | ಭ
62 | ಧ
63 | ಚ
64 | ಜ
65 | ೂ
66 | ಮ
67 | ಒ
68 | ೃ
69 | ಥ
70 | ಇ
71 | ಟ
72 | ಖ
73 | ಆ
74 | ಞ
75 | ಫ
76 | -
77 | ಢ
78 | ಊ
79 | ಓ
80 | ಐ
81 | ಃ
82 | ಘ
83 | ಝ
84 | ೌ
85 | ಠ
86 | ಛ
87 | ಔ
88 | ಏ
89 | ಈ
90 | ಋ
91 | ೨
92 | ೦
93 | ೧
94 | ೮
95 | ೯
96 | ೪
97 | ,
98 | ೫
99 | ೭
100 | ೩
101 | ೬
102 | ಙ
103 | s
104 | c
105 | e
106 | n
107 | w
108 | o
109 | u
110 | t
111 | d
112 | E
113 | A
114 | T
115 | B
116 | Z
117 | N
118 | G
119 | O
120 | q
121 | z
122 | r
123 | x
124 | P
125 | K
126 | M
127 | J
128 | U
129 | D
130 | f
131 | F
132 | h
133 | b
134 | W
135 | Y
136 | y
137 | H
138 | X
139 | Q
140 | '
141 | #
142 | &
143 | !
144 | @
145 | $
146 | :
147 | %
148 | é
149 | É
150 | (
151 | ?
152 | +
153 |
154 |
--------------------------------------------------------------------------------
/utils/dict/latin_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | "
4 | #
5 | $
6 | %
7 | &
8 | '
9 | (
10 | )
11 | *
12 | +
13 | ,
14 | -
15 | .
16 | /
17 | 0
18 | 1
19 | 2
20 | 3
21 | 4
22 | 5
23 | 6
24 | 7
25 | 8
26 | 9
27 | :
28 | ;
29 | <
30 | =
31 | >
32 | ?
33 | @
34 | A
35 | B
36 | C
37 | D
38 | E
39 | F
40 | G
41 | H
42 | I
43 | J
44 | K
45 | L
46 | M
47 | N
48 | O
49 | P
50 | Q
51 | R
52 | S
53 | T
54 | U
55 | V
56 | W
57 | X
58 | Y
59 | Z
60 | [
61 | ]
62 | _
63 | `
64 | a
65 | b
66 | c
67 | d
68 | e
69 | f
70 | g
71 | h
72 | i
73 | j
74 | k
75 | l
76 | m
77 | n
78 | o
79 | p
80 | q
81 | r
82 | s
83 | t
84 | u
85 | v
86 | w
87 | x
88 | y
89 | z
90 | {
91 | }
92 | ¡
93 | £
94 | §
95 | ª
96 | «
97 |
98 | °
99 | ²
100 | ³
101 | ´
102 | µ
103 | ·
104 | º
105 | »
106 | ¿
107 | À
108 | Á
109 | Â
110 | Ä
111 | Å
112 | Ç
113 | È
114 | É
115 | Ê
116 | Ë
117 | Ì
118 | Í
119 | Î
120 | Ï
121 | Ò
122 | Ó
123 | Ô
124 | Õ
125 | Ö
126 | Ú
127 | Ü
128 | Ý
129 | ß
130 | à
131 | á
132 | â
133 | ã
134 | ä
135 | å
136 | æ
137 | ç
138 | è
139 | é
140 | ê
141 | ë
142 | ì
143 | í
144 | î
145 | ï
146 | ñ
147 | ò
148 | ó
149 | ô
150 | õ
151 | ö
152 | ø
153 | ù
154 | ú
155 | û
156 | ü
157 | ý
158 | ą
159 | Ć
160 | ć
161 | Č
162 | č
163 | Đ
164 | đ
165 | ę
166 | ı
167 | Ł
168 | ł
169 | ō
170 | Œ
171 | œ
172 | Š
173 | š
174 | Ÿ
175 | Ž
176 | ž
177 | ʒ
178 | β
179 | δ
180 | ε
181 | з
182 | Ṡ
183 | ‘
184 | €
185 | ™
186 |
--------------------------------------------------------------------------------
/utils/dict/mr_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | ँ
83 | ं
84 | ः
85 | अ
86 | आ
87 | इ
88 | ई
89 | उ
90 | ऊ
91 | ए
92 | ऐ
93 | ऑ
94 | ओ
95 | औ
96 | क
97 | ख
98 | ग
99 | घ
100 | च
101 | छ
102 | ज
103 | झ
104 | ञ
105 | ट
106 | ठ
107 | ड
108 | ढ
109 | ण
110 | त
111 | थ
112 | द
113 | ध
114 | न
115 | प
116 | फ
117 | ब
118 | भ
119 | म
120 | य
121 | र
122 | ऱ
123 | ल
124 | ळ
125 | व
126 | श
127 | ष
128 | स
129 | ह
130 | ़
131 | ा
132 | ि
133 | ी
134 | ु
135 | ू
136 | ृ
137 | ॅ
138 | े
139 | ै
140 | ॉ
141 | ो
142 | ौ
143 | ्
144 | ०
145 | १
146 | २
147 | ३
148 | ४
149 | ५
150 | ६
151 | ७
152 | ८
153 | ९
154 |
--------------------------------------------------------------------------------
/utils/dict/ne_dict.txt:
--------------------------------------------------------------------------------
1 |
2 | !
3 | #
4 | $
5 | %
6 | &
7 | '
8 | (
9 | +
10 | ,
11 | -
12 | .
13 | /
14 | 0
15 | 1
16 | 2
17 | 3
18 | 4
19 | 5
20 | 6
21 | 7
22 | 8
23 | 9
24 | :
25 | ?
26 | @
27 | A
28 | B
29 | C
30 | D
31 | E
32 | F
33 | G
34 | H
35 | I
36 | J
37 | K
38 | L
39 | M
40 | N
41 | O
42 | P
43 | Q
44 | R
45 | S
46 | T
47 | U
48 | V
49 | W
50 | X
51 | Y
52 | Z
53 | _
54 | a
55 | b
56 | c
57 | d
58 | e
59 | f
60 | g
61 | h
62 | i
63 | j
64 | k
65 | l
66 | m
67 | n
68 | o
69 | p
70 | q
71 | r
72 | s
73 | t
74 | u
75 | v
76 | w
77 | x
78 | y
79 | z
80 | É
81 | é
82 | ः
83 | अ
84 | आ
85 | इ
86 | ई
87 | उ
88 | ऊ
89 | ऋ
90 | ए
91 | ऐ
92 | ओ
93 | औ
94 | क
95 | ख
96 | ग
97 | घ
98 | ङ
99 | च
100 | छ
101 | ज
102 | झ
103 | ञ
104 | ट
105 | ठ
106 | ड
107 | ढ
108 | ण
109 | त
110 | थ
111 | द
112 | ध
113 | न
114 | ऩ
115 | प
116 | फ
117 | ब
118 | भ
119 | म
120 | य
121 | र
122 | ऱ
123 | ल
124 | व
125 | श
126 | ष
127 | स
128 | ह
129 | ़
130 | ा
131 | ि
132 | ी
133 | ु
134 | ू
135 | ृ
136 | े
137 | ै
138 | ो
139 | ौ
140 | ्
141 | ॒
142 | ॠ
143 | ।
144 | ०
145 | १
146 | २
147 | ३
148 | ४
149 | ५
150 | ६
151 | ७
152 | ८
153 | ९
154 |
--------------------------------------------------------------------------------
/utils/dict/oc_dict.txt:
--------------------------------------------------------------------------------
1 | o
2 | c
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 2
9 | 0
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 1
17 | v
18 | a
19 | l
20 | 4
21 | 3
22 | .
23 | j
24 | p
25 | r
26 | e
27 | è
28 | t
29 | 9
30 | 7
31 | 5
32 | 8
33 | n
34 | '
35 | b
36 | s
37 | 6
38 | q
39 | u
40 | á
41 | d
42 | ò
43 | à
44 | h
45 | z
46 | f
47 | ï
48 | í
49 | A
50 | ç
51 | x
52 | ó
53 | é
54 | P
55 | O
56 | Ò
57 | ü
58 | k
59 | À
60 | F
61 | -
62 | ú
63 |
64 | æ
65 | Á
66 | D
67 | E
68 | w
69 | K
70 | T
71 | N
72 | y
73 | U
74 | Z
75 | G
76 | B
77 | J
78 | H
79 | M
80 | W
81 | Y
82 | X
83 | Q
84 | %
85 | $
86 | ,
87 | @
88 | &
89 | !
90 | :
91 | (
92 | #
93 | ?
94 | +
95 | É
96 |
97 |
--------------------------------------------------------------------------------
/utils/dict/pu_dict.txt:
--------------------------------------------------------------------------------
1 | p
2 | u
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 8
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | 0
17 | 1
18 | v
19 | a
20 | l
21 | 6
22 | 7
23 | 4
24 | 5
25 | .
26 | j
27 |
28 | q
29 | e
30 | s
31 | t
32 | ã
33 | o
34 | x
35 | 9
36 | c
37 | n
38 | r
39 | z
40 | ç
41 | õ
42 | 3
43 | A
44 | U
45 | d
46 | º
47 | ô
48 |
49 | ,
50 | E
51 | ;
52 | ó
53 | á
54 | b
55 | D
56 | ?
57 | ú
58 | ê
59 | -
60 | h
61 | P
62 | f
63 | à
64 | N
65 | í
66 | O
67 | M
68 | G
69 | É
70 | é
71 | â
72 | F
73 | :
74 | T
75 | Á
76 | "
77 | Q
78 | )
79 | W
80 | J
81 | B
82 | H
83 | (
84 | ö
85 | %
86 | Ö
87 | «
88 | w
89 | K
90 | y
91 | !
92 | k
93 | ]
94 | '
95 | Z
96 | +
97 | Ç
98 | Õ
99 | Y
100 | À
101 | X
102 | µ
103 | »
104 | ª
105 | Í
106 | ü
107 | ä
108 | ´
109 | è
110 | ñ
111 | ß
112 | ï
113 | Ú
114 | ë
115 | Ô
116 | Ï
117 | Ó
118 | [
119 | Ì
120 | <
121 | Â
122 | ò
123 | §
124 | ³
125 | ø
126 | å
127 | #
128 | $
129 | &
130 | @
131 |
--------------------------------------------------------------------------------
/utils/dict/rs_dict.txt:
--------------------------------------------------------------------------------
1 | r
2 | s
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 1
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | 0
17 | v
18 | a
19 | l
20 | 7
21 | 5
22 | 8
23 | 6
24 | .
25 | j
26 | p
27 |
28 | t
29 | d
30 | 9
31 | 3
32 | e
33 | š
34 | 4
35 | k
36 | u
37 | ć
38 | c
39 | n
40 | đ
41 | o
42 | z
43 | č
44 | b
45 | ž
46 | f
47 | Z
48 | T
49 | h
50 | M
51 | F
52 | O
53 | Š
54 | B
55 | H
56 | A
57 | E
58 | Đ
59 | Ž
60 | D
61 | P
62 | G
63 | Č
64 | K
65 | U
66 | N
67 | J
68 | Ć
69 | w
70 | y
71 | W
72 | x
73 | Y
74 | X
75 | q
76 | Q
77 | #
78 | &
79 | $
80 | ,
81 | -
82 | %
83 | '
84 | @
85 | !
86 | :
87 | ?
88 | (
89 | É
90 | é
91 | +
92 |
--------------------------------------------------------------------------------
/utils/dict/rsc_dict.txt:
--------------------------------------------------------------------------------
1 | r
2 | s
3 | c
4 | _
5 | i
6 | m
7 | g
8 | /
9 | 5
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 2
17 | 0
18 | 1
19 | v
20 | a
21 | l
22 | 9
23 | 7
24 | 8
25 | .
26 | j
27 | p
28 | м
29 | а
30 | с
31 | и
32 | р
33 | ћ
34 | е
35 | ш
36 | 3
37 | 4
38 | о
39 | г
40 | н
41 | з
42 | в
43 | л
44 | 6
45 | т
46 | ж
47 | у
48 | к
49 | п
50 | њ
51 | д
52 | ч
53 | С
54 | ј
55 | ф
56 | ц
57 | љ
58 | х
59 | О
60 | И
61 | А
62 | б
63 | Ш
64 | К
65 | ђ
66 | џ
67 | М
68 | В
69 | З
70 | Д
71 | Р
72 | У
73 | Н
74 | Т
75 | Б
76 | ?
77 | П
78 | Х
79 | Ј
80 | Ц
81 | Г
82 | Љ
83 | Л
84 | Ф
85 | e
86 | n
87 | w
88 | E
89 | F
90 | A
91 | N
92 | f
93 | o
94 | b
95 | M
96 | G
97 | t
98 | y
99 | W
100 | k
101 | P
102 | u
103 | H
104 | B
105 | T
106 | z
107 | h
108 | O
109 | Y
110 | d
111 | U
112 | K
113 | D
114 | x
115 | X
116 | J
117 | Z
118 | Q
119 | q
120 | '
121 | -
122 | @
123 | é
124 | #
125 | !
126 | ,
127 | %
128 | $
129 | :
130 | &
131 | +
132 | (
133 | É
134 |
135 |
--------------------------------------------------------------------------------
/utils/dict/ru_dict.txt:
--------------------------------------------------------------------------------
1 | к
2 | в
3 | а
4 | з
5 | и
6 | у
7 | р
8 | о
9 | н
10 | я
11 | х
12 | п
13 | л
14 | ы
15 | г
16 | е
17 | т
18 | м
19 | д
20 | ж
21 | ш
22 | ь
23 | с
24 | ё
25 | б
26 | й
27 | ч
28 | ю
29 | ц
30 | щ
31 | М
32 | э
33 | ф
34 | А
35 | ъ
36 | С
37 | Ф
38 | Ю
39 | В
40 | К
41 | Т
42 | Н
43 | О
44 | Э
45 | У
46 | И
47 | Г
48 | Л
49 | Р
50 | Д
51 | Б
52 | Ш
53 | П
54 | З
55 | Х
56 | Е
57 | Ж
58 | Я
59 | Ц
60 | Ч
61 | Й
62 | Щ
63 | 0
64 | 1
65 | 2
66 | 3
67 | 4
68 | 5
69 | 6
70 | 7
71 | 8
72 | 9
73 | a
74 | b
75 | c
76 | d
77 | e
78 | f
79 | g
80 | h
81 | i
82 | j
83 | k
84 | l
85 | m
86 | n
87 | o
88 | p
89 | q
90 | r
91 | s
92 | t
93 | u
94 | v
95 | w
96 | x
97 | y
98 | z
99 | A
100 | B
101 | C
102 | D
103 | E
104 | F
105 | G
106 | H
107 | I
108 | J
109 | K
110 | L
111 | M
112 | N
113 | O
114 | P
115 | Q
116 | R
117 | S
118 | T
119 | U
120 | V
121 | W
122 | X
123 | Y
124 | Z
125 |
126 |
--------------------------------------------------------------------------------
/utils/dict/ta_dict.txt:
--------------------------------------------------------------------------------
1 | t
2 | a
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 3
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | 0
17 | 1
18 | v
19 | l
20 | 9
21 | 7
22 | 8
23 | .
24 | j
25 | p
26 | ப
27 | ூ
28 | த
29 | ம
30 | ி
31 | வ
32 | ர
33 | ்
34 | ந
35 | ோ
36 | ன
37 | 6
38 | ஆ
39 | ற
40 | ல
41 | 5
42 | ள
43 | ா
44 | ொ
45 | ழ
46 | ு
47 | 4
48 | ெ
49 | ண
50 | க
51 | ட
52 | ை
53 | ே
54 | ச
55 | ய
56 | ஒ
57 | இ
58 | அ
59 | ங
60 | உ
61 | ீ
62 | ஞ
63 | எ
64 | ஓ
65 | ஃ
66 | ஜ
67 | ஷ
68 | ஸ
69 | ஏ
70 | ஊ
71 | ஹ
72 | ஈ
73 | ஐ
74 | ௌ
75 | ஔ
76 | s
77 | c
78 | e
79 | n
80 | w
81 | F
82 | T
83 | O
84 | P
85 | K
86 | A
87 | N
88 | G
89 | Y
90 | E
91 | M
92 | H
93 | U
94 | B
95 | o
96 | b
97 | D
98 | d
99 | r
100 | W
101 | u
102 | y
103 | f
104 | X
105 | k
106 | q
107 | h
108 | J
109 | z
110 | Z
111 | Q
112 | x
113 | -
114 | '
115 | $
116 | ,
117 | %
118 | @
119 | é
120 | !
121 | #
122 | +
123 | É
124 | &
125 | :
126 | (
127 | ?
128 |
129 |
--------------------------------------------------------------------------------
/utils/dict/table_dict.txt:
--------------------------------------------------------------------------------
1 | ←
2 |
3 | ☆
4 | ─
5 | α
6 |
7 |
8 | ⋅
9 | $
10 | ω
11 | ψ
12 | χ
13 | (
14 | υ
15 | ≥
16 | σ
17 | ,
18 | ρ
19 | ε
20 | 0
21 | ■
22 | 4
23 | 8
24 | ✗
25 | b
26 | <
27 | ✓
28 | Ψ
29 | Ω
30 | €
31 | D
32 | 3
33 | Π
34 | H
35 | ║
36 |
37 | L
38 | Φ
39 | Χ
40 | θ
41 | P
42 | κ
43 | λ
44 | μ
45 | T
46 | ξ
47 | X
48 | β
49 | γ
50 | δ
51 | \
52 | ζ
53 | η
54 | `
55 | d
56 |
57 | h
58 | f
59 | l
60 | Θ
61 | p
62 | √
63 | t
64 |
65 | x
66 | Β
67 | Γ
68 | Δ
69 | |
70 | ǂ
71 | ɛ
72 | j
73 | ̧
74 | ➢
75 |
76 | ̌
77 | ′
78 | «
79 | △
80 | ▲
81 | #
82 |
83 | '
84 | Ι
85 | +
86 | ¶
87 | /
88 | ▼
89 | ⇑
90 | □
91 | ·
92 | 7
93 | ▪
94 | ;
95 | ?
96 | ➔
97 | ∩
98 | C
99 | ÷
100 | G
101 | ⇒
102 | K
103 |
104 | O
105 | S
106 | С
107 | W
108 | Α
109 | [
110 | ○
111 | _
112 | ●
113 | ‡
114 | c
115 | z
116 | g
117 |
118 | o
119 |
120 | 〈
121 | 〉
122 | s
123 | ⩽
124 | w
125 | φ
126 | ʹ
127 | {
128 | »
129 | ∣
130 | ̆
131 | e
132 | ˆ
133 | ∈
134 | τ
135 | ◆
136 | ι
137 | ∅
138 | ∆
139 | ∙
140 | ∘
141 | Ø
142 | ß
143 | ✔
144 | ∞
145 | ∑
146 | −
147 | ×
148 | ◊
149 | ∗
150 | ∖
151 | ˃
152 | ˂
153 | ∫
154 | "
155 | i
156 | &
157 | π
158 | ↔
159 | *
160 | ∥
161 | æ
162 | ∧
163 | .
164 | ⁄
165 | ø
166 | Q
167 | ∼
168 | 6
169 | ⁎
170 | :
171 | ★
172 | >
173 | a
174 | B
175 | ≈
176 | F
177 | J
178 | ̄
179 | N
180 | ♯
181 | R
182 | V
183 |
184 | ―
185 | Z
186 | ♣
187 | ^
188 | ¤
189 | ¥
190 | §
191 |
192 | ¢
193 | £
194 | ≦
195 |
196 | ≤
197 | ‖
198 | Λ
199 | ©
200 | n
201 | ↓
202 | →
203 | ↑
204 | r
205 | °
206 | ±
207 | v
208 |
209 | ♂
210 | k
211 | ♀
212 | ~
213 | ᅟ
214 | ̇
215 | @
216 | ”
217 | ♦
218 | ł
219 | ®
220 | ⊕
221 | „
222 | !
223 |
224 | %
225 | ⇓
226 | )
227 | -
228 | 1
229 | 5
230 | 9
231 | =
232 | А
233 | A
234 | ‰
235 | ⋆
236 | Σ
237 | E
238 | ◦
239 | I
240 | ※
241 | M
242 | m
243 | ̨
244 | ⩾
245 | †
246 |
247 | •
248 | U
249 | Y
250 |
251 | ]
252 | ̸
253 | 2
254 | ‐
255 | –
256 | ‒
257 | ̂
258 | —
259 | ̀
260 | ́
261 | ’
262 | ‘
263 | ⋮
264 | ⋯
265 | ̊
266 | “
267 | ̈
268 | ≧
269 | q
270 | u
271 | ı
272 | y
273 |
274 |
275 | ̃
276 | }
277 | ν
278 |
--------------------------------------------------------------------------------
/utils/dict/te_dict.txt:
--------------------------------------------------------------------------------
1 | t
2 | e
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 5
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | 0
17 | 1
18 | v
19 | a
20 | l
21 | 3
22 | 4
23 | 8
24 | 9
25 | .
26 | j
27 | p
28 | త
29 | ె
30 | ర
31 | క
32 | ్
33 | ి
34 | ం
35 | చ
36 | ే
37 | ద
38 | ు
39 | 7
40 | 6
41 | ఉ
42 | ా
43 | మ
44 | ట
45 | ో
46 | వ
47 | ప
48 | ల
49 | శ
50 | ఆ
51 | య
52 | ై
53 | భ
54 | '
55 | ీ
56 | గ
57 | ూ
58 | డ
59 | ధ
60 | హ
61 | న
62 | జ
63 | స
64 | [
65 |
66 | ష
67 | అ
68 | ణ
69 | ఫ
70 | బ
71 | ఎ
72 | ;
73 | ళ
74 | థ
75 | ొ
76 | ఠ
77 | ృ
78 | ఒ
79 | ఇ
80 | ః
81 | ఊ
82 | ఖ
83 | -
84 | ఐ
85 | ఘ
86 | ౌ
87 | ఏ
88 | ఈ
89 | ఛ
90 | ,
91 | ఓ
92 | ఞ
93 | |
94 | ?
95 | :
96 | ఢ
97 | "
98 | (
99 | ”
100 | !
101 | +
102 | )
103 | *
104 | =
105 | &
106 | “
107 | €
108 | ]
109 | £
110 | $
111 | s
112 | c
113 | n
114 | w
115 | k
116 | J
117 | G
118 | u
119 | d
120 | r
121 | E
122 | o
123 | h
124 | y
125 | b
126 | f
127 | B
128 | M
129 | O
130 | T
131 | N
132 | D
133 | P
134 | A
135 | F
136 | x
137 | W
138 | Y
139 | U
140 | H
141 | K
142 | X
143 | z
144 | Z
145 | Q
146 | q
147 | É
148 | %
149 | #
150 | @
151 | é
152 |
--------------------------------------------------------------------------------
/utils/dict/ug_dict.txt:
--------------------------------------------------------------------------------
1 | u
2 | g
3 | _
4 | i
5 | m
6 | /
7 | 1
8 | I
9 | L
10 | S
11 | V
12 | R
13 | C
14 | 2
15 | 0
16 | v
17 | a
18 | l
19 | 8
20 | 5
21 | 3
22 | 6
23 | 9
24 | .
25 | j
26 | p
27 |
28 | ق
29 | ا
30 | پ
31 | ل
32 | 4
33 | 7
34 | ئ
35 | ى
36 | ش
37 | ت
38 | ي
39 | ك
40 | د
41 | ف
42 | ر
43 | و
44 | ن
45 | ب
46 | ە
47 | خ
48 | ې
49 | چ
50 | ۇ
51 | ز
52 | س
53 | م
54 | ۋ
55 | گ
56 | ڭ
57 | ۆ
58 | ۈ
59 | ج
60 | غ
61 | ھ
62 | ژ
63 | s
64 | c
65 | e
66 | n
67 | w
68 | P
69 | E
70 | D
71 | U
72 | d
73 | r
74 | b
75 | y
76 | B
77 | o
78 | O
79 | Y
80 | N
81 | T
82 | k
83 | t
84 | h
85 | A
86 | H
87 | F
88 | z
89 | W
90 | K
91 | G
92 | M
93 | f
94 | Z
95 | X
96 | Q
97 | J
98 | x
99 | q
100 | -
101 | !
102 | %
103 | #
104 | ?
105 | :
106 | $
107 | ,
108 | &
109 | '
110 | É
111 | @
112 | é
113 | (
114 | +
115 |
--------------------------------------------------------------------------------
/utils/dict/uk_dict.txt:
--------------------------------------------------------------------------------
1 | u
2 | k
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 1
9 | 6
10 | I
11 | L
12 | S
13 | V
14 | R
15 | C
16 | 2
17 | 0
18 | v
19 | a
20 | l
21 | 7
22 | 9
23 | .
24 | j
25 | p
26 | в
27 | і
28 | д
29 | п
30 | о
31 | н
32 | с
33 | т
34 | ю
35 | 4
36 | 5
37 | 3
38 | а
39 | и
40 | м
41 | е
42 | р
43 | ч
44 | у
45 | Б
46 | з
47 | л
48 | к
49 | 8
50 | А
51 | В
52 | г
53 | є
54 | б
55 | ь
56 | х
57 | ґ
58 | ш
59 | ц
60 | ф
61 | я
62 | щ
63 | ж
64 | Г
65 | Х
66 | У
67 | Т
68 | Е
69 | І
70 | Н
71 | П
72 | З
73 | Л
74 | Ю
75 | С
76 | Д
77 | М
78 | К
79 | Р
80 | Ф
81 | О
82 | Ц
83 | И
84 | Я
85 | Ч
86 | Ш
87 | Ж
88 | Є
89 | Ґ
90 | Ь
91 | s
92 | c
93 | e
94 | n
95 | w
96 | A
97 | P
98 | r
99 | E
100 | t
101 | o
102 | h
103 | d
104 | y
105 | M
106 | G
107 | N
108 | F
109 | B
110 | T
111 | D
112 | U
113 | O
114 | W
115 | Z
116 | f
117 | H
118 | Y
119 | b
120 | K
121 | z
122 | x
123 | Q
124 | X
125 | q
126 | J
127 | $
128 | -
129 | '
130 | #
131 | &
132 | %
133 | ?
134 | :
135 | !
136 | ,
137 | +
138 | @
139 | (
140 | é
141 | É
142 |
143 |
--------------------------------------------------------------------------------
/utils/dict/ur_dict.txt:
--------------------------------------------------------------------------------
1 | u
2 | r
3 | _
4 | i
5 | m
6 | g
7 | /
8 | 3
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | 0
17 | 1
18 | v
19 | a
20 | l
21 | 9
22 | 7
23 | 8
24 | .
25 | j
26 | p
27 |
28 | چ
29 | ٹ
30 | پ
31 | ا
32 | ئ
33 | ی
34 | ے
35 | 4
36 | 6
37 | و
38 | ل
39 | ن
40 | ڈ
41 | ھ
42 | ک
43 | ت
44 | ش
45 | ف
46 | ق
47 | ر
48 | د
49 | 5
50 | ب
51 | ج
52 | خ
53 | ہ
54 | س
55 | ز
56 | غ
57 | ڑ
58 | ں
59 | آ
60 | م
61 | ؤ
62 | ط
63 | ص
64 | ح
65 | ع
66 | گ
67 | ث
68 | ض
69 | ذ
70 | ۓ
71 | ِ
72 | ء
73 | ظ
74 | ً
75 | ي
76 | ُ
77 | ۃ
78 | أ
79 | ٰ
80 | ە
81 | ژ
82 | ۂ
83 | ة
84 | ّ
85 | ك
86 | ه
87 | s
88 | c
89 | e
90 | n
91 | w
92 | o
93 | d
94 | t
95 | D
96 | M
97 | T
98 | U
99 | E
100 | b
101 | P
102 | h
103 | y
104 | W
105 | H
106 | A
107 | x
108 | B
109 | O
110 | N
111 | G
112 | Y
113 | Q
114 | F
115 | k
116 | K
117 | q
118 | J
119 | Z
120 | f
121 | z
122 | X
123 | '
124 | @
125 | &
126 | !
127 | ,
128 | :
129 | $
130 | -
131 | #
132 | ?
133 | %
134 | é
135 | +
136 | (
137 | É
138 |
--------------------------------------------------------------------------------
/utils/dict/xi_dict.txt:
--------------------------------------------------------------------------------
1 | x
2 | i
3 | _
4 | m
5 | g
6 | /
7 | 1
8 | 0
9 | I
10 | L
11 | S
12 | V
13 | R
14 | C
15 | 2
16 | v
17 | a
18 | l
19 | 3
20 | 6
21 | 4
22 | 5
23 | .
24 | j
25 | p
26 |
27 | Q
28 | u
29 | e
30 | r
31 | o
32 | 8
33 | 7
34 | n
35 | c
36 | 9
37 | t
38 | b
39 | é
40 | q
41 | d
42 | ó
43 | y
44 | F
45 | s
46 | ,
47 | O
48 | í
49 | T
50 | f
51 | "
52 | U
53 | M
54 | h
55 | :
56 | P
57 | H
58 | A
59 | E
60 | D
61 | z
62 | N
63 | á
64 | ñ
65 | ú
66 | %
67 | ;
68 | è
69 | +
70 | Y
71 | -
72 | B
73 | G
74 | (
75 | )
76 | ¿
77 | ?
78 | w
79 | ¡
80 | !
81 | X
82 | É
83 | K
84 | k
85 | Á
86 | ü
87 | Ú
88 | «
89 | »
90 | J
91 | '
92 | ö
93 | W
94 | Z
95 | º
96 | Ö
97 |
98 | [
99 | ]
100 | Ç
101 | ç
102 | à
103 | ä
104 | û
105 | ò
106 | Í
107 | ê
108 | ô
109 | ø
110 | ª
111 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/6/15 14:25
3 | # @Author : zhoujun
4 |
5 | import logging
6 |
7 | import torch.distributed as dist
8 |
9 | logger_initialized = {}
10 |
11 |
12 | def get_logger(name, log_file=None, log_level=logging.INFO):
13 | """Initialize and get a logger by name.
14 | If the logger has not been initialized, this method will initialize the
15 | logger by adding one or two handlers, otherwise the initialized logger will
16 | be directly returned. During initialization, a StreamHandler will always be
17 | added. If `log_file` is specified and the process rank is 0, a FileHandler
18 | will also be added.
19 | Args:
20 | name (str): Logger name.
21 | log_file (str | None): The log filename. If specified, a FileHandler
22 | will be added to the logger.
23 | log_level (int): The logger level. Note that only the process of
24 | rank 0 is affected, and other processes will set the level to
25 | "Error" thus be silent most of the time.
26 | Returns:
27 | logging.Logger: The expected logger.
28 | """
29 | logger = logging.getLogger(name)
30 | if name in logger_initialized:
31 | return logger
32 | # handle hierarchical names
33 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
34 | # initialization since it is a child of "a".
35 | for logger_name in logger_initialized:
36 | if name.startswith(logger_name):
37 | return logger
38 |
39 | stream_handler = logging.StreamHandler()
40 | handlers = [stream_handler]
41 |
42 | if dist.is_available() and dist.is_initialized():
43 | rank = dist.get_rank()
44 | else:
45 | rank = 0
46 |
47 | # only rank 0 will add a FileHandler
48 | if rank == 0 and log_file is not None:
49 | file_handler = logging.FileHandler(log_file, 'w')
50 | handlers.append(file_handler)
51 |
52 | formatter = logging.Formatter(
53 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
54 | for handler in handlers:
55 | handler.setFormatter(formatter)
56 | handler.setLevel(log_level)
57 | logger.addHandler(handler)
58 |
59 | if rank == 0:
60 | logger.setLevel(log_level)
61 | else:
62 | logger.setLevel(logging.ERROR)
63 |
64 | logger_initialized[name] = True
65 |
66 | return logger
67 |
68 |
69 | def print_log(msg, logger=None, level=logging.INFO):
70 | """Print a log message.
71 | Args:
72 | msg (str): The message to be logged.
73 | logger (logging.Logger | str | None): The logger to be used.
74 | Some special loggers are:
75 | - "silent": no message will be printed.
76 | - other str: the logger obtained with `get_root_logger(logger)`.
77 | - None: The `print()` method will be used to print log messages.
78 | level (int): Logging level. Only available when `logger` is a Logger
79 | object or "root".
80 | """
81 | if logger is None:
82 | print(msg)
83 | elif isinstance(logger, logging.Logger):
84 | logger.log(level, msg)
85 | elif logger == 'silent':
86 | pass
87 | elif isinstance(logger, str):
88 | _logger = get_logger(logger)
89 | _logger.log(level, msg)
90 | else:
91 | raise TypeError(
92 | 'logger should be either a logging.Logger object, str, '
93 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/utils/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from .table_att_loss import TableAttentionLoss
3 |
4 |
5 | def build_loss(config):
6 | support_dict = ['TableAttentionLoss']
7 |
8 | config = copy.deepcopy(config)
9 | module_name = config.pop('name')
10 | assert module_name in support_dict, Exception('loss only support {}'.format(
11 | support_dict))
12 | module_class = eval(module_name)(**config)
13 | return module_class
14 |
--------------------------------------------------------------------------------
/utils/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/losses/__pycache__/table_att_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/losses/__pycache__/table_att_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/losses/table_att_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 |
10 | class TableAttentionLoss(nn.Module):
11 | def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
12 | super(TableAttentionLoss, self).__init__()
13 | self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
14 | self.structure_weight = structure_weight
15 | self.loc_weight = loc_weight
16 | self.use_giou = use_giou
17 | self.giou_weight = giou_weight
18 |
19 | def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
20 | '''
21 | :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
22 | :param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
23 | :return: loss
24 | '''
25 | # ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
26 | # iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
27 | # ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
28 | # iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
29 | ix1 = (preds[:, 0] >= bbox[:, 0]) * preds[:, 0] + (preds[:, 0] < bbox[:, 0]) * bbox[:, 0]
30 | iy1 = (preds[:, 1] >= bbox[:, 1]) * preds[:, 1] + (preds[:, 1] < bbox[:, 1]) * bbox[:, 1]
31 | ix2 = (preds[:, 2] >= bbox[:, 2]) * bbox[:, 2] + (preds[:, 2] < bbox[:, 2]) * preds[:, 2]
32 | iy2 = (preds[:, 3] >= bbox[:, 3]) * bbox[:, 3] + (preds[:, 3] < bbox[:, 3]) * preds[:, 3]
33 |
34 | iw = torch.clamp(ix2 - ix1 + 1e-3, 0., 1e10)
35 | ih = torch.clamp(iy2 - iy1 + 1e-3, 0., 1e10)
36 |
37 | # overlap
38 | inters = iw * ih
39 |
40 | # union
41 | uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
42 | ) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
43 | bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
44 |
45 | # ious
46 | ious = inters / uni
47 |
48 | # ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
49 | # ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
50 | # ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
51 | # ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
52 | ex1 = (preds[:, 0] >= bbox[:, 0]) * bbox[:, 0] + (preds[:, 0] < bbox[:, 0]) * preds[:, 0]
53 | ey1 = (preds[:, 1] >= bbox[:, 1]) * bbox[:, 1] + (preds[:, 1] < bbox[:, 1]) * preds[:, 1]
54 | ex2 = (preds[:, 2] >= bbox[:, 2]) * preds[:, 2] + (preds[:, 2] < bbox[:, 2]) * bbox[:, 2]
55 | ey2 = (preds[:, 3] >= bbox[:, 3]) * preds[:, 3] + (preds[:, 3] < bbox[:, 3]) * bbox[:, 3]
56 | ew = torch.clamp(ex2 - ex1 + 1e-3, 0., 1e10)
57 | eh = torch.clamp(ey2 - ey1 + 1e-3, 0., 1e10)
58 |
59 | # enclose erea
60 | enclose = ew * eh + eps
61 | giou = ious - (enclose - uni) / enclose
62 |
63 | loss = 1 - giou
64 |
65 | if reduction == 'mean':
66 | loss = torch.mean(loss)
67 | elif reduction == 'sum':
68 | loss = torch.sum(loss)
69 | else:
70 | raise NotImplementedError
71 | return loss
72 |
73 | def forward(self, predicts, batch):
74 | structure_probs = predicts['structure_probs']
75 | # structure_targets = batch[1].astype("int64")
76 | structure_targets = torch.tensor(batch[1], dtype=torch.int64)
77 | structure_targets = structure_targets[:, 1:]
78 | if len(batch) == 6:
79 | # structure_mask = batch[5].astype("int64")
80 | structure_mask = torch.tensor(batch[5], dtype=torch.int64)
81 | structure_mask = structure_mask[:, 1:]
82 | structure_mask = torch.reshape(structure_mask, [-1])
83 | structure_probs = torch.reshape(structure_probs, [-1, structure_probs.shape[-1]])
84 | structure_targets = torch.reshape(structure_targets, [-1])
85 | structure_loss = self.loss_func(structure_probs, structure_targets)
86 |
87 | if len(batch) == 6:
88 | structure_loss = structure_loss * structure_mask
89 |
90 | # structure_loss = paddle.sum(structure_loss) * self.structure_weight
91 | structure_loss = torch.mean(structure_loss) * self.structure_weight
92 |
93 | loc_preds = predicts['loc_preds']
94 | # loc_targets = batch[2].astype("float32")
95 | loc_targets = torch.tensor(batch[2], dtype=torch.float32)
96 | # loc_targets_mask = batch[4].astype("float32")
97 | loc_targets_mask = torch.tensor(batch[4], dtype=torch.float32)
98 | loc_targets = loc_targets[:, 1:, :]
99 | loc_targets_mask = loc_targets_mask[:, 1:, :]
100 | loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
101 | if self.use_giou:
102 | loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
103 | total_loss = structure_loss + loc_loss + loc_loss_giou
104 | return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
105 | else:
106 | total_loss = structure_loss + loc_loss
107 | return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
--------------------------------------------------------------------------------
/utils/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 | from __future__ import unicode_literals
19 |
20 | import copy
21 |
22 | __all__ = ["build_metric"]
23 |
24 | from .table_metric import TableMetric
25 |
26 |
27 | def build_metric(config):
28 | support_dict = [
29 | "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
30 | ]
31 |
32 | config = copy.deepcopy(config)
33 | module_name = config.pop("name")
34 | assert module_name in support_dict, Exception(
35 | "metric only support {}".format(support_dict))
36 | module_class = eval(module_name)(**config)
37 | return module_class
38 |
--------------------------------------------------------------------------------
/utils/metrics/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/metrics/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/metrics/__pycache__/table_metric.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/metrics/__pycache__/table_metric.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/metrics/table_metric.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import numpy as np
15 |
16 |
17 | class TableMetric(object):
18 | def __init__(self, main_indicator='acc', **kwargs):
19 | self.main_indicator = main_indicator
20 | self.reset()
21 |
22 | def __call__(self, pred, batch, *args, **kwargs):
23 | structure_probs = pred['structure_probs'] # .detach().numpy()
24 | structure_labels = batch[1]
25 | correct_num = 0
26 | all_num = 0
27 | structure_probs = np.argmax(structure_probs, axis=2)
28 | structure_labels = structure_labels[:, 1:]
29 | batch_size = structure_probs.shape[0]
30 | for bno in range(batch_size):
31 | all_num += 1
32 | if (structure_probs[bno] == structure_labels[bno]).all():
33 | correct_num += 1
34 | self.correct_num += correct_num
35 | self.all_num += all_num
36 | return {
37 | 'acc': correct_num * 1.0 / all_num,
38 | }
39 |
40 | def get_metric(self):
41 | """
42 | return metrics {
43 | 'acc': 0,
44 | }
45 | """
46 | acc = 1.0 * self.correct_num / self.all_num
47 | self.reset()
48 | return {'acc': acc}
49 |
50 | def reset(self):
51 | self.correct_num = 0
52 | self.all_num = 0
53 |
--------------------------------------------------------------------------------
/utils/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 | import copy
6 |
7 | __all__ = ['build_optimizer']
8 |
9 |
10 | def build_optimizer(config, epochs, step_each_epoch, parameters):
11 | from . import optimizer
12 | config = copy.deepcopy(config)
13 | # # step1 build lr
14 | lr = config.pop('lr')['learning_rate']
15 |
16 | # step2 build optimizer
17 | optim_name = config.pop('name')
18 | optim = getattr(optimizer, optim_name)(learning_rate=lr, **config)
19 | return optim(parameters), lr
20 |
--------------------------------------------------------------------------------
/utils/optimizer/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/optimizer/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/optimizer/__pycache__/optimizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/optimizer/__pycache__/optimizer.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/optimizer/optimizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | from torch import optim
7 |
8 |
9 | class Momentum(object):
10 | """
11 | Simple Momentum optimizer with velocity state.
12 | Args:
13 | learning_rate (float|Variable) - The learning rate used to update parameters.
14 | Can be a float value or a Variable with one float value as data element.
15 | momentum (float) - Momentum factor.
16 | regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
17 | """
18 |
19 | def __init__(self,
20 | learning_rate,
21 | momentum,
22 | weight_decay=None,
23 | grad_clip=None,
24 | **args):
25 | super(Momentum, self).__init__()
26 | self.learning_rate = learning_rate
27 | self.momentum = momentum
28 | self.weight_decay = weight_decay
29 | self.grad_clip = grad_clip
30 |
31 | def __call__(self, parameters):
32 | opt = optim.Momentum(
33 | learning_rate=self.learning_rate,
34 | momentum=self.momentum,
35 | weight_decay=self.weight_decay,
36 | grad_clip=self.grad_clip,
37 | parameters=parameters)
38 | return opt
39 |
40 |
41 | class Adam(object):
42 | def __init__(self,
43 | learning_rate=0.001,
44 | beta1=0.9,
45 | beta2=0.999,
46 | epsilon=1e-08,
47 | parameter_list=None,
48 | name=None,
49 | lazy_mode=False,
50 | **kwargs):
51 | self.learning_rate = learning_rate
52 | self.beta1 = beta1
53 | self.beta2 = beta2
54 | self.epsilon = epsilon
55 | self.parameter_list = parameter_list
56 | self.learning_rate = learning_rate
57 | self.name = name
58 | self.lazy_mode = lazy_mode
59 |
60 | def __call__(self, parameters):
61 | opt = optim.Adam(
62 | lr=self.learning_rate,
63 | betas=(self.beta1, self.beta2),
64 | eps=self.epsilon,
65 | amsgrad=self.lazy_mode,
66 | params=parameters)
67 | return opt
68 |
69 |
70 | class RMSProp(object):
71 | """
72 | Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
73 | Args:
74 | learning_rate (float|Variable) - The learning rate used to update parameters.
75 | Can be a float value or a Variable with one float value as data element.
76 | momentum (float) - Momentum factor.
77 | rho (float) - rho value in equation.
78 | epsilon (float) - avoid division by zero, default is 1e-6.
79 | regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
80 | """
81 |
82 | def __init__(self,
83 | learning_rate,
84 | momentum=0.0,
85 | rho=0.95,
86 | epsilon=1e-6,
87 | weight_decay=None,
88 | grad_clip=None,
89 | **args):
90 | super(RMSProp, self).__init__()
91 | self.learning_rate = learning_rate
92 | self.momentum = momentum
93 | self.rho = rho
94 | self.epsilon = epsilon
95 | self.weight_decay = weight_decay
96 | self.grad_clip = grad_clip
97 |
98 | def __call__(self, parameters):
99 | opt = optim.RMSProp(
100 | learning_rate=self.learning_rate,
101 | momentum=self.momentum,
102 | rho=self.rho,
103 | epsilon=self.epsilon,
104 | weight_decay=self.weight_decay,
105 | grad_clip=self.grad_clip,
106 | parameters=parameters)
107 | return opt
108 |
--------------------------------------------------------------------------------
/utils/postprocess/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | import copy
7 |
8 | __all__ = ['build_post_process']
9 |
10 | from .rec_postprocess import TableLabelDecode
11 |
12 |
13 | def build_post_process(config, global_config=None):
14 | support_dict = ['TableLabelDecode']
15 |
16 | config = copy.deepcopy(config)
17 | module_name = config.pop('name')
18 | if global_config is not None:
19 | config.update(global_config)
20 | assert module_name in support_dict, Exception(
21 | 'post process only support {}'.format(support_dict))
22 | module_class = eval(module_name)(**config)
23 | return module_class
24 |
--------------------------------------------------------------------------------
/utils/postprocess/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/postprocess/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/postprocess/__pycache__/rec_postprocess.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/June-Li/TableRecognition/a26956983ade1153ee91f9a92fface1a2cf1275e/utils/postprocess/__pycache__/rec_postprocess.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/postprocess/rec_postprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import string
3 | import torch
4 |
5 |
6 | class TableLabelDecode(object):
7 | """ """
8 | def __init__(self,
9 | character_dict_path,
10 | **kwargs):
11 | list_character, list_elem = self.load_char_elem_dict(character_dict_path)
12 | list_character = self.add_special_char(list_character)
13 | list_elem = self.add_special_char(list_elem)
14 | self.dict_character = {}
15 | self.dict_idx_character = {}
16 | for i, char in enumerate(list_character):
17 | self.dict_idx_character[i] = char
18 | self.dict_character[char] = i
19 | self.dict_elem = {}
20 | self.dict_idx_elem = {}
21 | for i, elem in enumerate(list_elem):
22 | self.dict_idx_elem[i] = elem
23 | self.dict_elem[elem] = i
24 |
25 | def load_char_elem_dict(self, character_dict_path):
26 | list_character = []
27 | list_elem = []
28 | with open(character_dict_path, "rb") as fin:
29 | lines = fin.readlines()
30 | substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
31 | character_num = int(substr[0])
32 | elem_num = int(substr[1])
33 | for cno in range(1, 1 + character_num):
34 | character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
35 | list_character.append(character)
36 | for eno in range(1 + character_num, 1 + character_num + elem_num):
37 | elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
38 | list_elem.append(elem)
39 | return list_character, list_elem
40 |
41 | def add_special_char(self, list_character):
42 | self.beg_str = "sos"
43 | self.end_str = "eos"
44 | list_character = [self.beg_str] + list_character + [self.end_str]
45 | return list_character
46 |
47 | def __call__(self, preds):
48 | structure_probs = preds['structure_probs']
49 | loc_preds = preds['loc_preds']
50 | if isinstance(structure_probs, torch.Tensor):
51 | structure_probs = structure_probs.numpy()
52 | if isinstance(loc_preds, torch.Tensor):
53 | loc_preds = loc_preds.numpy()
54 | structure_idx = structure_probs.argmax(axis=2)
55 | structure_probs = structure_probs.max(axis=2)
56 | structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
57 | structure_probs, 'elem')
58 | res_html_code_list = []
59 | res_loc_list = []
60 | batch_num = len(structure_str)
61 | for bno in range(batch_num):
62 | res_loc = []
63 | for sno in range(len(structure_str[bno])):
64 | text = structure_str[bno][sno]
65 | if text in ['', ' | 0 and tmp_elem_idx == end_idx:
98 | break
99 | if tmp_elem_idx in ignored_tokens:
100 | continue
101 |
102 | char_list.append(current_dict[tmp_elem_idx])
103 | elem_pos_list.append(idx)
104 | score_list.append(structure_probs[batch_idx, idx])
105 | elem_idx_list.append(tmp_elem_idx)
106 | result_list.append(char_list)
107 | result_pos_list.append(elem_pos_list)
108 | result_score_list.append(score_list)
109 | result_elem_idx_list.append(elem_idx_list)
110 | return result_list, result_pos_list, result_score_list, result_elem_idx_list
111 |
112 | def get_ignored_tokens(self, char_or_elem):
113 | beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
114 | end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
115 | return [beg_idx, end_idx]
116 |
117 | def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
118 | if char_or_elem == "char":
119 | if beg_or_end == "beg":
120 | idx = self.dict_character[self.beg_str]
121 | elif beg_or_end == "end":
122 | idx = self.dict_character[self.end_str]
123 | else:
124 | assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
125 | % beg_or_end
126 | elif char_or_elem == "elem":
127 | if beg_or_end == "beg":
128 | idx = self.dict_elem[self.beg_str]
129 | elif beg_or_end == "end":
130 | idx = self.dict_elem[self.end_str]
131 | else:
132 | assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
133 | % beg_or_end
134 | else:
135 | assert False, "Unsupport type %s in char_or_elem" \
136 | % char_or_elem
137 | return idx
138 |
--------------------------------------------------------------------------------
/utils/save_load.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/01/25 09:52
3 | # @Author : lijun
4 |
5 | import torch
6 |
7 |
8 | # def load_model(model, model_path):
9 | # model.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict'])
10 | #
11 | #
12 | # def save_model(save_dict, save_path):
13 | # torch.save(save_dict, save_path)
14 |
15 |
16 | def load_model(model, model_path):
17 | state_dict = {}
18 | for k, v in torch.load(model_path, map_location='cpu')['state_dict'].items():
19 | state_dict[k.replace('module.', '')] = v
20 | model.load_state_dict(state_dict)
21 |
22 |
23 | def save_model(save_dict, save_path):
24 | state_dict = {}
25 | for k, v in save_dict['state_dict'].items():
26 | state_dict[k.replace('module.', '')] = v
27 | save_dict['state_dict'] = state_dict
28 | torch.save(save_dict, save_path)
29 |
--------------------------------------------------------------------------------
/utils/stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import collections
16 | import numpy as np
17 | import datetime
18 |
19 | __all__ = ['TrainingStats', 'Time']
20 |
21 |
22 | class SmoothedValue(object):
23 | """Track a series of values and provide access to smoothed values over a
24 | window or the global series average.
25 | """
26 |
27 | def __init__(self, window_size):
28 | self.deque = collections.deque(maxlen=window_size)
29 |
30 | def add_value(self, value):
31 | self.deque.append(value)
32 |
33 | def get_median_value(self):
34 | return np.median(self.deque)
35 |
36 |
37 | def Time():
38 | return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
39 |
40 |
41 | class TrainingStats(object):
42 | def __init__(self, window_size, stats_keys):
43 | self.window_size = window_size
44 | self.smoothed_losses_and_metrics = {
45 | key: SmoothedValue(window_size)
46 | for key in stats_keys
47 | }
48 |
49 | def update(self, stats):
50 | for k, v in stats.items():
51 | if k not in self.smoothed_losses_and_metrics:
52 | self.smoothed_losses_and_metrics[k] = SmoothedValue(
53 | self.window_size)
54 | self.smoothed_losses_and_metrics[k].add_value(v)
55 |
56 | def get(self, extras=None):
57 | stats = collections.OrderedDict()
58 | if extras:
59 | for k, v in extras.items():
60 | stats[k] = v
61 | for k, v in self.smoothed_losses_and_metrics.items():
62 | stats[k] = round(v.get_median_value(), 6)
63 |
64 | return stats
65 |
66 | def log(self, extras=None):
67 | d = self.get(extras)
68 | strs = []
69 | for k, v in d.items():
70 | strs.append('{}: {:x<6f}'.format(k, v))
71 | strs = ', '.join(strs)
72 | return strs
73 |
--------------------------------------------------------------------------------
/utils/utility.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import imghdr
4 | import cv2
5 | import yaml
6 | import numpy as np
7 | from utils.torch_utils import select_device
8 | from argparse import ArgumentParser, RawDescriptionHelpFormatter
9 | from utils.logging import get_logger
10 | import torch
11 |
12 |
13 | def print_dict(d, logger, delimiter=0):
14 | """
15 | Recursively visualize a dict and
16 | indenting acrrording by the relationship of keys.
17 | """
18 | for k, v in sorted(d.items()):
19 | if isinstance(v, dict):
20 | logger.info("{}{} : ".format(delimiter * " ", str(k)))
21 | print_dict(v, logger, delimiter + 4)
22 | elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
23 | logger.info("{}{} : ".format(delimiter * " ", str(k)))
24 | for value in v:
25 | print_dict(value, logger, delimiter + 4)
26 | else:
27 | logger.info("{}{} : {}".format(delimiter * " ", k, v))
28 |
29 |
30 | def get_check_global_params(mode):
31 | check_params = ['use_gpu', 'max_text_length', 'image_shape', \
32 | 'image_shape', 'character_type', 'loss_type']
33 | if mode == "train_eval":
34 | check_params = check_params + [ \
35 | 'train_batch_size_per_card', 'test_batch_size_per_card']
36 | elif mode == "test":
37 | check_params = check_params + ['test_batch_size_per_card']
38 | return check_params
39 |
40 |
41 | def get_image_file_list(img_file):
42 | imgs_lists = []
43 | if img_file is None or not os.path.exists(img_file):
44 | raise Exception("not found any img file in {}".format(img_file))
45 |
46 | img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'}
47 | if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
48 | imgs_lists.append(img_file)
49 | elif os.path.isdir(img_file):
50 | for single_file in os.listdir(img_file):
51 | file_path = os.path.join(img_file, single_file)
52 | if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
53 | imgs_lists.append(file_path)
54 | if len(imgs_lists) == 0:
55 | raise Exception("not found any img file in {}".format(img_file))
56 | imgs_lists = sorted(imgs_lists)
57 | return imgs_lists
58 |
59 |
60 | def check_and_read_gif(img_path):
61 | if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
62 | gif = cv2.VideoCapture(img_path)
63 | ret, frame = gif.read()
64 | if not ret:
65 | logger = logging.getLogger('ppocr')
66 | logger.info("Cannot read {}. This gif image maybe corrupted.")
67 | return None, False
68 | if len(frame.shape) == 2 or frame.shape[-1] == 1:
69 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
70 | imgvalue = frame[:, :, ::-1]
71 | return imgvalue, True
72 | return None, False
73 |
74 |
75 | def get_rotate_crop_image(img, points):
76 | '''
77 | img_height, img_width = img.shape[0:2]
78 | left = int(np.min(points[:, 0]))
79 | right = int(np.max(points[:, 0]))
80 | top = int(np.min(points[:, 1]))
81 | bottom = int(np.max(points[:, 1]))
82 | img_crop = img[top:bottom, left:right, :].copy()
83 | points[:, 0] = points[:, 0] - left
84 | points[:, 1] = points[:, 1] - top
85 | '''
86 | assert len(points) == 4, "shape of points must be 4*2"
87 | img_crop_width = int(
88 | max(
89 | np.linalg.norm(points[0] - points[1]),
90 | np.linalg.norm(points[2] - points[3])))
91 | img_crop_height = int(
92 | max(
93 | np.linalg.norm(points[0] - points[3]),
94 | np.linalg.norm(points[1] - points[2])))
95 | pts_std = np.float32([[0, 0], [img_crop_width, 0],
96 | [img_crop_width, img_crop_height],
97 | [0, img_crop_height]])
98 | M = cv2.getPerspectiveTransform(points, pts_std)
99 | dst_img = cv2.warpPerspective(
100 | img,
101 | M, (img_crop_width, img_crop_height),
102 | borderMode=cv2.BORDER_REPLICATE,
103 | flags=cv2.INTER_CUBIC)
104 | dst_img_height, dst_img_width = dst_img.shape[0:2]
105 | if dst_img_height * 1.0 / dst_img_width >= 1.5:
106 | dst_img = np.rot90(dst_img)
107 | return dst_img
108 |
109 |
110 | def preprocess(is_train=False):
111 | FLAGS = ArgsParser().parse_args()
112 | config = load_config(FLAGS.config)
113 | merge_config(FLAGS.opt)
114 |
115 | # check if set use_gpu=True in paddlepaddle cpu version
116 | use_gpu = config['Global']['use_gpu']
117 | # check_gpu(use_gpu)
118 |
119 | alg = config['Architecture']['algorithm']
120 | assert alg in ['TableAttn']
121 |
122 | # device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
123 | # device = paddle.set_device(device)
124 | device = select_device(use_gpu)
125 |
126 | # config['Global']['distributed'] = dist.get_world_size() != 1
127 | if is_train:
128 | # save_config
129 | save_model_dir = config['Global']['save_model_dir']
130 | os.makedirs(save_model_dir, exist_ok=True)
131 | with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
132 | yaml.dump(
133 | dict(config), f, default_flow_style=False, sort_keys=False)
134 | log_file = '{}/train.log'.format(save_model_dir)
135 | else:
136 | log_file = None
137 | logger = get_logger(name='root', log_file=log_file)
138 |
139 | print_dict(config, logger)
140 | logger.info('train with paddle {} and device {}'.format(torch.__version__, device))
141 | return config, device, logger
142 |
143 |
144 | class ArgsParser(ArgumentParser):
145 | def __init__(self):
146 | super(ArgsParser, self).__init__(
147 | formatter_class=RawDescriptionHelpFormatter)
148 | self.add_argument("-c", "--config", help="configuration file to use")
149 | self.add_argument(
150 | "-o", "--opt", nargs='+', help="set configuration options")
151 |
152 | def parse_args(self, argv=None):
153 | args = super(ArgsParser, self).parse_args(argv)
154 | assert args.config is not None, \
155 | "Please specify --config=configure_file_path."
156 | args.opt = self._parse_opt(args.opt)
157 | return args
158 |
159 | def _parse_opt(self, opts):
160 | config = {}
161 | if not opts:
162 | return config
163 | for s in opts:
164 | s = s.strip()
165 | k, v = s.split('=')
166 | config[k] = yaml.load(v, Loader=yaml.Loader)
167 | return config
168 |
169 |
170 | class AttrDict(dict):
171 | """Single level attribute dict, NOT recursive"""
172 |
173 | def __init__(self, **kwargs):
174 | super(AttrDict, self).__init__()
175 | super(AttrDict, self).update(kwargs)
176 |
177 | def __getattr__(self, key):
178 | if key in self:
179 | return self[key]
180 | raise AttributeError("object has no attribute '{}'".format(key))
181 |
182 |
183 | global_config = AttrDict()
184 | default_config = {'Global': {'debug': False, }}
185 |
186 |
187 | def load_config(file_path):
188 | """
189 | Load config from yml/yaml file.
190 | Args:
191 | file_path (str): Path of the config file to be loaded.
192 | Returns: global config
193 | """
194 | merge_config(default_config)
195 | _, ext = os.path.splitext(file_path)
196 | assert ext in ['.yml', '.yaml'], "only support yaml files for now"
197 | merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
198 | return global_config
199 |
200 |
201 | def merge_config(config):
202 | """
203 | Merge config into global config.
204 | Args:
205 | config (dict): Config to be merged.
206 | Returns: global config
207 | """
208 | for key, value in config.items():
209 | if "." not in key:
210 | if isinstance(value, dict) and key in global_config:
211 | global_config[key].update(value)
212 | else:
213 | global_config[key] = value
214 | else:
215 | sub_keys = key.split('.')
216 | assert (
217 | sub_keys[0] in global_config
218 | ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
219 | global_config.keys(), sub_keys[0])
220 | cur = global_config[sub_keys[0]]
221 | for idx, sub_key in enumerate(sub_keys[1:]):
222 | if idx == len(sub_keys) - 2:
223 | cur[sub_key] = value
224 | else:
225 | cur = cur[sub_key]
226 |
227 |
--------------------------------------------------------------------------------
|