├── .gitignore ├── README.md ├── config ├── lite_ocr.yml ├── load_conf.py ├── predict │ ├── det.yml │ └── rec.yml └── train │ ├── det.yml │ └── rec.yml ├── data_loader ├── __init__.py ├── data_det │ ├── test.txt │ └── test │ │ ├── 00001.jpg │ │ ├── 00040.jpg │ │ ├── 00084.jpg │ │ ├── 00159.jpg │ │ └── 00172.jpg ├── data_rec │ ├── chinese_chars_6695.json │ ├── train.txt │ └── train │ │ ├── 001.jpg │ │ ├── 002.jpg │ │ ├── 003.jpg │ │ ├── 004.jpg │ │ ├── 005.jpg │ │ ├── 006.jpg │ │ ├── 007.jpg │ │ ├── 008.jpg │ │ ├── 009.jpg │ │ ├── 010.jpg │ │ ├── 011.jpg │ │ ├── 012.jpg │ │ ├── 013.jpg │ │ ├── 014.jpg │ │ ├── 015.jpg │ │ ├── 016.jpg │ │ ├── 017.jpg │ │ ├── 018.jpg │ │ ├── 019.jpg │ │ ├── 020.jpg │ │ ├── 021.jpg │ │ ├── 022.jpg │ │ ├── 023.jpg │ │ ├── 024.jpg │ │ ├── 025.jpg │ │ ├── 026.jpg │ │ ├── 027.jpg │ │ ├── 028.jpg │ │ ├── 029.jpg │ │ └── 030.jpg ├── det_dataset.py ├── img_aug │ ├── __init__.py │ ├── make_binary_map.py │ ├── make_threshold_map.py │ ├── operators.py │ ├── random_crop_data.py │ ├── rec_img_aug.py │ └── text_image_aug │ │ ├── __init__.py │ │ ├── augment.py │ │ └── warp_mls.py └── rec_dataset.py ├── lite_ocr.py ├── logger ├── __init__.py ├── log_conf.py └── logger.py ├── losses ├── __init__.py ├── ctc_loss.py ├── det_loss.py └── loss.py ├── metrics ├── __init__.py ├── det_metric.py ├── eval_det_iou.py └── rec_metric.py ├── nets ├── __init__.py ├── det │ ├── __init__.py │ ├── dbnet.py │ ├── mobilenetv3.py │ ├── params_mapping.py │ └── resnet.py └── rec │ ├── __init__.py │ ├── mobilenet_v3.py │ └── rnn.py ├── optimizer ├── __init__.py ├── learning_rate.py └── optim.py ├── postprocess ├── __init__.py ├── det_postprocess.py └── rec_postprocess.py ├── predict.py ├── requirements.txt ├── train.py └── utils ├── __init__.py └── string_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pretrained_models/ 2 | .idea/ 3 | *.log 4 | __pycache__/ 5 | output/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##### 简介 2 | ocr_torch是基于Torch1.8实现的DBNet(2.2M) + CRNN(3.8M)实现的轻量级文字检测识别项目(支持onnx推理). 3 | 4 | ##### 项目环境 5 | - linux 6 | 7 | - python3.7 8 | 9 | - Torch1.8 10 | 11 | ##### 文本检测模型DBNET 12 | 13 | 采用mobilenetV3 large作为骨干网络实现 14 | 15 | * 训练启动脚本 16 | ``` 17 | python -m torch.distributed.launch train.py -c config/train/det.yml 18 | ``` 19 | 20 | * 测试启动脚本 21 | ``` 22 | python predict.py -c config/predict/det.yml 23 | ``` 24 | 25 | ##### 文本识别模型CRNN 26 | 27 | 采用mobilenetV3 small作为骨干网络实现 28 | 29 | * 训练启动脚本 30 | ``` 31 | python -m torch.distributed.launch train.py -c config/train/rec.yml 32 | ``` 33 | 34 | * 测试测试脚本 35 | ``` 36 | python predict.py -c config/predict/rec.yml 37 | ``` 38 | 39 | ##### 文本检测识别合并推理 40 | 41 | * 训练推理脚本 42 | ``` 43 | python lite_ocr.py -c config/lite_ocr.yml 44 | ``` 45 | 46 | ##### 主要参考文献及源码 47 | 1. DB [https://github.com/MhLiao/DB](https://github.com/MhLiao/DB) 48 | 2. PaddleOCR [https://github.com/PaddlePaddle/PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) 49 | 3. DBNET.pytorch [https://github.com/WenmuZhou/DBNet.pytorch](https://github.com/WenmuZhou/DBNet.pytorch) 50 | 4. Paper [https://arxiv.org/pdf/1911.08947.pdf](https://arxiv.org/pdf/1911.08947.pdf) -------------------------------------------------------------------------------- /config/lite_ocr.yml: -------------------------------------------------------------------------------- 1 | global: 2 | infer_det_path: ./output/model_det/dbnet.onnx 3 | infer_rec_path: ./output/model_rec/crnn.onnx 4 | res_save_dir: ./output/result_liteocr 5 | image_dir_or_path: /data/projects/task/ocr_torch/data_loader/data_det/test 6 | character_json_path: ./data_loader/data_rec/chinese_chars_6695.json 7 | 8 | det: 9 | post_process: 10 | name: DBPostProcess 11 | thresh: 0.3 12 | box_thresh: 0.7 13 | max_candidates: 1000 14 | unclip_ratio: 1.6 15 | transforms: 16 | - ResizeForTest: 17 | long_size: 960 18 | - NormalizeImage: 19 | - OutputData: 20 | keep_keys: ["image", "src_scale"] 21 | 22 | rec: 23 | post_process: 24 | name: CRnnPostProcess 25 | -------------------------------------------------------------------------------- /config/load_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import codecs 5 | 6 | 7 | class ReadConfig(object): 8 | def __init__(self, yml_path): 9 | self._yml_path = yml_path 10 | self.base_conf = self._read_yml_conf() 11 | self._complement_conf() 12 | 13 | def _read_yml_conf(self): 14 | with codecs.open(self._yml_path, "r", "utf8") as f: 15 | conf = yaml.load(f.read(), Loader=yaml.FullLoader) 16 | return conf 17 | 18 | def _complement_conf(self): 19 | character_json_path = self.base_conf["global"].get("character_json_path", "") 20 | if not character_json_path: 21 | return 22 | if not os.path.exists(character_json_path): 23 | raise Exception("path {} not exists".format(character_json_path)) 24 | 25 | try: 26 | with codecs.open(character_json_path, "r", "utf8") as f: 27 | char2idx = json.loads(f.read()) 28 | except Exception as e: 29 | raise e 30 | if "" not in char2idx.keys(): 31 | raise Exception("keys is not found!") 32 | if "model_det" in self.base_conf.keys(): 33 | self.base_conf["model_det"]["classes_num"] = len(char2idx) 34 | if "post_process" in self.base_conf.keys(): 35 | self.base_conf["post_process"]["character_json_path"] = character_json_path 36 | -------------------------------------------------------------------------------- /config/predict/det.yml: -------------------------------------------------------------------------------- 1 | global: 2 | yml_type: DET # 配置类型, 不可更改, 可选值为DET, REC, 分别代表检测和识别 3 | train_model_path: ./output/model_finetune/latest.pth 4 | infer_model_path: ./output/model_finetune/latest.onnx 5 | res_save_dir: ./output/result_det 6 | use_infer_model: true 7 | 8 | 9 | model: 10 | name: DBNet 11 | inner_channel: 96 12 | k: 50 13 | backbone: 14 | name: det_mobilenet_v3 # 轻量级backbone 15 | pre_trained_dir: # 默认没有预训练模型, 因为官网提供的mobilenet的multiplier参数为1 16 | multiplier: 0.5 # 可选择为0.35, 0.5, 0.75, 1.0, 1.25 17 | use_se: false 18 | 19 | post_process: 20 | name: DBPostProcess 21 | thresh: 0.3 22 | box_thresh: 0.7 23 | max_candidates: 1000 24 | unclip_ratio: 1.6 25 | 26 | dataset: 27 | image_dir_or_path: /data/projects/task/ocr_torch/data_loader/data_det/test 28 | transforms: 29 | - ResizeForTest: 30 | long_size: 960 31 | - NormalizeImage: 32 | - OutputData: 33 | keep_keys: ["image", "src_scale"] 34 | -------------------------------------------------------------------------------- /config/predict/rec.yml: -------------------------------------------------------------------------------- 1 | global: 2 | yml_type: REC # 配置类型, 不可更改, 可选值为DET, REC, 分别代表检测和识别 3 | train_model_path: ./output/model_rec/latest.pth 4 | infer_model_path: ./output/model_rec/crnn.onnx 5 | res_save_dir: ./output/result_rec 6 | use_infer_model: true 7 | character_json_path: ./data_loader/data_rec/chinese_chars_6695.json 8 | 9 | 10 | model: 11 | name: CRNN 12 | rnn_type: GRU # 可选值为GRU, LSTM 13 | hidden_size: 48 14 | num_layers: 2 15 | bidirectional: true 16 | backbone: 17 | name: rec_mobilenet_v3 # 轻量级backbone 18 | pre_trained_dir: # 默认没有预训练模型, 因为官网提供的mobilenet的multiplier参数为1 19 | multiplier: 0.5 # 可选择为0.35, 0.5, 0.75, 1.0, 1.25 20 | use_se: false 21 | 22 | post_process: 23 | name: CRnnPostProcess 24 | 25 | dataset: 26 | image_dir_or_path: /data/projects/task/ocr_torch/data_loader/data_rec/train 27 | transforms: 28 | - RecResizeImg: 29 | image_shape: [3, 32, 320] 30 | - OutputData: 31 | keep_keys: ['image'] -------------------------------------------------------------------------------- /config/train/det.yml: -------------------------------------------------------------------------------- 1 | global: 2 | yml_type: DET # 配置类型, 不可更改, 可选值为DET, REC, 分别代表检测和识别 3 | use_gpu: false # 是否使用GPU 4 | epochs: 200 # 训练总轮次 5 | eval_epoch: 2 # 评估间隔轮次 6 | save_pth_dir: ./output/model_det/ # pytorch模型保存地址 7 | init_pth_path: # 初始化模型地址 8 | log_iter: 2 # 每隔几个iter输出一次日志 9 | save_epoch_iter: 2 # 每隔几个epochs保存一次模型 10 | 11 | model: 12 | name: DBNet 13 | inner_channel: 96 14 | k: 50 15 | backbone: 16 | name: det_mobilenet_v3 # 轻量级backbone 17 | pre_trained_dir: # 默认没有预训练模型, 因为官网提供的mobilenet的multiplier参数为1 18 | multiplier: 0.5 # 可选择为0.35, 0.5, 0.75, 1.0, 1.25 19 | use_se: false 20 | # backbone: 21 | # name: resnet18 # 可选值为resnet18, resnet34, resnet50, resnet101, resnet152 22 | # pre_trained_dir: 23 | 24 | 25 | loss: 26 | name: L1BalanceCELoss # balance loss + l1 loss + dice loss 27 | eps: 0.000001 28 | l1_scale: 10 29 | bce_scale: 5 30 | negative_ratio: 3.0 31 | 32 | 33 | optimizer: 34 | name: OptimizerScheduler 35 | optim_method: _adam 36 | init_learning_rate: 0.001 37 | learning_schedule: 38 | name: LearningSchedule 39 | lr_method: _cosine_warmup 40 | warmup_epoch: 2 41 | 42 | metrics: 43 | name: DetMetric 44 | main_indicator: hmean 45 | 46 | post_process: 47 | name: DBPostProcess 48 | thresh: 0.3 49 | box_thresh: 0.7 50 | max_candidates: 1000 51 | unclip_ratio: 1.6 52 | 53 | train: 54 | dataset: 55 | name: DetDataSet 56 | data_base_dir: /data/projects/task/ocr_torch/data_loader/data_det/ 57 | ano_file_path: /data/projects/task/ocr_torch/data_loader/data_det/test.txt 58 | do_shuffle: true 59 | transforms: 60 | - IaaAugment: 61 | flip_prob: 0.5 62 | affine_rotate: [-5, 5] 63 | resize_scale: [0.5, 1.5] 64 | - RandomCropData: 65 | size: [960, 960] 66 | max_tries: 10 67 | min_crop_side_ratio: 0.1 68 | keep_ratio: true 69 | - MakeBorderMap: 70 | shrink_ratio: 0.4 71 | thresh_min: 0.3 72 | thresh_max: 0.7 73 | - MakeProbMap: 74 | min_text_size: 4 75 | shrink_ratio: 0.4 76 | - NormalizeImage: 77 | - OutputData: 78 | keep_keys: ['image', 'thresh_map', 'thresh_mask', 'prob_map', 'prob_mask'] 79 | dataloader: 80 | batch_size: 2 81 | num_workers: 0 82 | drop_last: true 83 | pin_memory: true 84 | 85 | 86 | validate: 87 | dataset: 88 | name: DetDataSet 89 | data_base_dir: /data/projects/task/ocr_torch/data_loader/data_det/ 90 | ano_file_path: /data/projects/task/ocr_torch/data_loader/data_det/test.txt 91 | do_shuffle: false 92 | transforms: 93 | - ResizeForTest: 94 | long_size: 960 95 | - NormalizeImage: 96 | - OutputData: 97 | keep_keys: ["image", "src_scale", "polys", "ignore_tags"] 98 | dataloader: 99 | batch_size: 1 100 | num_workers: 0 101 | drop_last: false 102 | pin_memory: false 103 | -------------------------------------------------------------------------------- /config/train/rec.yml: -------------------------------------------------------------------------------- 1 | global: 2 | yml_type: REC # 配置类型, 不可更改, 可选值为DET, REC, 分别代表检测和识别 3 | use_gpu: false # 是否使用GPU 4 | epochs: 200 5 | eval_epoch: 2 6 | save_pth_dir: ./output/model_rec/ # pytorch模型保存地址 7 | init_pth_path: 8 | log_iter: 2 # 每隔几个iter输出一次日志 9 | save_epoch_iter: 2 # 每隔几个epochs保存一次模型 10 | max_text_len: 25 # 字符最大長度 11 | character_json_path: ./data_loader/data_rec/chinese_chars_6695.json 12 | 13 | model: 14 | name: CRNN 15 | rnn_type: GRU # 可选值为GRU, LSTM 16 | hidden_size: 48 17 | num_layers: 2 18 | bidirectional: true 19 | backbone: 20 | name: rec_mobilenet_v3 # 轻量级backbone 21 | pre_trained_dir: # 默认没有预训练模型, 因为官网提供的mobilenet的multiplier参数为1 22 | multiplier: 0.5 # 可选择为0.35, 0.5, 0.75, 1.0, 1.25 23 | use_se: false 24 | 25 | loss: 26 | name: CTCLoss 27 | 28 | optimizer: 29 | name: OptimizerScheduler 30 | optim_method: _adam 31 | init_learning_rate: 0.001 32 | learning_schedule: 33 | name: LearningSchedule 34 | lr_method: _cosine_warmup 35 | warmup_epoch: 0 36 | 37 | metrics: 38 | name: RecMetric 39 | main_indicator: acc 40 | 41 | post_process: 42 | name: CRnnPostProcess 43 | 44 | train: 45 | dataset: 46 | name: RecDataSet 47 | data_base_dir: /data/projects/task/ocr_torch/data_loader/data_rec/ 48 | ano_file_path: /data/projects/task/ocr_torch/data_loader/data_rec/train.txt 49 | do_shuffle: true 50 | transforms: 51 | - RecAug: 52 | aug_prob: 0.4 53 | - RecResizeImg: 54 | image_shape: [3, 32, 320] 55 | - OutputData: 56 | keep_keys: ['image', 'sequence_length', 'label_idx'] 57 | dataloader: 58 | batch_size: 2 59 | num_workers: 0 60 | drop_last: true 61 | pin_memory: true 62 | 63 | 64 | validate: 65 | dataset: 66 | name: RecDataSet 67 | data_base_dir: /data/projects/task/ocr_torch/data_loader/data_rec/ 68 | ano_file_path: /data/projects/task/ocr_torch/data_loader/data_rec/train.txt 69 | do_shuffle: false 70 | transforms: 71 | - RecResizeImg: 72 | image_shape: [3, 32, 320] 73 | - OutputData: 74 | keep_keys: ['image', 'sequence_length', 'label_idx'] 75 | dataloader: 76 | batch_size: 1 77 | num_workers: 0 78 | drop_last: false 79 | pin_memory: false 80 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .det_dataset import DetDataSet 2 | from .rec_dataset import RecDataSet 3 | from torch.utils.data import DistributedSampler, DataLoader 4 | 5 | 6 | __all__ = ["build_data_loader"] 7 | 8 | 9 | def build_data_loader(config, distributed, logger, mode): 10 | module_name = config[mode]["dataset"]["name"] 11 | data_set = eval(module_name)(config, logger, mode) 12 | sampler = None 13 | if distributed: 14 | sampler = DistributedSampler(data_set) 15 | 16 | loader_conf = config[mode]["dataloader"] 17 | data_loader = DataLoader( 18 | dataset=data_set, 19 | batch_size=loader_conf["batch_size"], 20 | shuffle=(sampler is None), # when distributed is True, shuffle is False 21 | drop_last=loader_conf["drop_last"], 22 | num_workers=loader_conf["num_workers"], 23 | sampler=sampler, 24 | pin_memory=loader_conf["pin_memory"], # 提高数据转移速度 25 | ) 26 | return data_loader 27 | -------------------------------------------------------------------------------- /data_loader/data_det/test.txt: -------------------------------------------------------------------------------- 1 | test/00001.jpg [{"points": [[132, 121], [130, 194], [439, 198], [439, 128]], "transcription": "detection"}, {"points": [[295, 215], [294, 229], [507, 231], [507, 217]], "transcription": "detection"}, {"points": [[349, 234], [349, 247], [507, 248], [507, 235]], "transcription": "detection"}, {"points": [[213, 266], [209, 283], [382, 284], [378, 268]], "transcription": "detection"}, {"points": [[207, 298], [209, 317], [364, 318], [361, 302]], "transcription": "detection"}, {"points": [[211, 333], [210, 351], [435, 352], [435, 335]], "transcription": "detection"}, {"points": [[330, 326], [333, 338], [355, 332], [351, 320]], "transcription": "detection"}, {"points": [[88, 364], [88, 384], [199, 384], [199, 364]], "transcription": "detection"}, {"points": [[209, 367], [207, 384], [253, 384], [253, 368]], "transcription": "detection"}, {"points": [[212, 400], [208, 417], [357, 418], [357, 401]], "transcription": "detection"}, {"points": [[87, 430], [87, 449], [195, 451], [195, 431]], "transcription": "detection"}, {"points": [[207, 467], [203, 481], [456, 485], [455, 468]], "transcription": "detection"}, {"points": [[86, 496], [86, 516], [192, 516], [192, 496]], "transcription": "detection"}, {"points": [[206, 499], [205, 511], [479, 513], [479, 501]], "transcription": "detection"}, {"points": [[204, 513], [205, 525], [485, 526], [487, 516]], "transcription": "detection"}, {"points": [[205, 526], [205, 538], [486, 539], [486, 527]], "transcription": "detection"}, {"points": [[205, 539], [205, 551], [487, 553], [486, 540]], "transcription": "detection"}, {"points": [[204, 551], [204, 564], [371, 564], [367, 553]], "transcription": "detection"}, {"points": [[207, 564], [207, 578], [480, 578], [480, 566]], "transcription": "detection"}, {"points": [[287, 666], [286, 684], [383, 685], [383, 667]], "transcription": "detection"}, {"points": [[26, 759], [26, 770], [218, 771], [215, 761]], "transcription": "detection"}, {"points": [[337, 757], [337, 771], [524, 771], [524, 757]], "transcription": "detection"}, {"points": [[92, 262], [92, 286], [110, 286], [110, 262]], "transcription": "detection"}, {"points": [[175, 266], [175, 285], [196, 285], [196, 266]], "transcription": "detection"}, {"points": [[89, 296], [89, 319], [110, 319], [110, 296]], "transcription": "detection"}, {"points": [[175, 299], [175, 319], [197, 319], [197, 299]], "transcription": "detection"}, {"points": [[89, 330], [89, 351], [110, 351], [110, 330]], "transcription": "detection"}, {"points": [[175, 330], [175, 353], [195, 353], [195, 330]], "transcription": "detection"}, {"points": [[87, 396], [87, 420], [194, 420], [194, 396]], "transcription": "detection"}, {"points": [[207, 433], [207, 449], [316, 449], [316, 433]], "transcription": "detection"}, {"points": [[87, 462], [87, 484], [193, 484], [193, 462]], "transcription": "detection"}, {"points": [[364, 696], [364, 716], [409, 716], [409, 696]], "transcription": "detection"}, {"points": [[416, 696], [416, 710], [431, 710], [431, 696]], "transcription": "detection"}, {"points": [[433, 698], [433, 717], [447, 717], [447, 698]], "transcription": "detection"}, {"points": [[449, 698], [449, 712], [465, 712], [465, 698]], "transcription": "detection"}, {"points": [[473, 700], [473, 714], [485, 714], [485, 700]], "transcription": "detection"}] 2 | test/00040.jpg [{"points": [[327, 152], [327, 177], [634, 183], [635, 157]], "transcription": "detection"}, {"points": [[462, 204], [461, 223], [797, 228], [797, 208]], "transcription": "detection"}, {"points": [[132, 249], [132, 272], [211, 272], [211, 249]], "transcription": "detection"}, {"points": [[236, 250], [235, 272], [483, 275], [483, 252]], "transcription": "detection"}, {"points": [[122, 292], [120, 314], [222, 315], [215, 294]], "transcription": "detection"}, {"points": [[235, 294], [234, 315], [273, 315], [273, 295]], "transcription": "detection"}, {"points": [[123, 336], [123, 357], [221, 357], [216, 335]], "transcription": "detection"}, {"points": [[232, 336], [232, 358], [273, 358], [273, 336]], "transcription": "detection"}, {"points": [[130, 380], [130, 401], [211, 403], [211, 381]], "transcription": "detection"}, {"points": [[233, 381], [234, 400], [271, 402], [271, 381]], "transcription": "detection"}, {"points": [[139, 422], [139, 445], [199, 445], [199, 422]], "transcription": "detection"}, {"points": [[234, 424], [234, 444], [611, 447], [611, 426]], "transcription": "detection"}, {"points": [[126, 466], [126, 489], [207, 489], [207, 466]], "transcription": "detection"}, {"points": [[232, 466], [232, 488], [612, 490], [613, 468]], "transcription": "detection"}, {"points": [[126, 508], [126, 529], [206, 531], [206, 511]], "transcription": "detection"}, {"points": [[229, 511], [229, 531], [619, 533], [619, 513]], "transcription": "detection"}, {"points": [[233, 550], [232, 571], [840, 573], [841, 555]], "transcription": "detection"}, {"points": [[232, 573], [232, 594], [825, 596], [824, 575]], "transcription": "detection"}, {"points": [[232, 596], [232, 615], [817, 617], [817, 598]], "transcription": "detection"}, {"points": [[121, 833], [121, 856], [199, 853], [199, 831]], "transcription": "detection"}, {"points": [[387, 1159], [388, 1176], [563, 1177], [561, 1156]], "transcription": "detection"}, {"points": [[385, 1190], [386, 1209], [550, 1207], [550, 1188]], "transcription": "detection"}, {"points": [[562, 1157], [562, 1176], [739, 1176], [739, 1157]], "transcription": "detection"}] 3 | test/00084.jpg [{"points": [[873, 365], [868, 515], [1490, 534], [1494, 384]], "transcription": "detection"}, {"points": [[367, 416], [367, 470], [692, 474], [693, 420]], "transcription": "detection"}, {"points": [[1924, 417], [1924, 443], [2074, 443], [2074, 417]], "transcription": "detection"}, {"points": [[1921, 445], [1924, 469], [2059, 472], [2061, 445]], "transcription": "detection"}, {"points": [[374, 478], [374, 511], [702, 517], [703, 484]], "transcription": "detection"}, {"points": [[1909, 470], [1908, 497], [2047, 501], [2048, 474]], "transcription": "detection"}, {"points": [[1910, 499], [1910, 525], [2042, 525], [2042, 499]], "transcription": "detection"}, {"points": [[1095, 535], [1094, 590], [1255, 595], [1256, 539]], "transcription": "detection"}, {"points": [[1908, 527], [1908, 553], [2058, 553], [2058, 527]], "transcription": "detection"}, {"points": [[1907, 552], [1906, 575], [1986, 579], [1987, 557]], "transcription": "detection"}, {"points": [[364, 673], [364, 709], [395, 709], [395, 673]], "transcription": "detection"}, {"points": [[558, 678], [557, 709], [828, 712], [825, 681]], "transcription": "detection"}, {"points": [[1360, 685], [1360, 727], [1532, 727], [1532, 688]], "transcription": "detection"}, {"points": [[362, 754], [362, 789], [400, 789], [400, 754]], "transcription": "detection"}, {"points": [[555, 759], [554, 789], [957, 793], [958, 767]], "transcription": "detection"}, {"points": [[1359, 768], [1358, 804], [1531, 809], [1534, 771]], "transcription": "detection"}, {"points": [[360, 835], [360, 871], [524, 872], [524, 836]], "transcription": "detection"}, {"points": [[1556, 852], [1555, 880], [1983, 888], [1984, 856]], "transcription": "detection"}, {"points": [[356, 912], [355, 955], [535, 959], [536, 916]], "transcription": "detection"}, {"points": [[553, 928], [551, 955], [1282, 967], [1281, 940]], "transcription": "detection"}, {"points": [[1358, 931], [1358, 967], [1399, 968], [1397, 932]], "transcription": "detection"}, {"points": [[1558, 937], [1555, 966], [2047, 975], [2046, 946]], "transcription": "detection"}, {"points": [[550, 959], [546, 983], [1268, 994], [1267, 969]], "transcription": "detection"}, {"points": [[550, 985], [550, 1018], [1277, 1025], [1277, 992]], "transcription": "detection"}, {"points": [[546, 1016], [545, 1043], [1281, 1050], [1282, 1024]], "transcription": "detection"}, {"points": [[551, 1045], [548, 1073], [1278, 1081], [1277, 1055]], "transcription": "detection"}, {"points": [[550, 1073], [550, 1100], [791, 1102], [792, 1075]], "transcription": "detection"}, {"points": [[1448, 1188], [1447, 1229], [1668, 1235], [1669, 1194]], "transcription": "detection"}, {"points": [[176, 1442], [175, 1472], [530, 1480], [524, 1452]], "transcription": "detection"}, {"points": [[938, 1453], [937, 1479], [1389, 1487], [1390, 1460]], "transcription": "detection"}, {"points": [[931, 1481], [936, 1510], [1289, 1515], [1289, 1484]], "transcription": "detection"}, {"points": [[1820, 1472], [1819, 1508], [2134, 1516], [2135, 1479]], "transcription": "detection"}, {"points": [[491, 757], [491, 794], [526, 794], [526, 757]], "transcription": "detection"}, {"points": [[553, 836], [553, 870], [635, 870], [635, 836]], "transcription": "detection"}, {"points": [[1361, 846], [1361, 891], [1526, 891], [1526, 846]], "transcription": "detection"}, {"points": [[1563, 695], [1563, 726], [1701, 726], [1701, 695]], "transcription": "detection"}, {"points": [[1559, 776], [1559, 809], [1762, 809], [1762, 776]], "transcription": "detection"}, {"points": [[1556, 967], [1556, 1000], [1583, 1000], [1583, 967]], "transcription": "detection"}, {"points": [[1490, 931], [1490, 972], [1528, 972], [1528, 931]], "transcription": "detection"}, {"points": [[544, 1466], [544, 1494], [828, 1494], [828, 1466]], "transcription": "detection"}, {"points": [[492, 675], [492, 715], [529, 715], [529, 675]], "transcription": "detection"}, {"points": [[1689, 1330], [1689, 1363], [1766, 1363], [1766, 1330]], "transcription": "detection"}, {"points": [[1767, 1320], [1767, 1367], [1803, 1367], [1803, 1320]], "transcription": "detection"}, {"points": [[1841, 1331], [1841, 1368], [1872, 1368], [1872, 1331]], "transcription": "detection"}, {"points": [[1876, 1323], [1876, 1365], [1908, 1365], [1908, 1323]], "transcription": "detection"}, {"points": [[1942, 1335], [1942, 1367], [1981, 1367], [1981, 1335]], "transcription": "detection"}, {"points": [[1986, 1327], [1986, 1362], [2015, 1362], [2015, 1327]], "transcription": "detection"}] 4 | test/00159.jpg [{"points": [[756, 340], [756, 421], [1453, 428], [1454, 346]], "transcription": "detection"}, {"points": [[500, 440], [498, 607], [1712, 618], [1713, 450]], "transcription": "detection"}, {"points": [[1341, 618], [1341, 652], [1630, 655], [1631, 621]], "transcription": "detection"}, {"points": [[491, 700], [495, 735], [1107, 738], [1107, 705]], "transcription": "detection"}, {"points": [[1174, 698], [1175, 737], [1338, 741], [1335, 702]], "transcription": "detection"}, {"points": [[495, 757], [496, 793], [1074, 797], [1073, 758]], "transcription": "detection"}, {"points": [[1346, 754], [1347, 801], [1888, 803], [1886, 762]], "transcription": "detection"}, {"points": [[305, 821], [307, 861], [485, 860], [488, 823]], "transcription": "detection"}, {"points": [[1347, 816], [1347, 851], [1802, 852], [1802, 818]], "transcription": "detection"}, {"points": [[307, 878], [306, 918], [486, 918], [485, 883]], "transcription": "detection"}, {"points": [[309, 941], [308, 978], [484, 981], [485, 943]], "transcription": "detection"}, {"points": [[496, 987], [496, 1006], [590, 1006], [590, 987]], "transcription": "detection"}, {"points": [[313, 999], [310, 1036], [484, 1037], [485, 1003]], "transcription": "detection"}, {"points": [[1362, 1152], [1364, 1196], [1960, 1182], [1956, 1140]], "transcription": "detection"}, {"points": [[316, 1223], [315, 1262], [476, 1266], [477, 1227]], "transcription": "detection"}, {"points": [[758, 1396], [761, 1459], [1434, 1469], [1431, 1409]], "transcription": "detection"}, {"points": [[309, 703], [309, 741], [482, 741], [482, 703]], "transcription": "detection"}, {"points": [[305, 758], [305, 800], [480, 800], [480, 758]], "transcription": "detection"}, {"points": [[491, 813], [491, 851], [589, 851], [589, 813]], "transcription": "detection"}, {"points": [[495, 871], [495, 907], [591, 907], [591, 871]], "transcription": "detection"}, {"points": [[496, 928], [496, 964], [594, 964], [594, 928]], "transcription": "detection"}, {"points": [[1348, 702], [1348, 739], [1387, 739], [1387, 702]], "transcription": "detection"}, {"points": [[1441, 700], [1441, 742], [1478, 742], [1478, 700]], "transcription": "detection"}, {"points": [[1510, 691], [1510, 726], [1566, 726], [1566, 691]], "transcription": "detection"}, {"points": [[1596, 689], [1596, 728], [1643, 728], [1643, 689]], "transcription": "detection"}, {"points": [[1178, 624], [1178, 667], [1219, 667], [1219, 624]], "transcription": "detection"}, {"points": [[1287, 626], [1287, 666], [1330, 666], [1330, 626]], "transcription": "detection"}, {"points": [[1178, 770], [1178, 812], [1334, 812], [1334, 770]], "transcription": "detection"}, {"points": [[1517, 1241], [1517, 1275], [1585, 1275], [1585, 1241]], "transcription": "detection"}, {"points": [[1615, 1238], [1615, 1279], [1652, 1279], [1652, 1238]], "transcription": "detection"}, {"points": [[1692, 1238], [1692, 1269], [1727, 1269], [1727, 1238]], "transcription": "detection"}, {"points": [[1754, 1237], [1754, 1275], [1788, 1275], [1788, 1237]], "transcription": "detection"}, {"points": [[1830, 1236], [1830, 1266], [1865, 1266], [1865, 1236]], "transcription": "detection"}, {"points": [[1894, 1235], [1894, 1270], [1924, 1270], [1924, 1235]], "transcription": "detection"}, {"points": [[950, 1233], [950, 1266], [991, 1266], [991, 1233]], "transcription": "detection"}, {"points": [[1008, 1233], [1008, 1269], [1037, 1269], [1037, 1233]], "transcription": "detection"}, {"points": [[865, 1237], [865, 1277], [894, 1277], [894, 1237]], "transcription": "detection"}, {"points": [[808, 1235], [808, 1268], [835, 1268], [835, 1235]], "transcription": "detection"}, {"points": [[621, 1232], [621, 1267], [690, 1267], [690, 1232]], "transcription": "detection"}, {"points": [[712, 1237], [712, 1277], [752, 1277], [752, 1237]], "transcription": "detection"}] 5 | test/00172.jpg [{"points": [[185, 38], [184, 65], [552, 76], [553, 49]], "transcription": "detection"}, {"points": [[143, 88], [143, 107], [406, 110], [407, 92]], "transcription": "detection"}, {"points": [[79, 134], [79, 151], [167, 152], [167, 136]], "transcription": "detection"}, {"points": [[182, 135], [181, 153], [437, 155], [437, 137]], "transcription": "detection"}, {"points": [[77, 178], [77, 196], [165, 198], [165, 179]], "transcription": "detection"}, {"points": [[182, 171], [182, 190], [669, 191], [669, 173]], "transcription": "detection"}, {"points": [[184, 190], [184, 207], [230, 207], [230, 190]], "transcription": "detection"}, {"points": [[185, 216], [184, 235], [675, 236], [674, 217]], "transcription": "detection"}, {"points": [[85, 224], [85, 243], [156, 243], [156, 224]], "transcription": "detection"}, {"points": [[183, 237], [182, 253], [230, 255], [231, 235]], "transcription": "detection"}, {"points": [[74, 271], [74, 290], [166, 290], [166, 271]], "transcription": "detection"}, {"points": [[186, 271], [187, 291], [239, 292], [239, 271]], "transcription": "detection"}, {"points": [[74, 320], [74, 337], [164, 337], [164, 320]], "transcription": "detection"}, {"points": [[186, 317], [183, 339], [239, 339], [239, 318]], "transcription": "detection"}, {"points": [[80, 368], [80, 386], [154, 386], [154, 368]], "transcription": "detection"}, {"points": [[182, 367], [183, 386], [343, 386], [343, 368]], "transcription": "detection"}, {"points": [[70, 416], [70, 435], [160, 435], [160, 416]], "transcription": "detection"}, {"points": [[181, 416], [181, 433], [256, 433], [256, 416]], "transcription": "detection"}, {"points": [[181, 454], [181, 474], [703, 469], [703, 448]], "transcription": "detection"}, {"points": [[67, 466], [67, 484], [161, 484], [161, 466]], "transcription": "detection"}, {"points": [[74, 516], [74, 535], [152, 535], [152, 516]], "transcription": "detection"}, {"points": [[178, 515], [179, 533], [420, 531], [420, 512]], "transcription": "detection"}, {"points": [[86, 565], [86, 589], [138, 589], [138, 565]], "transcription": "detection"}, {"points": [[176, 563], [176, 587], [291, 585], [291, 562]], "transcription": "detection"}, {"points": [[70, 617], [71, 641], [150, 639], [150, 615]], "transcription": "detection"}, {"points": [[84, 670], [84, 694], [136, 694], [136, 670]], "transcription": "detection"}, {"points": [[70, 728], [71, 749], [321, 745], [321, 725]], "transcription": "detection"}, {"points": [[447, 722], [447, 741], [651, 740], [648, 720]], "transcription": "detection"}, {"points": [[446, 742], [449, 762], [650, 761], [652, 742]], "transcription": "detection"}, {"points": [[182, 475], [182, 495], [207, 495], [207, 475]], "transcription": "detection"}] -------------------------------------------------------------------------------- /data_loader/data_det/test/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_det/test/00001.jpg -------------------------------------------------------------------------------- /data_loader/data_det/test/00040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_det/test/00040.jpg -------------------------------------------------------------------------------- /data_loader/data_det/test/00084.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_det/test/00084.jpg -------------------------------------------------------------------------------- /data_loader/data_det/test/00159.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_det/test/00159.jpg -------------------------------------------------------------------------------- /data_loader/data_det/test/00172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_det/test/00172.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train.txt: -------------------------------------------------------------------------------- 1 | train/001.jpg 第二类医疗器械经营备案凭证 2 | train/002.jpg 皖六食药监械经营备20160244更号 3 | train/003.jpg 安徽省兴盛医疗科技有限公司 4 | train/004.jpg 企业名称 5 | train/005.jpg 法定代表人 6 | train/006.jpg 企业负责人 7 | train/007.jpg 解晓东 8 | train/008.jpg 经营方式 9 | train/009.jpg 批零兼营 10 | train/010.jpg 住 11 | train/012.jpg 经营场所 12 | train/014.jpg 库房地址 13 | train/020.jpg 经营范围 14 | train/022.jpg 6870,6877 15 | train/025.jpg 备案部门(公章) 16 | train/026.jpg 备案日期: 17 | train/027.jpg 田圣萍 18 | train/028.jpg 所 19 | train/029.jpg 2019年3月19日 20 | train/030.jpg 备案号 -------------------------------------------------------------------------------- /data_loader/data_rec/train/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/001.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/002.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/003.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/004.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/005.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/006.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/007.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/008.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/009.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/010.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/011.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/012.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/013.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/014.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/015.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/016.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/017.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/018.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/019.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/020.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/021.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/022.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/023.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/024.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/025.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/026.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/027.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/028.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/029.jpg -------------------------------------------------------------------------------- /data_loader/data_rec/train/030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/data_loader/data_rec/train/030.jpg -------------------------------------------------------------------------------- /data_loader/det_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import codecs 5 | import random 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | from data_loader.img_aug import * 9 | 10 | 11 | class DetDataSet(Dataset): 12 | def __init__(self, config, logger, mode): 13 | dataset_conf = config[mode]["dataset"] 14 | self.base_dir = dataset_conf["data_base_dir"] 15 | self.mode = mode 16 | self.logger = logger 17 | self.data_lines = self.get_image_info_list(dataset_conf["ano_file_path"]) 18 | self._transforms = self._transforms_func_lst(dataset_conf["transforms"]) 19 | if dataset_conf["do_shuffle"]: 20 | random.shuffle(self.data_lines) 21 | 22 | def __len__(self): 23 | return len(self.data_lines) 24 | 25 | def get_image_info_list(self, file_path): 26 | """数据文件以\t分割""" 27 | lines = [] 28 | with codecs.open(file_path, "r", "utf8") as f: 29 | for line in f.readlines(): 30 | tmp_data = line.strip().split("\t") 31 | if len(tmp_data) != 2: 32 | self.logger.warn(f"{line}数据格式不对") 33 | continue 34 | image_path = os.path.join(self.base_dir, tmp_data[0]) 35 | if not os.path.exists(image_path): 36 | self.logger.warn(f"{image_path}图片文件不存在") 37 | continue 38 | lines.append([tmp_data[0], tmp_data[1]]) 39 | return lines 40 | 41 | @staticmethod 42 | def det_label_encoder(label_str): 43 | label = json.loads(label_str) 44 | boxes = [] 45 | ignore_tags = [] 46 | for bno in range(0, len(label)): 47 | box = label[bno]["points"] 48 | txt = label[bno]["transcription"] 49 | if txt in ["*", "###"]: # ICDAR为### 50 | ignore_tags.append(True) 51 | else: 52 | ignore_tags.append(False) 53 | boxes.append(box) 54 | 55 | boxes = np.array(boxes, dtype=np.float) 56 | ignore_tags = np.array(ignore_tags, dtype=np.bool) 57 | return boxes, ignore_tags 58 | 59 | @staticmethod 60 | def _transforms_func_lst(config): 61 | func_lst = [] 62 | for _transform in config: 63 | operator = list(_transform.keys())[0] 64 | params = dict() if _transform[operator] is None else _transform[operator] 65 | func_name = eval(operator)(**params) 66 | func_lst.append(func_name) 67 | return func_lst 68 | 69 | def __getitem__(self, index): 70 | try: 71 | data_line = self.data_lines[index] 72 | image_path = os.path.join(self.base_dir, data_line[0]) 73 | polys, ignore_tags = self.det_label_encoder(data_line[1]) 74 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 默认BGR CHANNEL_LAST 75 | if image is None: 76 | self.logger.info(image_path) 77 | data = {"polys": polys, "image": image, "ignore_tags": ignore_tags} 78 | for _transform in self._transforms: 79 | data = _transform(data) 80 | except Exception as e: 81 | self.logger.error(e) 82 | data = [] 83 | 84 | if not data: 85 | return self.__getitem__(np.random.randint(self.__len__())) 86 | return data 87 | -------------------------------------------------------------------------------- /data_loader/img_aug/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_threshold_map import MakeBorderMap 2 | from .make_binary_map import MakeProbMap 3 | from .operators import IaaAugment, NormalizeImage, OutputData, ResizeForTest 4 | from .random_crop_data import RandomCropData 5 | from .rec_img_aug import RecAug, RecResizeImg 6 | 7 | __all__ = [ 8 | "MakeBorderMap", 9 | "IaaAugment", 10 | "MakeProbMap", 11 | "RandomCropData", 12 | "NormalizeImage", 13 | "OutputData", 14 | "ResizeForTest", 15 | "RecAug", 16 | "RecResizeImg", 17 | ] 18 | -------------------------------------------------------------------------------- /data_loader/img_aug/make_binary_map.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import pyclipper 3 | import numpy as np 4 | from shapely.geometry import Polygon 5 | 6 | 7 | class MakeProbMap(object): 8 | def __init__( 9 | self, 10 | min_text_size=4, 11 | shrink_ratio=0.4 12 | ): 13 | """ 14 | :param min_text_size: 最短边的距离, 根据实际情况决定 15 | :param shrink_ratio: 收缩比例r 16 | """ 17 | self.min_text_size = min_text_size 18 | self.shrink_ratio = shrink_ratio 19 | 20 | def __call__(self, data): 21 | image = data['image'] 22 | text_polys = data['polys'] 23 | ignore_tags = data['ignore_tags'] 24 | 25 | h, w = image.shape[:2] 26 | text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w) 27 | gt = np.zeros((h, w), dtype=np.float32) 28 | mask = np.ones((h, w), dtype=np.float32) 29 | for i in range(len(text_polys)): 30 | polygon = text_polys[i] 31 | height = min(np.linalg.norm(polygon[0] - polygon[3]), 32 | np.linalg.norm(polygon[1] - polygon[2])) 33 | width = min(np.linalg.norm(polygon[0] - polygon[1]), 34 | np.linalg.norm(polygon[2] - polygon[3])) 35 | if ignore_tags[i] or min(height, width) < self.min_text_size: 36 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) 37 | ignore_tags[i] = True 38 | else: 39 | polygon_shape = Polygon(polygon) 40 | subject = [tuple(pg) for pg in polygon] 41 | padding = pyclipper.PyclipperOffset() 42 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 43 | shrinked = [] 44 | 45 | # Increase the shrink ratio every time we get multiple polygon returned back 46 | possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio) 47 | np.append(possible_ratios, 1) 48 | # 这里跟官方DB有一点区别, 但个人认为这个可以更好地检测小文本 49 | for ratio in possible_ratios: 50 | # print(f"Change shrink ratio to {ratio}") 51 | distance = polygon_shape.area * ( 52 | 1 - np.power(ratio, 2)) / polygon_shape.length 53 | shrinked = padding.Execute(-distance) 54 | if shrinked: 55 | break 56 | 57 | if not shrinked: 58 | cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) 59 | ignore_tags[i] = True 60 | continue 61 | 62 | shrink = np.array(shrinked[0]).reshape(-1, 2) 63 | cv2.fillPoly(gt, [shrink.astype(np.int32)], 1) 64 | 65 | data['prob_map'] = gt 66 | data['prob_mask'] = mask 67 | return data 68 | 69 | def validate_polygons(self, polygons, ignore_tags, h, w): 70 | if len(polygons) == 0: 71 | return polygons, ignore_tags 72 | assert len(polygons) == len(ignore_tags) 73 | for polygon in polygons: 74 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) 75 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) 76 | 77 | for i in range(len(polygons)): 78 | area = self.polygon_area(polygons[i]) 79 | if abs(area) < 1: 80 | ignore_tags[i] = True 81 | if area > 0: 82 | polygons[i] = polygons[i][::-1, :] 83 | return polygons, ignore_tags 84 | 85 | @staticmethod 86 | def polygon_area(polygon): 87 | edge = 0 88 | for i in range(polygon.shape[0]): 89 | next_index = (i + 1) % polygon.shape[0] 90 | edge += (polygon[next_index, 0] - polygon[i, 0]) * ( 91 | polygon[next_index, 1] + polygon[i, 1]) 92 | 93 | return edge / 2. 94 | 95 | 96 | if __name__ == "__main__": 97 | po = np.array([[10, 10], [20, 10], [20, 20], [10, 20]]) 98 | print(Polygon(po).area) 99 | msm = MakeProbMap() 100 | print(msm.polygon_area(po)) 101 | -------------------------------------------------------------------------------- /data_loader/img_aug/make_threshold_map.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import warnings 3 | import pyclipper 4 | import numpy as np 5 | from shapely.geometry import Polygon 6 | 7 | warnings.simplefilter("ignore") 8 | 9 | 10 | class MakeBorderMap(object): 11 | def __init__( 12 | self, 13 | shrink_ratio=0.4, 14 | thresh_min=0.3, 15 | thresh_max=0.7 16 | ): 17 | """ 18 | :param shrink_ratio: 膨胀比例 19 | :param thresh_min: 非文字区域threshold_map值 20 | :param thresh_max: 用于归一化threshold_map 进行一定的缩放,将1缩放到0.7的值,将0缩放到0.3 21 | D = Area * (1 - r**r) / L 22 | """ 23 | self.shrink_ratio = shrink_ratio 24 | self.thresh_min = thresh_min 25 | self.thresh_max = thresh_max 26 | 27 | def __call__(self, data): 28 | img = data['image'] 29 | text_polys = data['polys'] 30 | ignore_tags = data["ignore_tags"] 31 | canvas = np.zeros(img.shape[:2], dtype=np.float32) 32 | mask = np.zeros(img.shape[:2], dtype=np.float32) 33 | for i in range(len(text_polys)): 34 | if ignore_tags[i]: 35 | continue 36 | self.draw_border_map(text_polys[i], canvas, mask=mask) 37 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 38 | data['thresh_map'] = canvas 39 | data['thresh_mask'] = mask 40 | return data 41 | 42 | def draw_border_map(self, polygon, canvas, mask): 43 | polygon = np.array(polygon) 44 | assert polygon.ndim == 2 45 | assert polygon.shape[1] == 2 46 | 47 | polygon_shape = Polygon(polygon) 48 | if polygon_shape.area <= 0: 49 | return 50 | 51 | # 计算收缩偏移量D 52 | distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 53 | subject = [tuple(pg) for pg in polygon] 54 | padding = pyclipper.PyclipperOffset() 55 | # joinType 当扁平的路径始终无法完美的获取角度信息,他们等价于一系列的圆弧曲线 56 | padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 57 | 58 | # 1. 首先对原始标注框G,采用上述偏移量D来进行扩充,得到的框为Gd 59 | padded_polygon = np.array(padding.Execute(distance)[0]) 60 | # 使用膨胀的padded_polygon多边形填充mask, 值为1.0 61 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 62 | xmin = padded_polygon[:, 0].min() 63 | xmax = padded_polygon[:, 0].max() 64 | ymin = padded_polygon[:, 1].min() 65 | ymax = padded_polygon[:, 1].max() 66 | width = xmax - xmin + 1 67 | height = ymax - ymin + 1 68 | polygon[:, 0] = polygon[:, 0] - xmin 69 | polygon[:, 1] = polygon[:, 1] - ymin 70 | 71 | xs = np.broadcast_to( 72 | np.linspace( 73 | 0, width - 1, num=width).reshape(1, width), (height, width)) 74 | ys = np.broadcast_to( 75 | np.linspace( 76 | 0, height - 1, num=height).reshape(height, 1), (height, width)) 77 | 78 | distance_map = np.zeros( 79 | (polygon.shape[0], height, width), dtype=np.float32) 80 | for i in range(polygon.shape[0]): 81 | j = (i + 1) % polygon.shape[0] 82 | # 2.计算框Gd内所有的点到G的四条边的距离,选择最小的距离(也就是Gd框内像素离它最近的G框的边的距离,下面简称像素到G框的距离) 83 | absolute_distance = self._distance(xs, ys, polygon[i], polygon[j]) 84 | # 3. 将所求的Gd框内所有像素到G框的距离,除以偏移量D进行归一化; 距离限制在[0,1]内, 大于1取值为1, 小于0取值为0 85 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 86 | 87 | distance_map = distance_map.min(axis=0) 88 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 89 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 90 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 91 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 92 | # 使用1减去4中得到的map,这里得到的就是Gd框和Gs框之间的像素到G框最近边的归一化距离; 距离为0,概率为1,距离为1,概率为0;因此1-d 93 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 94 | 1 - distance_map[ 95 | ymin_valid - ymin: ymax_valid - ymax + height, 96 | xmin_valid - xmin: xmax_valid - xmax + width 97 | ], 98 | canvas[ymin_valid: ymax_valid + 1, xmin_valid: xmax_valid + 1]) 99 | 100 | @staticmethod 101 | def _distance(xs, ys, point_1, point_2): 102 | """ 103 | compute the distance from point to a line 104 | ys: coordinates in the first axis 105 | xs: coordinates in the second axis 106 | point_1, point_2: (x, y), the end of the line 107 | """ 108 | square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) 109 | square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) 110 | square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 111 | 112 | cosin = (square_distance - square_distance_1 - square_distance_2) / ( 113 | 2 * np.sqrt(square_distance_1 * square_distance_2)) 114 | square_sin = 1 - np.square(cosin) 115 | square_sin = np.nan_to_num(square_sin) 116 | result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance) 117 | 118 | result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0] 119 | return result 120 | -------------------------------------------------------------------------------- /data_loader/img_aug/operators.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imgaug 3 | import numpy as np 4 | import imgaug.augmenters as iaa 5 | 6 | 7 | class IaaAugment(object): 8 | def __init__( 9 | self, 10 | flip_prob=0.5, 11 | affine_rotate=(-5, 5), 12 | resize_scale=(0.5, 1.5), 13 | ): 14 | """ 15 | :param flip_prob: 水平翻转概率 16 | :param affine_rotate: 放射变换角度 17 | :param resize_scale: 大小范围 18 | """ 19 | self.augmenter = iaa.Sequential([ 20 | iaa.Fliplr(p=flip_prob), 21 | iaa.Affine(rotate=affine_rotate), 22 | iaa.Resize(resize_scale) 23 | ]) 24 | 25 | def __call__(self, data): 26 | image = data['image'] 27 | shape = image.shape 28 | if self.augmenter: 29 | aug = self.augmenter.to_deterministic() 30 | data['image'] = aug.augment_image(image) 31 | data = self.may_augment_annotation(aug, data, shape) 32 | return data 33 | 34 | @staticmethod 35 | def may_augment_annotation(aug, data, shape): 36 | line_polys = [] 37 | for poly in data['polys']: 38 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 39 | keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(keypoints, shape=shape)])[0].keypoints 40 | line_polys.append([(p.x, p.y) for p in keypoints]) 41 | data['polys'] = np.array(line_polys) 42 | return data 43 | 44 | 45 | class NormalizeImage(object): 46 | MEAN = [0.485, 0.456, 0.406] 47 | STD = [0.229, 0.224, 0.225] 48 | 49 | def __init__(self, src_order="hwc", tgt_order="chw"): 50 | self.tgt_order = tgt_order 51 | self.src_order = src_order 52 | shape = (3, 1, 1) if self.src_order == "chw" else (1, 1, 3) 53 | self.scale = 1.0 / 255.0 54 | self.mean = np.array(self.MEAN).reshape(shape).astype('float32') 55 | self.std = np.array(self.STD).reshape(shape).astype('float32') 56 | 57 | def __call__(self, data): 58 | image = (data["image"].astype('float32') * self.scale - self.mean) / self.std 59 | if self.src_order == "hwc" and self.tgt_order == "chw": 60 | data["image"] = image.transpose((2, 0, 1)) 61 | return data 62 | 63 | 64 | class OutputData(object): 65 | def __init__(self, keep_keys): 66 | self.keep_keys = keep_keys 67 | 68 | def __call__(self, data): 69 | output = dict() 70 | for key in self.keep_keys: 71 | output[key] = data[key] 72 | return output 73 | 74 | 75 | class ResizeForTest(object): 76 | def __init__(self, long_size=960): 77 | self.max_pixes = long_size 78 | 79 | def __call__(self, data): 80 | image = data["image"] 81 | src_h, src_w, _ = image.shape 82 | 83 | if src_h > src_w: 84 | ratio = float(self.max_pixes) / src_h 85 | else: 86 | ratio = float(self.max_pixes) / src_w 87 | 88 | resize_h = int(src_h * ratio) 89 | resize_w = int(src_w * ratio) 90 | 91 | max_stride = 128 92 | resize_h = (resize_h + max_stride - 1) // max_stride * max_stride 93 | resize_w = (resize_w + max_stride - 1) // max_stride * max_stride 94 | data["image"] = cv2.resize(image, (int(resize_w), int(resize_h))) 95 | data["src_scale"] = np.array([src_h, src_w], dtype=np.int) 96 | return data 97 | -------------------------------------------------------------------------------- /data_loader/img_aug/random_crop_data.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class RandomCropData(object): 6 | """ 7 | 参考EAST模型图片数据剪切方式 8 | """ 9 | def __init__( 10 | self, 11 | size=(960, 960), 12 | max_tries=10, 13 | min_crop_side_ratio=0.1, 14 | keep_ratio=True 15 | ): 16 | """ 17 | :param size: 最终输出图片大小 18 | :param max_tries: 检测尝试次数 19 | :param min_crop_side_ratio: 20 | :param keep_ratio: 21 | """ 22 | self.size = size 23 | self.max_tries = max_tries 24 | self.min_crop_side_ratio = min_crop_side_ratio 25 | self.keep_ratio = keep_ratio 26 | 27 | def __call__(self, data): 28 | img = data['image'] 29 | text_polys = data['polys'] 30 | ignore_tags = data['ignore_tags'] 31 | all_care_polys = [ 32 | text_polys[i] for i, tag in enumerate(ignore_tags) if not tag 33 | ] 34 | # 计算crop区域 35 | crop_x, crop_y, crop_w, crop_h = self.crop_area( 36 | img, all_care_polys, self.min_crop_side_ratio, self.max_tries) 37 | # crop 图片 保持比例填充 38 | scale_w = self.size[0] / crop_w 39 | scale_h = self.size[1] / crop_h 40 | scale = min(scale_w, scale_h) 41 | h = int(crop_h * scale) 42 | w = int(crop_w * scale) 43 | if self.keep_ratio: 44 | padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), 45 | img.dtype) 46 | padimg[:h, :w] = cv2.resize( 47 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) 48 | img = padimg 49 | else: 50 | img = cv2.resize( 51 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], 52 | tuple(self.size)) 53 | # crop 文本框 54 | text_polys_crop = [] 55 | ignore_tags_crop = [] 56 | for poly, tag in zip(text_polys, ignore_tags): 57 | poly = ((poly - (crop_x, crop_y)) * scale).tolist() 58 | if not self.is_poly_outside_rect(poly, 0, 0, w, h): 59 | text_polys_crop.append(poly) 60 | ignore_tags_crop.append(tag) 61 | data['image'] = img 62 | data['polys'] = np.array(text_polys_crop, dtype=np.float) 63 | data['ignore_tags'] = np.array(ignore_tags_crop, dtype=np.bool) 64 | return data 65 | 66 | def crop_area(self, im, text_polys, min_crop_side_ratio, max_tries): 67 | h, w, _ = im.shape 68 | h_array = np.zeros(h, dtype=np.int32) 69 | w_array = np.zeros(w, dtype=np.int32) 70 | for points in text_polys: 71 | points = np.round(points, decimals=0).astype(np.int32) 72 | minx = np.min(points[:, 0]) 73 | maxx = np.max(points[:, 0]) 74 | w_array[minx:maxx] = 1 75 | miny = np.min(points[:, 1]) 76 | maxy = np.max(points[:, 1]) 77 | h_array[miny:maxy] = 1 78 | # ensure the cropped area not across a text 79 | h_axis = np.where(h_array == 0)[0] 80 | w_axis = np.where(w_array == 0)[0] 81 | 82 | if len(h_axis) == 0 or len(w_axis) == 0: 83 | return 0, 0, w, h 84 | 85 | h_regions = self.split_regions(h_axis) 86 | w_regions = self.split_regions(w_axis) 87 | 88 | for i in range(max_tries): 89 | if len(w_regions) > 1: 90 | xmin, xmax = self.region_wise_random_select(w_regions) 91 | else: 92 | xmin, xmax = self.random_select(w_axis, w) 93 | if len(h_regions) > 1: 94 | ymin, ymax = self.region_wise_random_select(h_regions) 95 | else: 96 | ymin, ymax = self.random_select(h_axis, h) 97 | 98 | if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h: 99 | # area too small 100 | continue 101 | num_poly_in_rect = 0 102 | for poly in text_polys: 103 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, 104 | ymax - ymin): 105 | num_poly_in_rect += 1 106 | break 107 | 108 | if num_poly_in_rect > 0: 109 | return xmin, ymin, xmax - xmin, ymax - ymin 110 | 111 | return 0, 0, w, h 112 | 113 | @staticmethod 114 | def is_poly_outside_rect(poly, x, y, w, h): 115 | poly = np.array(poly) 116 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w: 117 | return True 118 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h: 119 | return True 120 | return False 121 | 122 | @staticmethod 123 | def split_regions(axis): 124 | regions = [] 125 | min_axis = 0 126 | for i in range(1, axis.shape[0]): 127 | if axis[i] != axis[i - 1] + 1: 128 | region = axis[min_axis:i] 129 | min_axis = i 130 | regions.append(region) 131 | return regions 132 | 133 | @staticmethod 134 | def random_select(axis, max_size): 135 | xx = np.random.choice(axis, size=2) 136 | xmin = np.min(xx) 137 | xmax = np.max(xx) 138 | xmin = np.clip(xmin, 0, max_size - 1) 139 | xmax = np.clip(xmax, 0, max_size - 1) 140 | return xmin, xmax 141 | 142 | @staticmethod 143 | def region_wise_random_select(regions): 144 | selected_index = list(np.random.choice(len(regions), 2)) 145 | selected_values = [] 146 | for index in selected_index: 147 | axis = regions[index] 148 | xx = int(np.random.choice(axis, size=1)) 149 | selected_values.append(xx) 150 | xmin = min(selected_values) 151 | xmax = max(selected_values) 152 | return xmin, xmax 153 | -------------------------------------------------------------------------------- /data_loader/img_aug/rec_img_aug.py: -------------------------------------------------------------------------------- 1 | import math 2 | import cv2 3 | import numpy as np 4 | import random 5 | 6 | from .text_image_aug import tia_perspective, tia_stretch, tia_distort 7 | 8 | 9 | class RecAug(object): 10 | def __init__(self, aug_prob=0.4): 11 | self.aug_prob = aug_prob 12 | 13 | def __call__(self, data): 14 | img = data['image'] 15 | img = warp(img, 10, self.aug_prob) 16 | data['image'] = img 17 | return data 18 | 19 | 20 | class RecResizeImg(object): 21 | def __init__(self, image_shape=None): 22 | if image_shape is None: 23 | image_shape = [3, 32, 320] 24 | self.image_shape = image_shape 25 | 26 | def __call__(self, data): 27 | img = data['image'] 28 | norm_img = resize_norm_img(img, self.image_shape) 29 | data['image'] = norm_img 30 | return data 31 | 32 | 33 | def resize_norm_img(img, image_shape): 34 | img_c, img_h, img_w = image_shape 35 | h = img.shape[0] 36 | w = img.shape[1] 37 | ratio = w / float(h) 38 | if math.ceil(img_h * ratio) > img_w: 39 | resized_w = img_w 40 | else: 41 | resized_w = int(math.ceil(img_h * ratio)) 42 | resized_image = cv2.resize(img, (resized_w, img_h)) 43 | resized_image = resized_image.astype('float32') 44 | resized_image = resized_image.transpose((2, 0, 1)) / 255 45 | resized_image -= 0.5 46 | resized_image /= 0.5 47 | padding_im = np.zeros((img_c, img_h, img_w), dtype=np.float32) 48 | padding_im[:, :, 0:resized_w] = resized_image 49 | return padding_im 50 | 51 | 52 | def flag(): 53 | """ 54 | flag 55 | """ 56 | return 1 if random.random() > 0.5000001 else -1 57 | 58 | 59 | def cvtColor(img): 60 | """ 61 | cvtColor 62 | """ 63 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 64 | delta = 0.001 * random.random() * flag() 65 | hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) 66 | new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 67 | return new_img 68 | 69 | 70 | def blur(img): 71 | """ 72 | blur 73 | """ 74 | h, w, _ = img.shape 75 | if h > 10 and w > 10: 76 | return cv2.GaussianBlur(img, (5, 5), 1) 77 | else: 78 | return img 79 | 80 | 81 | def jitter(img): 82 | """ 83 | jitter 84 | """ 85 | w, h, _ = img.shape 86 | if h > 10 and w > 10: 87 | thres = min(w, h) 88 | s = int(random.random() * thres * 0.01) 89 | src_img = img.copy() 90 | for i in range(s): 91 | img[i:, i:, :] = src_img[:w - i, :h - i, :] 92 | return img 93 | else: 94 | return img 95 | 96 | 97 | def add_gasuss_noise(image, mean=0, var=0.1): 98 | """ 99 | Gasuss noise 100 | """ 101 | 102 | noise = np.random.normal(mean, var ** 0.5, image.shape) 103 | out = image + 0.5 * noise 104 | out = np.clip(out, 0, 255) 105 | out = np.uint8(out) 106 | return out 107 | 108 | 109 | def get_crop(image): 110 | """ 111 | random crop 112 | """ 113 | h, w, _ = image.shape 114 | top_min = 1 115 | top_max = 8 116 | top_crop = int(random.randint(top_min, top_max)) 117 | top_crop = min(top_crop, h - 1) 118 | crop_img = image.copy() 119 | ratio = random.randint(0, 1) 120 | if ratio: 121 | crop_img = crop_img[top_crop:h, :, :] 122 | else: 123 | crop_img = crop_img[0:h - top_crop, :, :] 124 | return crop_img 125 | 126 | 127 | class ImageAugConf: 128 | """ 129 | Config 130 | """ 131 | 132 | def __init__(self, w, h, ang): 133 | self.anglex = random.random() * 5 * flag() 134 | self.angley = random.random() * 5 * flag() 135 | self.anglez = -1 * random.random() * int(ang) * flag() 136 | self.fov = 42 137 | self.r = 0 138 | self.shearx = 0 139 | self.sheary = 0 140 | self.borderMode = cv2.BORDER_REPLICATE 141 | self.w = w 142 | self.h = h 143 | 144 | self.perspective = True 145 | self.stretch = True 146 | self.distort = True 147 | self.crop = True 148 | self.reverse = True 149 | self.noise = True 150 | self.jitter = True 151 | self.blur = True 152 | self.color = True 153 | 154 | 155 | def warp(img, ang, prob=0.4): 156 | """ 157 | warp 158 | """ 159 | h, w, _ = img.shape 160 | config = ImageAugConf(w, h, ang) 161 | new_img = img 162 | 163 | if config.distort: 164 | img_height, img_width = img.shape[0:2] 165 | if random.random() <= prob and img_height >= 20 and img_width >= 20: 166 | new_img = tia_distort(new_img, random.randint(3, 6)) 167 | 168 | if config.stretch: 169 | img_height, img_width = img.shape[0:2] 170 | if random.random() <= prob and img_height >= 20 and img_width >= 20: 171 | new_img = tia_stretch(new_img, random.randint(3, 6)) 172 | 173 | if config.perspective: 174 | if random.random() <= prob: 175 | new_img = tia_perspective(new_img) 176 | 177 | if config.crop: 178 | img_height, img_width = img.shape[0:2] 179 | if random.random() <= prob and img_height >= 20 and img_width >= 20: 180 | new_img = get_crop(new_img) 181 | 182 | if config.blur: 183 | if random.random() <= prob: 184 | new_img = blur(new_img) 185 | if config.color: 186 | if random.random() <= prob: 187 | new_img = cvtColor(new_img) 188 | if config.jitter: 189 | new_img = jitter(new_img) 190 | if config.noise: 191 | if random.random() <= prob: 192 | new_img = add_gasuss_noise(new_img) 193 | if config.reverse: 194 | if random.random() <= prob: 195 | new_img = 255 - new_img 196 | return new_img 197 | -------------------------------------------------------------------------------- /data_loader/img_aug/text_image_aug/__init__.py: -------------------------------------------------------------------------------- 1 | from .augment import tia_perspective, tia_distort, tia_stretch 2 | 3 | __all__ = ['tia_distort', 'tia_stretch', 'tia_perspective'] 4 | -------------------------------------------------------------------------------- /data_loader/img_aug/text_image_aug/augment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .warp_mls import WarpMLS 3 | 4 | 5 | def tia_distort(src, segment=4): 6 | img_h, img_w = src.shape[:2] 7 | 8 | cut = img_w // segment 9 | thresh = cut // 3 10 | 11 | src_pts = list() 12 | dst_pts = list() 13 | 14 | src_pts.append([0, 0]) 15 | src_pts.append([img_w, 0]) 16 | src_pts.append([img_w, img_h]) 17 | src_pts.append([0, img_h]) 18 | 19 | dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) 20 | dst_pts.append( 21 | [img_w - np.random.randint(thresh), np.random.randint(thresh)]) 22 | dst_pts.append( 23 | [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]) 24 | dst_pts.append( 25 | [np.random.randint(thresh), img_h - np.random.randint(thresh)]) 26 | 27 | half_thresh = thresh * 0.5 28 | 29 | for cut_idx in np.arange(1, segment, 1): 30 | src_pts.append([cut * cut_idx, 0]) 31 | src_pts.append([cut * cut_idx, img_h]) 32 | dst_pts.append([ 33 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 34 | np.random.randint(thresh) - half_thresh 35 | ]) 36 | dst_pts.append([ 37 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 38 | img_h + np.random.randint(thresh) - half_thresh 39 | ]) 40 | 41 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 42 | dst = trans.generate() 43 | 44 | return dst 45 | 46 | 47 | def tia_stretch(src, segment=4): 48 | img_h, img_w = src.shape[:2] 49 | 50 | cut = img_w // segment 51 | thresh = cut * 4 // 5 52 | 53 | src_pts = list() 54 | dst_pts = list() 55 | 56 | src_pts.append([0, 0]) 57 | src_pts.append([img_w, 0]) 58 | src_pts.append([img_w, img_h]) 59 | src_pts.append([0, img_h]) 60 | 61 | dst_pts.append([0, 0]) 62 | dst_pts.append([img_w, 0]) 63 | dst_pts.append([img_w, img_h]) 64 | dst_pts.append([0, img_h]) 65 | 66 | half_thresh = thresh * 0.5 67 | 68 | for cut_idx in np.arange(1, segment, 1): 69 | move = np.random.randint(thresh) - half_thresh 70 | src_pts.append([cut * cut_idx, 0]) 71 | src_pts.append([cut * cut_idx, img_h]) 72 | dst_pts.append([cut * cut_idx + move, 0]) 73 | dst_pts.append([cut * cut_idx + move, img_h]) 74 | 75 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 76 | dst = trans.generate() 77 | 78 | return dst 79 | 80 | 81 | def tia_perspective(src): 82 | img_h, img_w = src.shape[:2] 83 | 84 | thresh = img_h // 2 85 | 86 | src_pts = list() 87 | dst_pts = list() 88 | 89 | src_pts.append([0, 0]) 90 | src_pts.append([img_w, 0]) 91 | src_pts.append([img_w, img_h]) 92 | src_pts.append([0, img_h]) 93 | 94 | dst_pts.append([0, np.random.randint(thresh)]) 95 | dst_pts.append([img_w, np.random.randint(thresh)]) 96 | dst_pts.append([img_w, img_h - np.random.randint(thresh)]) 97 | dst_pts.append([0, img_h - np.random.randint(thresh)]) 98 | 99 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 100 | dst = trans.generate() 101 | 102 | return dst -------------------------------------------------------------------------------- /data_loader/img_aug/text_image_aug/warp_mls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class WarpMLS: 5 | def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.): 6 | self.src = src 7 | self.src_pts = src_pts 8 | self.dst_pts = dst_pts 9 | self.pt_count = len(self.dst_pts) 10 | self.dst_w = dst_w 11 | self.dst_h = dst_h 12 | self.trans_ratio = trans_ratio 13 | self.grid_size = 100 14 | self.rdx = np.zeros((self.dst_h, self.dst_w)) 15 | self.rdy = np.zeros((self.dst_h, self.dst_w)) 16 | 17 | @staticmethod 18 | def __bilinear_interp(x, y, v11, v12, v21, v22): 19 | return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * 20 | (1 - y) + v22 * y) * x 21 | 22 | def generate(self): 23 | self.calc_delta() 24 | return self.gen_img() 25 | 26 | def calc_delta(self): 27 | w = np.zeros(self.pt_count, dtype=np.float32) 28 | 29 | if self.pt_count < 2: 30 | return 31 | 32 | i = 0 33 | while 1: 34 | if self.dst_w <= i < self.dst_w + self.grid_size - 1: 35 | i = self.dst_w - 1 36 | elif i >= self.dst_w: 37 | break 38 | 39 | j = 0 40 | while 1: 41 | if self.dst_h <= j < self.dst_h + self.grid_size - 1: 42 | j = self.dst_h - 1 43 | elif j >= self.dst_h: 44 | break 45 | 46 | sw = 0 47 | swp = np.zeros(2, dtype=np.float32) 48 | swq = np.zeros(2, dtype=np.float32) 49 | new_pt = np.zeros(2, dtype=np.float32) 50 | cur_pt = np.array([i, j], dtype=np.float32) 51 | 52 | k = 0 53 | for k in range(self.pt_count): 54 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 55 | break 56 | 57 | w[k] = 1. / ( 58 | (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + 59 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) 60 | 61 | sw += w[k] 62 | swp = swp + w[k] * np.array(self.dst_pts[k]) 63 | swq = swq + w[k] * np.array(self.src_pts[k]) 64 | 65 | if k == self.pt_count - 1: 66 | pstar = 1 / sw * swp 67 | qstar = 1 / sw * swq 68 | 69 | miu_s = 0 70 | for k in range(self.pt_count): 71 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 72 | continue 73 | pt_i = self.dst_pts[k] - pstar 74 | miu_s += w[k] * np.sum(pt_i * pt_i) 75 | 76 | cur_pt -= pstar 77 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) 78 | 79 | for k in range(self.pt_count): 80 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 81 | continue 82 | 83 | pt_i = self.dst_pts[k] - pstar 84 | pt_j = np.array([-pt_i[1], pt_i[0]]) 85 | 86 | tmp_pt = np.zeros(2, dtype=np.float32) 87 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ 88 | np.sum(pt_j * cur_pt) * self.src_pts[k][1] 89 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ 90 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] 91 | tmp_pt *= (w[k] / miu_s) 92 | new_pt += tmp_pt 93 | 94 | new_pt += qstar 95 | else: 96 | new_pt = self.src_pts[k] 97 | 98 | self.rdx[j, i] = new_pt[0] - i 99 | self.rdy[j, i] = new_pt[1] - j 100 | 101 | j += self.grid_size 102 | i += self.grid_size 103 | 104 | def gen_img(self): 105 | src_h, src_w = self.src.shape[:2] 106 | dst = np.zeros_like(self.src, dtype=np.float32) 107 | 108 | for i in np.arange(0, self.dst_h, self.grid_size): 109 | for j in np.arange(0, self.dst_w, self.grid_size): 110 | ni = i + self.grid_size 111 | nj = j + self.grid_size 112 | w = h = self.grid_size 113 | if ni >= self.dst_h: 114 | ni = self.dst_h - 1 115 | h = ni - i + 1 116 | if nj >= self.dst_w: 117 | nj = self.dst_w - 1 118 | w = nj - j + 1 119 | 120 | di = np.reshape(np.arange(h), (-1, 1)) 121 | dj = np.reshape(np.arange(w), (1, -1)) 122 | delta_x = self.__bilinear_interp( 123 | di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], 124 | self.rdx[ni, j], self.rdx[ni, nj]) 125 | delta_y = self.__bilinear_interp( 126 | di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], 127 | self.rdy[ni, j], self.rdy[ni, nj]) 128 | nx = j + dj + delta_x * self.trans_ratio 129 | ny = i + di + delta_y * self.trans_ratio 130 | nx = np.clip(nx, 0, src_w - 1) 131 | ny = np.clip(ny, 0, src_h - 1) 132 | nxi = np.array(np.floor(nx), dtype=np.int32) 133 | nyi = np.array(np.floor(ny), dtype=np.int32) 134 | nxi1 = np.array(np.ceil(nx), dtype=np.int32) 135 | nyi1 = np.array(np.ceil(ny), dtype=np.int32) 136 | 137 | if len(self.src.shape) == 3: 138 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) 139 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) 140 | else: 141 | x = ny - nyi 142 | y = nx - nxi 143 | dst[i:i + h, j:j + w] = self.__bilinear_interp( 144 | x, y, self.src[nyi, nxi], self.src[nyi, nxi1], 145 | self.src[nyi1, nxi], self.src[nyi1, nxi1]) 146 | 147 | dst = np.clip(dst, 0, 255) 148 | dst = np.array(dst, dtype=np.uint8) 149 | 150 | return dst -------------------------------------------------------------------------------- /data_loader/rec_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import codecs 4 | import random 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from utils.string_utils import CharacterJson 8 | from data_loader.img_aug import * 9 | 10 | 11 | class RecDataSet(Dataset): 12 | def __init__(self, config, logger, mode="train"): 13 | dataset_config = config[mode]["dataset"] 14 | global_config = config["global"] 15 | self.base_dir = dataset_config["data_base_dir"] 16 | self.mode = mode 17 | self.logger = logger 18 | cj = CharacterJson(global_config["character_json_path"]) 19 | self.char2idx = cj.char2idx 20 | self.max_text_len = global_config["max_text_len"] 21 | self.data_lines = self.get_image_info_list(dataset_config["ano_file_path"]) 22 | self.transforms = self._transforms_func_lst(dataset_config["transforms"]) 23 | if dataset_config["do_shuffle"]: 24 | random.shuffle(self.data_lines) 25 | 26 | def __len__(self): 27 | return len(self.data_lines) 28 | 29 | @staticmethod 30 | def _transforms_func_lst(config): 31 | func_lst = [] 32 | for _transform in config: 33 | operator = list(_transform.keys())[0] 34 | params = dict() if _transform[operator] is None else _transform[operator] 35 | func_name = eval(operator)(**params) 36 | func_lst.append(func_name) 37 | return func_lst 38 | 39 | def get_image_info_list(self, file_path): 40 | """数据文件以\t分割""" 41 | lines = [] 42 | with codecs.open(file_path, "r", "utf8") as f: 43 | for line in f.readlines(): 44 | tmp_data = line.strip().split("\t") 45 | if len(tmp_data) != 2: 46 | self.logger.warn(f"{line}数据格式不对") 47 | continue 48 | image_path = os.path.join(self.base_dir, tmp_data[0]) 49 | if not os.path.exists(image_path): 50 | self.logger.warn(f"{image_path}图片文件不存在") 51 | continue 52 | lines.append([tmp_data[0], tmp_data[1]]) 53 | return lines 54 | 55 | def rec_label_encoder(self, label_str): 56 | labels = [] 57 | for char in label_str: 58 | if char not in self.char2idx.keys(): 59 | continue 60 | labels.append(self.char2idx[char]) 61 | if len(labels) > self.max_text_len: 62 | return 63 | sequence_length = len(labels) 64 | labels = labels + [self.char2idx[""]] * (self.max_text_len - len(labels)) 65 | labels = np.array(labels, dtype=np.int) 66 | return labels, sequence_length 67 | 68 | def __getitem__(self, index): 69 | try: 70 | data_line = self.data_lines[index] 71 | image_path = os.path.join(self.base_dir, data_line[0]) 72 | label_idx, sequence_length = self.rec_label_encoder(data_line[1]) 73 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 默认BGR CHANNEL_LAST 74 | if image is None: 75 | self.logger.info(image_path) 76 | data = {"label_idx": label_idx, "sequence_length": sequence_length, "image": image} 77 | for _transform in self.transforms: 78 | data = _transform(data) 79 | except Exception as e: 80 | self.logger.error(e) 81 | data = [] 82 | 83 | if not data: 84 | return self.__getitem__(np.random.randint(self.__len__())) 85 | return data 86 | -------------------------------------------------------------------------------- /lite_ocr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import json 5 | import argparse 6 | import time 7 | import copy 8 | import codecs 9 | import numpy as np 10 | from functools import partial 11 | from config.load_conf import ReadConfig 12 | import onnxruntime as rt 13 | from postprocess import build_post_process 14 | from data_loader.img_aug import * 15 | 16 | 17 | def main(params): 18 | pt = LiteOcr(params) 19 | pt.predict() 20 | 21 | 22 | class LiteOcr(object): 23 | def __init__(self, params): 24 | self._global_param = params["global"] 25 | self._det_param = params["det"] 26 | self._rec_param = params["rec"] 27 | self.image_dir_or_path = self._global_param["image_dir_or_path"] 28 | self._image_list = self._read_images() 29 | 30 | self._det_post_process = build_post_process(self._det_param["post_process"]) 31 | rec_conf = self._rec_param["post_process"] 32 | rec_conf["character_json_path"] = self._global_param["character_json_path"] 33 | self._rec_post_process = build_post_process(rec_conf) 34 | self._det_transforms = self._transforms_func_lst(self._det_param["transforms"]) 35 | 36 | self.det_sess = rt.InferenceSession(self._global_param["infer_det_path"]) 37 | self.rec_sess = rt.InferenceSession(self._global_param["infer_rec_path"]) 38 | 39 | if not os.path.exists(self._global_param["res_save_dir"]): 40 | os.makedirs(self._global_param["res_save_dir"]) 41 | 42 | @staticmethod 43 | def _transforms_func_lst(config): 44 | func_lst = [] 45 | for _transform in config: 46 | operator = list(_transform.keys())[0] 47 | params = dict() if _transform[operator] is None else _transform[operator] 48 | func_name = eval(operator)(**params) 49 | func_lst.append(func_name) 50 | return func_lst 51 | 52 | def _read_images(self): 53 | imgs_lists = [] 54 | if self.image_dir_or_path is None or not os.path.exists(self.image_dir_or_path): 55 | raise Exception("not found any img file in {}".format(self.image_dir_or_path)) 56 | 57 | img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'} 58 | if os.path.isfile(self.image_dir_or_path) and \ 59 | os.path.splitext(self.image_dir_or_path)[-1][1:].lower() in img_end: 60 | imgs_lists.append(self.image_dir_or_path) 61 | elif os.path.isdir(self.image_dir_or_path): 62 | for single_file in os.listdir(self.image_dir_or_path): 63 | file_path = os.path.join(self.image_dir_or_path, single_file) 64 | if os.path.isfile(file_path) and os.path.splitext(file_path)[-1][1:].lower() in img_end: 65 | imgs_lists.append(file_path) 66 | if len(imgs_lists) == 0: 67 | raise Exception("not found any img file in {}".format(self.image_dir_or_path)) 68 | return imgs_lists 69 | 70 | @staticmethod 71 | def _get_rotate_crop_image(img, points): 72 | left = int(np.min(points[:, 0])) 73 | right = int(np.max(points[:, 0])) 74 | top = int(np.min(points[:, 1])) 75 | bottom = int(np.max(points[:, 1])) 76 | img_crop = img[top:bottom, left:right, :].copy() 77 | points[:, 0] = points[:, 0] - left 78 | points[:, 1] = points[:, 1] - top 79 | img_crop_width = int(np.linalg.norm(points[0] - points[1])) 80 | img_crop_height = int(np.linalg.norm(points[0] - points[3])) 81 | pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) 82 | 83 | M = cv2.getPerspectiveTransform(points, pts_std) 84 | dst_img = cv2.warpPerspective( 85 | img_crop, 86 | M, (img_crop_width, img_crop_height), 87 | borderMode=cv2.BORDER_REPLICATE) 88 | dst_img_height, dst_img_width = dst_img.shape[0:2] 89 | if dst_img_height * 1.0 / dst_img_width >= 2: 90 | dst_img = np.rot90(dst_img) 91 | return dst_img 92 | 93 | def predict(self): 94 | result = [] 95 | for image_path in self._image_list: 96 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 默认BGR CHANNEL_LAST 97 | if image is None: 98 | print("reading image_path: {} failed".format(image_path)) 99 | continue 100 | data = {"image": image} 101 | for _transform in self._det_transforms: 102 | data = _transform(data) 103 | 104 | for key, val in data.items(): 105 | data[key] = np.expand_dims(val, axis=0) 106 | 107 | start_time = time.time() 108 | out = self.det_sess.run(["output"], {"input": data["image"]})[0] 109 | preds = torch.from_numpy(out) 110 | 111 | print("image: {} \texpend time: {:.4f}".format(image_path, time.time() - start_time)) 112 | boxes_batch, scores_batch = self._det_post_process(preds, data) 113 | 114 | results = [] 115 | for idx, (box, score) in enumerate(zip(boxes_batch[0], scores_batch)): 116 | tmp_box = copy.deepcopy(box) 117 | tmp_img = self._get_rotate_crop_image(image, tmp_box.astype(np.float32)) 118 | scale = tmp_img.shape[0] * 1.0 / 32 119 | w = int(tmp_img.shape[1] / scale) 120 | line_img = RecResizeImg(image_shape=[3, 32, w])({"image": tmp_img})["image"] 121 | preds = self.rec_sess.run(["output"], {"input": np.expand_dims(line_img, axis=0)})[0] 122 | line_text, line_score = self._rec_post_process(preds)[0] 123 | tmp = dict() 124 | tmp["file_name"] = image_path 125 | if line_text.strip() != '': 126 | tmp["text"] = line_text.replace(" ", "").replace(" ", "") 127 | bbox = tmp_box.tolist() 128 | tmp["score"] = round(float(score*line_score), 3) 129 | tmp["bbox"] = [ 130 | bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1], 131 | bbox[2][0], bbox[2][1], bbox[3][0], bbox[3][1] 132 | ] 133 | results.append(tmp) 134 | 135 | with codecs.open(os.path.join(self._global_param["res_save_dir"], "result.txt"), "a", "utf8") as f: 136 | for res in result: 137 | f.write(json.dumps(res, ensure_ascii=False)+"\n") 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument("-c", "--config", default="./config/lite_ocr.yml", help="配置文件路径") 143 | det_conf_path = parser.parse_args().config 144 | 145 | cus_params = ReadConfig(det_conf_path).base_conf 146 | print("预测相关参数:\n{}".format(json.dumps(cus_params, indent=2, ensure_ascii=False))) 147 | main(cus_params) 148 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/logger/__init__.py -------------------------------------------------------------------------------- /logger/log_conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def logging_conf(log_path, level="INFO"): 5 | """ 6 | level可选参数为DEBUG, INFO, WARN, ERROR 7 | """ 8 | return { 9 | "loggers": { 10 | "mail": { 11 | "level": "CRITICAL", 12 | "propagate": False, 13 | "handlers": ["mail"] 14 | }, 15 | "data_det": { 16 | "level": level, 17 | "propagate": False, 18 | "handlers": ["data_det", "console"] 19 | }, 20 | "console": { 21 | "level": level, 22 | "propagate": False, 23 | "handlers": ["console"] 24 | }, 25 | }, 26 | "disable_existing_loggers": False, 27 | "handlers": { 28 | "data_det": { 29 | "formatter": "simple", 30 | "backupCount": 10, 31 | "class": "logging.handlers.RotatingFileHandler", 32 | "maxBytes": 10485760, 33 | "filename": os.path.join(log_path, "log.txt") 34 | }, 35 | "console": { 36 | "formatter": "default", 37 | "class": "logging.StreamHandler", 38 | "stream": "ext://sys.stdout" 39 | }, 40 | "mail": { 41 | "toaddrs": [""], 42 | "mailhost": ["smtp.exmail.qq.com", 25], 43 | "fromaddr": "", 44 | "level": "CRITICAL", 45 | "credentials": ["", ""], 46 | "formatter": "mail", 47 | "class": "logging.handlers.SMTPHandler", 48 | "subject": "XXXXX" 49 | } 50 | }, 51 | "formatters": { 52 | "default": { 53 | "datefmt": "%Y-%m-%d %H:%M:%S", 54 | "format": "%(asctime)s - %(levelname)s - %(module)s.%(name)s : %(message)s" 55 | }, 56 | "simple": { 57 | "format": "%(asctime)s - %(levelname)s - %(message)s" 58 | }, 59 | "mail": { 60 | "datefmt": "%Y-%m-%d %H:%M:%S", 61 | "format": "%(asctime)s : %(message)s" 62 | } 63 | }, 64 | "version": 1 65 | } 66 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from logger import log_conf 4 | import logging 5 | import logging.config 6 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 7 | level=logging.WARNING) 8 | LOG_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | 10 | 11 | def get_logger(log_name="data_det", level="INFO", log_path=LOG_PATH): 12 | """ 13 | :param log_name: 只提供data(debug) 和 mail(Critical) 14 | :param log_path: 默认目录为sc-log 15 | :param level: DEBUG, INFO, WARN, ERROR 16 | :return: 17 | """ 18 | if log_name == "data_det" and not os.path.isdir(log_path): 19 | os.makedirs(log_path) 20 | try: 21 | logging.config.dictConfig(log_conf.logging_conf(log_path, level)) 22 | except Exception as e1: 23 | print('日志初始化失败[%s]' % e1) 24 | sys.exit(1) 25 | logger = logging.getLogger(log_name) 26 | return logger 27 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .det_loss import L1BalanceCELoss 2 | from .ctc_loss import CTCLoss 3 | 4 | 5 | __all__ = ["build_loss"] 6 | 7 | 8 | def build_loss(config): 9 | module_name = config.pop("name") 10 | support_dict = ["L1BalanceCELoss", "CTCLoss"] 11 | assert module_name in support_dict 12 | module_class = eval(module_name)(**config) 13 | return module_class 14 | -------------------------------------------------------------------------------- /losses/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CTCLoss(nn.Module): 6 | def __init__(self): 7 | super(CTCLoss, self).__init__() 8 | self.ctc_loss = nn.CTCLoss( 9 | blank=0, 10 | reduction="mean" 11 | ) 12 | 13 | def forward(self, x, batch): 14 | """ 15 | :param x: T * N * Classes 16 | :param batch: 17 | :return: 18 | """ 19 | t, n, _ = x.shape 20 | loss = self.ctc_loss( 21 | log_probs=x, # T * N * Classes 22 | targets=batch["label_idx"], # 训练标签 23 | input_lengths=torch.tensor([t] * n), # 固定的额输出序列长度 24 | target_lengths=batch["sequence_length"] # 真实长度 25 | ) 26 | return {"loss": loss} 27 | -------------------------------------------------------------------------------- /losses/det_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | from losses.loss import MaskDiceLoss, MaskL1Loss, BalanceCrossEntropyLoss 4 | 5 | 6 | class L1BalanceCELoss(nn.Module): 7 | def __init__(self, negative_ratio=3.0, eps=1e-6, l1_scale=10, bce_scale=5): 8 | super(L1BalanceCELoss, self).__init__() 9 | self.dice_loss = MaskDiceLoss(eps=eps) 10 | self.l1_loss = MaskL1Loss(eps=eps) 11 | self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=negative_ratio, eps=eps) 12 | 13 | self.l1_scale = l1_scale 14 | self.bce_scale = bce_scale 15 | 16 | def forward(self, pred, batch): 17 | pred_binary = pred[:, 0, :, :] 18 | pred_thresh = pred[:, 1, :, :] 19 | pred_binary_thresh = pred[:, 2, :, :] 20 | 21 | bce_loss = self.bce_loss(pred_binary, batch['prob_map'], batch['prob_mask']) 22 | l1_loss = self.l1_loss(pred_thresh, batch['thresh_map'], batch['thresh_mask']) 23 | dice_loss = self.dice_loss(pred_binary_thresh, batch['prob_map'], batch['prob_mask']) 24 | loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale 25 | loss_dict = OrderedDict() 26 | loss_dict["loss"] = loss 27 | loss_dict["prob_loss"] = bce_loss 28 | loss_dict["thresh_loss"] = l1_loss 29 | loss_dict["binary_loss"] = dice_loss 30 | return loss_dict 31 | 32 | 33 | if __name__ == "__main__": 34 | import torch 35 | lbce = L1BalanceCELoss() 36 | input_x = torch.rand(4, 3, 160, 160) 37 | input_y = { 38 | "prob_map": torch.rand(4, 160, 160), 39 | "prob_mask": torch.rand(4, 160, 160), 40 | "thresh_map": torch.rand(4, 160, 160), 41 | "thresh_mask": torch.rand(4, 160, 160), 42 | } 43 | print(lbce(input_x, input_y)) -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class MaskDiceLoss(nn.Module): 7 | """ 8 | 最终的thresh_binary采用DiceLoss 9 | """ 10 | def __init__(self, eps=1e-6): 11 | super(MaskDiceLoss, self).__init__() 12 | self.eps = eps 13 | 14 | def forward(self, pred, gt, mask): 15 | """ 16 | :param pred: N * H * W 17 | :param gt: N* H * W 18 | :param mask: N * H * W 19 | :return: 20 | """ 21 | intersection = (pred * gt * mask).sum() 22 | union = (pred * mask).sum() + (gt * mask).sum() + self.eps 23 | loss = 1 - 2.0 * intersection / union 24 | return loss 25 | 26 | 27 | class MaskL1Loss(nn.Module): 28 | def __init__(self, eps=1e-6): 29 | super(MaskL1Loss, self).__init__() 30 | self.eps = eps 31 | 32 | def forward(self, pred, gt, mask): 33 | """ 34 | :param pred: N * H * W 35 | :param gt: N * H * W 36 | :param mask: N * H * W 37 | :return: 38 | """ 39 | loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps) 40 | return loss 41 | 42 | 43 | class BalanceCrossEntropyLoss(nn.Module): 44 | def __init__(self, negative_ratio=3.0, eps=1e-6): 45 | super(BalanceCrossEntropyLoss, self).__init__() 46 | self.negative_ratio = negative_ratio 47 | self.eps = eps 48 | 49 | def forward(self, pred, gt, mask): 50 | """ 51 | :param pred: N * H * W 52 | :param gt: N * H * W 53 | :param mask: N * H * W 54 | :return: 55 | """ 56 | loss = F.binary_cross_entropy(pred, gt, reduction="none") 57 | 58 | positive = (gt * mask).byte() 59 | negative = ((1-gt) * mask).byte() 60 | 61 | positive_loss = loss * positive.float() 62 | negative_loss = loss * negative.float() 63 | 64 | positive_count = int(positive.float().sum()) 65 | negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio)) 66 | 67 | negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) 68 | balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps) 69 | 70 | return balance_loss 71 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .det_metric import DetMetric 2 | from .rec_metric import RecMetric 3 | 4 | 5 | __all__ = ["build_metric"] 6 | 7 | 8 | def build_metric(config): 9 | module_name = config.pop("name") 10 | support_dict = ["DetMetric", "RecMetric"] 11 | assert module_name in support_dict 12 | module_class = eval(module_name)(**config) 13 | return module_class 14 | -------------------------------------------------------------------------------- /metrics/det_metric.py: -------------------------------------------------------------------------------- 1 | from .eval_det_iou import DetectionIoUEvaluator 2 | 3 | 4 | class DetMetric(object): 5 | def __init__(self, main_indicator='hmean'): 6 | self.evaluator = DetectionIoUEvaluator() 7 | self.main_indicator = main_indicator 8 | self._reset() 9 | 10 | def __call__(self, post_result, batch): 11 | preds = post_result[0] 12 | gt_polygons, ignore_tags = batch["polys"], batch["ignore_tags"] 13 | for pred, gt_polyons, ignore_tags in zip(preds, gt_polygons, ignore_tags): 14 | # prepare gt 15 | gt_info_list = [{ 16 | 'points': gt_polyon, 17 | 'text': '', 18 | 'ignore': ignore_tag 19 | } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)] 20 | # prepare det 21 | det_info_list = [{ 22 | 'points': det_polyon, 23 | 'text': '' 24 | } for det_polyon in pred] 25 | result = self.evaluator.evaluate_image(gt_info_list, det_info_list) 26 | self.results.append(result) 27 | 28 | def get_metric(self): 29 | """ 30 | return metrics { 31 | 'precision': 0, 32 | 'recall': 0, 33 | 'hmean': 0 34 | } 35 | """ 36 | 37 | metircs = self.evaluator.combine_results(self.results) 38 | self._reset() 39 | return metircs 40 | 41 | def _reset(self): 42 | self.results = [] # clear results 43 | -------------------------------------------------------------------------------- /metrics/eval_det_iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | class DetectionIoUEvaluator(object): 6 | def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5): 7 | self.iou_constraint = iou_constraint 8 | self.area_precision_constraint = area_precision_constraint 9 | 10 | def evaluate_image(self, gt, pred): 11 | def get_union(pd, pg): 12 | return Polygon(pd).union(Polygon(pg)).area 13 | 14 | def get_intersection_over_union(pd, pg): 15 | return get_intersection(pd, pg) / get_union(pd, pg) 16 | 17 | def get_intersection(pd, pg): 18 | return Polygon(pd).intersection(Polygon(pg)).area 19 | 20 | matched_sum = 0 21 | num_global_care_gt = 0 22 | num_global_care_det = 0 23 | det_matched = 0 24 | iou_mat = np.empty([1, 1]) 25 | gt_pols = [] 26 | det_pols = [] 27 | gt_pol_points = [] 28 | det_pol_points = [] 29 | # Array of Ground Truth Polygons' keys marked as don't Care 30 | gt_dont_care_pols_num = [] 31 | # Array of Detected Polygons' matched with a don't Care GT 32 | det_dont_care_pols_num = [] 33 | 34 | pairs = [] 35 | det_matched_nums = [] 36 | evaluation_log = "" 37 | 38 | for n in range(len(gt)): 39 | points = gt[n]['points'] 40 | dont_care = gt[n]['ignore'] 41 | if not Polygon(points).is_valid or not Polygon(points).is_simple: 42 | continue 43 | 44 | gt_pol = points 45 | gt_pols.append(gt_pol) 46 | gt_pol_points.append(points) 47 | if dont_care: 48 | gt_dont_care_pols_num.append(len(gt_pols) - 1) 49 | 50 | evaluation_log += "GT polygons: " + str(len(gt_pols)) + ( 51 | " (" + str(len(gt_dont_care_pols_num)) + " don't care)\n" 52 | if len(gt_dont_care_pols_num) > 0 else "\n") 53 | 54 | for n in range(len(pred)): 55 | points = pred[n]['points'] 56 | if not Polygon(points).is_valid or not Polygon(points).is_simple: 57 | continue 58 | 59 | det_pol = points 60 | det_pols.append(det_pol) 61 | det_pol_points.append(points) 62 | if len(gt_dont_care_pols_num) > 0: 63 | for dont_care_pol in gt_dont_care_pols_num: 64 | dont_care_pol = gt_pols[dont_care_pol] 65 | intersected_area = get_intersection(dont_care_pol, det_pol) 66 | pd_dimensions = Polygon(det_pol).area 67 | precision = 0 if pd_dimensions == 0 else intersected_area / pd_dimensions 68 | if precision > self.area_precision_constraint: 69 | det_dont_care_pols_num.append(len(det_pols) - 1) 70 | break 71 | 72 | evaluation_log += "DET polygons: " + str(len(det_pols)) + ( 73 | " (" + str(len(det_dont_care_pols_num)) + " don't care)\n" 74 | if len(det_dont_care_pols_num) > 0 else "\n") 75 | 76 | if len(gt_pols) > 0 and len(det_pols) > 0: 77 | # Calculate IoU and precision matrixs 78 | output_shape = [len(gt_pols), len(det_pols)] 79 | iou_mat = np.empty(output_shape) 80 | gt_rect_mat = np.zeros(len(gt_pols), np.int8) 81 | det_rect_mat = np.zeros(len(det_pols), np.int8) 82 | for gt_num in range(len(gt_pols)): 83 | for det_num in range(len(det_pols)): 84 | p_g = gt_pols[gt_num] 85 | p_d = det_pols[det_num] 86 | iou_mat[gt_num, det_num] = get_intersection_over_union(p_d, p_g) 87 | 88 | for gt_num in range(len(gt_pols)): 89 | for det_num in range(len(det_pols)): 90 | if gt_rect_mat[gt_num] == 0 and det_rect_mat[det_num] == 0 and\ 91 | gt_num not in gt_dont_care_pols_num and det_num not in det_dont_care_pols_num: 92 | if iou_mat[gt_num, det_num] > self.iou_constraint: 93 | gt_rect_mat[gt_num] = 1 94 | det_rect_mat[det_num] = 1 95 | det_matched += 1 96 | pairs.append({'gt': gt_num, 'det': det_num}) 97 | det_matched_nums.append(det_num) 98 | evaluation_log += "Match GT #" + str(gt_num) + " with Det #" + str(det_num) + "\n" 99 | 100 | num_gt_care = (len(gt_pols) - len(gt_dont_care_pols_num)) 101 | num_det_care = (len(det_pols) - len(det_dont_care_pols_num)) 102 | if num_gt_care == 0: 103 | recall = float(1) 104 | precision = float(0) if num_det_care > 0 else float(1) 105 | else: 106 | recall = float(det_matched) / num_gt_care 107 | precision = 0 if num_det_care == 0 else float(det_matched) / num_det_care 108 | 109 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 110 | 111 | matched_sum += det_matched 112 | num_global_care_gt += num_gt_care 113 | num_global_care_det += num_det_care 114 | 115 | per_sample_metrics = { 116 | 'precision': precision, 117 | 'recall': recall, 118 | 'hmean': hmean, 119 | 'pairs': pairs, 120 | 'iouMat': [] if len(det_pols) > 100 else iou_mat.tolist(), 121 | 'gtPolPoints': gt_pol_points, 122 | 'detPolPoints': det_pol_points, 123 | 'gtCare': num_gt_care, 124 | 'detCare': num_det_care, 125 | 'gtDontCare': gt_dont_care_pols_num, 126 | 'detDontCare': det_dont_care_pols_num, 127 | 'detMatched': det_matched, 128 | 'evaluationLog': evaluation_log 129 | } 130 | 131 | return per_sample_metrics 132 | 133 | @staticmethod 134 | def combine_results(results): 135 | num_global_care_gt = 0 136 | num_global_care_det = 0 137 | matched_sum = 0 138 | for result in results: 139 | num_global_care_gt += result['gtCare'] 140 | num_global_care_det += result['detCare'] 141 | matched_sum += result['detMatched'] 142 | 143 | method_recall = 0 if num_global_care_gt == 0 else float( 144 | matched_sum) / num_global_care_gt 145 | method_precision = 0 if num_global_care_det == 0 else float( 146 | matched_sum) / num_global_care_det 147 | if method_recall + method_precision == 0: 148 | method_hmean = 0 149 | else: 150 | method_hmean = 2 * method_recall * method_precision / (method_recall + method_precision) 151 | method_metrics = { 152 | 'precision': method_precision, 153 | 'recall': method_recall, 154 | 'hmean': method_hmean 155 | } 156 | 157 | return method_metrics 158 | -------------------------------------------------------------------------------- /metrics/rec_metric.py: -------------------------------------------------------------------------------- 1 | import Levenshtein 2 | 3 | 4 | class RecMetric(object): 5 | def __init__(self, main_indicator='acc'): 6 | self.main_indicator = main_indicator 7 | self._reset() 8 | 9 | def __call__(self, post_result, batch=None): 10 | preds, labels = post_result 11 | correct_num = 0 12 | all_num = 0 13 | norm_edit_dis = 0.0 14 | for (pred, pred_score), (target, _) in zip(preds, labels): 15 | pred = pred.replace(" ", "") 16 | target = target.replace(" ", "") 17 | norm_edit_dis += Levenshtein.distance(pred, target) / max( 18 | len(pred), len(target), 1) 19 | if pred == target: 20 | correct_num += 1 21 | all_num += 1 22 | self.correct_num += correct_num 23 | self.all_num += all_num 24 | self.norm_edit_dis += norm_edit_dis 25 | return { 26 | 'acc': correct_num / all_num, 27 | 'norm_edit_dis': 1 - norm_edit_dis / all_num 28 | } 29 | 30 | def get_metric(self): 31 | """ 32 | return metrics { 33 | 'acc': 0, 34 | 'norm_edit_dis': 0, 35 | } 36 | """ 37 | acc = 1.0 * self.correct_num / self.all_num 38 | norm_edit_dis = 1 - self.norm_edit_dis / self.all_num 39 | self._reset() 40 | return {'acc': acc, 'norm_edit_dis': norm_edit_dis} 41 | 42 | def _reset(self): 43 | self.correct_num = 0 44 | self.all_num = 0 45 | self.norm_edit_dis = 0 46 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .det.dbnet import DBNet 2 | from .rec.rnn import CRNN 3 | 4 | 5 | __all__ = ["build_model"] 6 | 7 | 8 | def build_model(config): 9 | module_name = config.pop("name") 10 | support_dict = ["DBNet", "CRNN"] 11 | assert module_name in support_dict 12 | module_class = eval(module_name)(**config) 13 | return module_class 14 | -------------------------------------------------------------------------------- /nets/det/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/nets/det/__init__.py -------------------------------------------------------------------------------- /nets/det/dbnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DBNet(nn.Module): 6 | 7 | def __init__( 8 | self, 9 | inner_channel, 10 | k, 11 | backbone 12 | ): 13 | """ 14 | :param inner_channel: FPN对齐维度 15 | :param k: 16 | :param backbone: 17 | """ 18 | super(DBNet, self).__init__() 19 | self.inner_channel = inner_channel 20 | self.k = k 21 | self.backbone = self._call_backbone(backbone) 22 | self.channel_size_lst = self.backbone.layer_out_channels 23 | self.in5 = nn.Conv2d( 24 | in_channels=self.channel_size_lst[-1], 25 | out_channels=self.inner_channel, 26 | kernel_size=(1, 1), 27 | stride=(1, 1), 28 | bias=False 29 | ) 30 | self.in4 = nn.Conv2d( 31 | in_channels=self.channel_size_lst[-2], 32 | out_channels=self.inner_channel, 33 | kernel_size=(1, 1), 34 | stride=(1, 1), 35 | bias=False 36 | ) 37 | self.in3 = nn.Conv2d( 38 | in_channels=self.channel_size_lst[-3], 39 | out_channels=self.inner_channel, 40 | kernel_size=(1, 1), 41 | stride=(1, 1), 42 | bias=False 43 | ) 44 | self.in2 = nn.Conv2d( 45 | in_channels=self.channel_size_lst[-4], 46 | out_channels=self.inner_channel, 47 | kernel_size=(1, 1), 48 | stride=(1, 1), 49 | bias=False 50 | ) 51 | 52 | self.up5 = nn.Upsample(scale_factor=2, mode="nearest") 53 | self.up4 = nn.Upsample(scale_factor=2, mode="nearest") 54 | self.up3 = nn.Upsample(scale_factor=2, mode="nearest") 55 | 56 | self.out5 = nn.Sequential( 57 | nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), (1, 1), (1, 1), bias=False), 58 | nn.Upsample(scale_factor=8, mode="nearest") 59 | ) 60 | self.out4 = nn.Sequential( 61 | nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), (1, 1), (1, 1), bias=False), 62 | nn.Upsample(scale_factor=4, mode="nearest") 63 | ) 64 | self.out3 = nn.Sequential( 65 | nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), (1, 1), (1, 1), bias=False), 66 | nn.Upsample(scale_factor=2, mode="nearest") 67 | ) 68 | self.out2 = nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), (1, 1), (1, 1), bias=False) 69 | 70 | self.binarize = nn.Sequential( 71 | nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), padding=(1, 1), bias=False), 72 | nn.BatchNorm2d(self.inner_channel//4), 73 | nn.ReLU(inplace=True), 74 | nn.ConvTranspose2d(self.inner_channel//4, self.inner_channel//4, (2, 2), (2, 2)), 75 | nn.BatchNorm2d(self.inner_channel//4), 76 | nn.ReLU(inplace=True), 77 | nn.ConvTranspose2d(self.inner_channel//4, 1, (2, 2), (2, 2)), 78 | nn.Sigmoid() 79 | ) 80 | 81 | self.thresh = nn.Sequential( 82 | nn.Conv2d(self.inner_channel, self.inner_channel//4, (3, 3), (1, 1), (1, 1), bias=False), 83 | nn.BatchNorm2d(self.inner_channel//4), 84 | nn.ReLU(inplace=True), 85 | nn.ConvTranspose2d(self.inner_channel//4, self.inner_channel//4, (2, 2), (2, 2)), 86 | nn.BatchNorm2d(self.inner_channel//4), 87 | nn.ReLU(inplace=True), 88 | nn.ConvTranspose2d(self.inner_channel//4, 1, (2, 2), (2, 2)), 89 | nn.Sigmoid() 90 | ) 91 | self.weights_init() 92 | 93 | @staticmethod 94 | def _call_backbone(backbone): 95 | module_func = backbone.pop("name") 96 | if module_func == "det_mobilenet_v3": 97 | from nets.det.mobilenetv3 import det_mobilenet_v3 98 | module_func = eval(module_func)(**backbone) 99 | else: 100 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 101 | module_func = eval(module_func)(**backbone) 102 | return module_func 103 | 104 | def weights_init(self): 105 | for m in self.modules(): 106 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 107 | nn.init.kaiming_normal_(m.weight.data) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1.) 110 | m.bias.data.fill_(1e-4) 111 | 112 | def step_function(self, x, y): 113 | return torch.reciprocal((1+torch.exp(-self.k * (x-y)))) 114 | 115 | def forward(self, x): 116 | """ 117 | c2, c3, c4, c5依次为原图像H*W的1/4, 1/8, 1/16, 1/32 118 | 假设H*W为640*640 119 | """ 120 | features = self.backbone(x) 121 | c2, c3, c4, c5 = features 122 | 123 | # Channel统一 124 | in5 = self.in5(c5) # N * 256 * 20 * 20 125 | in4 = self.in4(c4) # N * 256 * 40 * 40 126 | in3 = self.in3(c3) # N * 256 * 80 * 80 127 | in2 = self.in2(c2) # N * 256 * 160 * 160 128 | 129 | # 上层特征上采样,依次和下层特征合并 130 | out5 = in5 # N * 256 * 20 * 20 131 | out4 = self.up5(in5) + in4 # N * 256 * 40 * 40 132 | out3 = self.up4(in4) + in3 # N * 256 * 80 * 80 133 | out2 = self.up3(in3) + in2 # N * 256 * 160 * 160 134 | 135 | # 降维,特征上采样 136 | p5 = self.out5(out5) # N * 64 * 160 * 160 137 | p4 = self.out4(out4) # N * 64 * 160 * 160 138 | p3 = self.out3(out3) # N * 64 * 160 * 160 139 | p2 = self.out2(out2) # N * 64 * 160 * 160 140 | 141 | # 后Concat 142 | fuse = torch.cat((p5, p4, p3, p2), dim=1) # # N * 256 * 160 * 160 143 | prob = self.binarize(fuse) 144 | if not self.training: 145 | # N * 1 * H * W 146 | return prob 147 | 148 | thresh = self.thresh(fuse) 149 | binary_thresh = self.step_function(prob, thresh) 150 | # N * 3 * H * W 151 | return torch.cat([prob, thresh, binary_thresh], dim=1) 152 | 153 | 154 | if __name__ == "__main__": 155 | model = DBNet(96, 50, backbone={"name": "det_mobilenet_v3"}) 156 | ignored_params = list(map(id, model.binarize.parameters())) 157 | print(ignored_params) 158 | # base_params = filter(lambda p: id(p) not in ignored_params, 159 | # model.parameters()) 160 | # 161 | # optimizer = torch.optim.SGD([ 162 | # {'params': base_params}, 163 | # {'params': model.fc.parameters(), 'lr': 1e-3} 164 | # ], lr=1e-2, momentum=0.9) 165 | # print(model) 166 | # for param in model.parameters(): 167 | # print(param) 168 | # import pdb 169 | # pdb.set_trace() -------------------------------------------------------------------------------- /nets/det/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils import model_zoo 5 | from collections import OrderedDict 6 | from nets.det.params_mapping import mobile_net_v3_mapping 7 | 8 | M_URL = "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" 9 | 10 | 11 | def _weights_init(m): 12 | if isinstance(m, nn.Conv2d): 13 | torch.nn.init.xavier_uniform_(m.weight) 14 | if m.bias is not None: 15 | torch.nn.init.zeros_(m.bias) 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | m.weight.data.normal_(0, 0.01) 21 | m.bias.data.zero_() 22 | 23 | 24 | def _make_divisible(v, divisor=8, min_value=None): 25 | # 等比例的增加和减少通道的个数 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | class SqueezeBlock(nn.Module): 36 | def __init__(self, exp_size, divide=4): 37 | super(SqueezeBlock, self).__init__() 38 | self.dense = nn.Sequential( 39 | nn.Linear(exp_size, exp_size // divide), # Squeeze线性连接 40 | nn.ReLU(inplace=True), 41 | nn.Linear(exp_size // divide, exp_size), # Excite线性连接 42 | ) 43 | 44 | def forward(self, x): 45 | batch, channels, height, width = x.size() 46 | # 1.池化 47 | out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1) 48 | out = self.dense(out) 49 | out = F.hardsigmoid(out, inplace=True) 50 | # resize 51 | out = out.view(batch, channels, 1, 1) 52 | # 相乘 53 | return out * x 54 | 55 | 56 | class MobileBlock(nn.Module): 57 | def __init__(self, in_channels, out_channels, kernal_size, stride, non_linear, _se, exp_size, dropout_rate=1.0): 58 | super(MobileBlock, self).__init__() 59 | self.out_channels = out_channels 60 | self.nonLinear = non_linear 61 | self.SE = _se 62 | self.dropout_rate = dropout_rate 63 | padding = (kernal_size - 1) // 2 64 | 65 | self.use_connect = (stride == 1 and in_channels == out_channels) # 残差条件 66 | 67 | if self.nonLinear == "RE": 68 | activation = nn.ReLU 69 | else: 70 | activation = nn.Hardswish 71 | 72 | # 1*1卷积 expand 73 | self.expand_conv = nn.Sequential( 74 | nn.Conv2d(in_channels, exp_size, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), 75 | nn.BatchNorm2d(exp_size), 76 | activation(inplace=True) 77 | ) 78 | # 膨胀的卷积操作, 深度卷积 3* 3 或者 5*5 79 | self.depth_conv = nn.Sequential( 80 | nn.Conv2d(exp_size, exp_size, kernel_size=kernal_size, stride=stride, padding=padding, groups=exp_size), 81 | nn.BatchNorm2d(exp_size), 82 | activation(inplace=True) 83 | ) 84 | 85 | if self.SE: 86 | self.squeeze_block = SqueezeBlock(exp_size) 87 | 88 | # 1*1卷积 逐点卷积 89 | self.point_conv = nn.Sequential( 90 | nn.Conv2d(exp_size, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 91 | nn.BatchNorm2d(out_channels), 92 | nn.Identity() 93 | ) 94 | 95 | def forward(self, x): 96 | # MobileNetV2 97 | out = self.expand_conv(x) # 1*1 卷积, 由输入通道转为膨胀通道,转换通道 in->exp 98 | out = self.depth_conv(out) # 3x3或5*5卷积,膨胀通道,使用步长stride 99 | 100 | # Squeeze and Excite 101 | if self.SE: 102 | out = self.squeeze_block(out) 103 | 104 | # 1x1卷积,由膨胀通道,转换为输出通道 105 | out = self.point_conv(out) # 转换通道 exp->out 106 | 107 | # 残差结构 108 | if self.use_connect: 109 | return x + out 110 | else: 111 | return out 112 | 113 | 114 | class MobileNetV3(nn.Module): 115 | def __init__(self, multiplier=0.5, use_se=False): 116 | super(MobileNetV3, self).__init__() 117 | self._multiplier = multiplier 118 | self._use_se = use_se 119 | # in_channel out_channel kernel_size stride nl se exp_size 120 | layer1_conf = [ 121 | [16, 16, 3, 1, "RE", False, 16], 122 | [16, 24, 3, 2, "RE", False, 64], 123 | [24, 24, 3, 1, "RE", False, 72] 124 | ] 125 | 126 | layer2_conf = [ 127 | [24, 40, 5, 2, "RE", True, 72], 128 | [40, 40, 5, 1, "RE", True, 120], 129 | [40, 40, 5, 1, "RE", True, 120] 130 | ] 131 | layer3_conf = [ 132 | [40, 80, 3, 2, "HS", False, 240], 133 | [80, 80, 3, 1, "HS", False, 200], 134 | [80, 80, 3, 1, "HS", False, 184], 135 | [80, 80, 3, 1, "HS", False, 184], 136 | [80, 112, 3, 1, "HS", True, 480], 137 | [112, 112, 3, 1, "HS", True, 672], 138 | [112, 160, 5, 1, "HS", True, 672] 139 | ] 140 | layer4_conf = [ 141 | [160, 160, 5, 2, "HS", True, 960], 142 | [160, 160, 5, 1, "HS", True, 960], 143 | ] 144 | cls_ch_squeeze = _make_divisible(960 * self._multiplier) 145 | 146 | layer1_conf = self._multi_layer_conf(layer1_conf) 147 | layer2_conf = self._multi_layer_conf(layer2_conf) 148 | layer3_conf = self._multi_layer_conf(layer3_conf) 149 | layer4_conf = self._multi_layer_conf(layer4_conf) 150 | self.layer_out_channels = [layer1_conf[-1][1], layer2_conf[-1][1], layer3_conf[-1][1], cls_ch_squeeze] 151 | 152 | self.init_conv = nn.Sequential( 153 | nn.Conv2d( 154 | in_channels=3, out_channels=_make_divisible(16 * self._multiplier), 155 | kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 156 | nn.BatchNorm2d(_make_divisible(16 * self._multiplier)), 157 | nn.Hardswish(inplace=True) 158 | ) 159 | 160 | self.layer1 = self._make_layer(layer1_conf) 161 | self.layer2 = self._make_layer(layer2_conf) 162 | self.layer3 = self._make_layer(layer3_conf) 163 | layer_4 = self._make_layer(layer4_conf, False) 164 | layer_4.append( 165 | nn.Sequential( 166 | nn.Conv2d(self.layer_out_channels[-2], self.layer_out_channels[-1], kernel_size=(1, 1), stride=(1, 1)), 167 | nn.BatchNorm2d(self.layer_out_channels[-1]), 168 | nn.Hardswish(inplace=True) 169 | ) 170 | ) 171 | self.layer4 = nn.Sequential(*layer_4) 172 | self.apply(_weights_init) 173 | 174 | def _multi_layer_conf(self, layer_conf): 175 | for lc in layer_conf: 176 | lc[0] = _make_divisible(lc[0] * self._multiplier) 177 | lc[1] = _make_divisible(lc[1] * self._multiplier) 178 | lc[-1] = _make_divisible(lc[-1] * self._multiplier) 179 | return layer_conf 180 | 181 | def _make_layer(self, layer_conf, sequential=True): 182 | block_list = [] 183 | for in_channels, out_channels, kernal_size, stride, activation, se, exp_size in layer_conf: 184 | # activation  NL: 非线性激活函数;HS: H-swish激活函数,RE:ReLU激活函数 185 | # SE Squeeze and Excite结构,是否压缩和激发 186 | # exp_size: 膨胀参数 187 | # MobileBlock瓶颈层 188 | se = se and self._use_se 189 | block_list.append(MobileBlock(in_channels, out_channels, kernal_size, stride, activation, se, exp_size)) 190 | if sequential: 191 | return nn.Sequential(*block_list) 192 | return block_list 193 | 194 | def forward(self, x): 195 | # 起始部分 196 | out = self.init_conv(x) 197 | # 中间部分 198 | layer1 = self.layer1(out) 199 | layer2 = self.layer2(layer1) 200 | layer3 = self.layer3(layer2) 201 | layer4 = self.layer4(layer3) 202 | return layer1, layer2, layer3, layer4 203 | 204 | 205 | def det_mobilenet_v3(pre_trained_dir=None, multiplier=0.5, use_se=False): 206 | if not pre_trained_dir: 207 | return MobileNetV3() 208 | 209 | assert multiplier == 1.0, not use_se 210 | mn_v3_model = MobileNetV3(multiplier=1.0, use_se=False) 211 | state_dict = mn_v3_model.state_dict() 212 | pre_state_dict = model_zoo.load_url(M_URL, model_dir=pre_trained_dir) 213 | param_state_dict = OrderedDict() 214 | for mn_key in state_dict.keys(): 215 | for map_mn, map_pre in mobile_net_v3_mapping.items(): 216 | map_key = mn_key.replace(map_mn, map_pre) 217 | if map_key not in pre_state_dict.keys(): 218 | continue 219 | param_state_dict[mn_key] = pre_state_dict[map_key] 220 | 221 | mn_v3_model.load_state_dict(param_state_dict, strict=False) 222 | return mn_v3_model 223 | 224 | 225 | if __name__ == "__main__": 226 | model11 = det_mobilenet_v3() 227 | x1 = torch.randn(8, 3, 224, 224) 228 | f1, f2, f3, f4 = model11(x1) 229 | print(f1.shape, f2.shape, f3.shape, f4.shape) 230 | -------------------------------------------------------------------------------- /nets/det/params_mapping.py: -------------------------------------------------------------------------------- 1 | mobile_net_v3_mapping = { 2 | "init_conv": "features.0", 3 | 4 | "layer1.0.depth_conv": "features.1.block.0", 5 | "layer1.0.point_conv": "features.1.block.1", 6 | 7 | "layer1.1.expand_conv": "features.2.block.0", 8 | "layer1.1.depth_conv": "features.2.block.1", 9 | "layer1.1.point_conv": "features.2.block.2", 10 | 11 | "layer1.2.expand_conv": "features.3.block.0", 12 | "layer1.2.depth_conv": "features.3.block.1", 13 | "layer1.2.point_conv": "features.3.block.2", 14 | 15 | "layer2.0.expand_conv": "features.4.block.0", 16 | "layer2.0.depth_conv": "features.4.block.1", 17 | "layer2.0.point_conv": "features.4.block.2", 18 | 19 | "layer2.1.expand_conv": "features.5.block.0", 20 | "layer2.1.depth_conv": "features.5.block.1", 21 | "layer2.1.point_conv": "features.5.block.2", 22 | 23 | "layer2.2.expand_conv": "features.6.block.0", 24 | "layer2.2.depth_conv": "features.6.block.1", 25 | "layer2.2.point_conv": "features.6.block.2", 26 | 27 | "layer3.0.expand_conv": "features.7.block.0", 28 | "layer3.0.depth_conv": "features.7.block.1", 29 | "layer3.0.point_conv": "features.7.block.2", 30 | 31 | "layer3.1.expand_conv": "features.8.block.0", 32 | "layer3.1.depth_conv": "features.8.block.1", 33 | "layer3.1.point_conv": "features.8.block.2", 34 | 35 | "layer3.2.expand_conv": "features.9.block.0", 36 | "layer3.2.depth_conv": "features.9.block.1", 37 | "layer3.2.point_conv": "features.9.block.2", 38 | 39 | "layer3.3.expand_conv": "features.10.block.0", 40 | "layer3.3.depth_conv": "features.10.block.1", 41 | "layer3.3.point_conv": "features.10.block.2", 42 | 43 | "layer3.4.expand_conv": "features.11.block.0", 44 | "layer3.4.depth_conv": "features.11.block.1", 45 | "layer3.4.point_conv": "features.11.block.2", 46 | 47 | "layer3.5.expand_conv": "features.12.block.0", 48 | "layer3.5.depth_conv": "features.12.block.1", 49 | "layer3.5.point_conv": "features.12.block.2", 50 | 51 | "layer3.6.expand_conv": "features.13.block.0", 52 | "layer3.6.depth_conv": "features.13.block.1", 53 | "layer3.6.point_conv": "features.13.block.2", 54 | 55 | "layer4.0.expand_conv": "features.14.block.0", 56 | "layer4.0.depth_conv": "features.14.block.1", 57 | "layer4.0.point_conv": "features.14.block.2", 58 | 59 | "layer4.1.expand_conv": "features.15.block.0", 60 | "layer4.1.depth_conv": "features.15.block.1", 61 | "layer4.1.point_conv": "features.15.block.2", 62 | 63 | "layer4.2": "features.16", 64 | 65 | } 66 | -------------------------------------------------------------------------------- /nets/det/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils import model_zoo 5 | 6 | __all__ = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 7 | 8 | 9 | M_URLS = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | """ 20 | 通道可以升维降维, feature map可以上采样下采样 21 | """ 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=(1, 1), downsample=None): 25 | super().__init__() 26 | 27 | # 当stride=2时, 下采样. 28 | # 当planes != inplanes * self.expansion时,升维 29 | self.conv1 = nn.Conv2d( 30 | in_channels=inplanes, 31 | out_channels=planes, 32 | kernel_size=(3, 3), 33 | stride=stride, 34 | padding=(1, 1), 35 | bias=False 36 | ) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | 39 | self.conv2 = nn.Conv2d( 40 | in_channels=planes, 41 | out_channels=planes, 42 | kernel_size=(3, 3), 43 | stride=(1, 1), 44 | padding=(1, 1), 45 | bias=False 46 | ) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | 49 | self.downsample = downsample 50 | self.relu = nn.ReLU(inplace=True) # inplace代表是否修改输入对象的值 51 | 52 | def forward(self, x): 53 | residual = x 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 # 输出通道的倍乘 71 | 72 | def __init__(self, inplanes, planes, stride=(1, 1), downsample=None): 73 | super(Bottleneck, self).__init__() 74 | self.conv1 = nn.Conv2d( 75 | in_channels=inplanes, 76 | out_channels=planes, 77 | kernel_size=(1, 1), 78 | stride=(1, 1), 79 | padding=(0, 0), 80 | bias=False 81 | ) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | self.conv2 = nn.Conv2d( 84 | in_channels=planes, 85 | out_channels=planes, 86 | kernel_size=(3, 3), 87 | stride=stride, 88 | padding=(1, 1), 89 | bias=False 90 | ) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d( 93 | in_channels=planes, 94 | out_channels=planes * self.expansion, 95 | kernel_size=(1, 1), 96 | bias=False 97 | ) 98 | self.bn3 = nn.BatchNorm2d(planes * 4) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | 102 | def forward(self, x): 103 | residual = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.bn3(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(residual) 118 | 119 | out += residual 120 | out = self.relu(out) 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | def __init__(self, block, layers): 126 | super().__init__() 127 | self.inplanes = 64 # 最开始默认输入维度为64 128 | self.layer_out_channels = [64, 128, 256, 512] 129 | self.conv1 = nn.Conv2d( 130 | in_channels=3, 131 | out_channels=64, 132 | kernel_size=(7, 7), 133 | stride=(2, 2), 134 | padding=(3, 3), 135 | bias=False 136 | ) 137 | self.bn1 = nn.BatchNorm2d(64) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.maxpool = nn.MaxPool2d( 140 | kernel_size=3, 141 | stride=2, 142 | padding=1 143 | ) 144 | 145 | self.layer1 = self._make_layer(block, self.layer_out_channels[0], layers[0]) 146 | self.layer2 = self._make_layer(block, self.layer_out_channels[1], layers[1], stride=(2, 2)) 147 | self.layer3 = self._make_layer(block, self.layer_out_channels[2], layers[2], stride=(2, 2)) 148 | self.layer4 = self._make_layer(block, self.layer_out_channels[3], layers[3], stride=(2, 2)) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 153 | m.weight.data.normal_(0, math.sqrt(2. / n)) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | 158 | def _make_layer(self, block, planes, blocks, stride=(1, 1)): 159 | downsample = None 160 | if stride != (1, 1) or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | nn.Conv2d( 163 | self.inplanes, 164 | planes * block.expansion, 165 | kernel_size=(1, 1), 166 | stride=stride, 167 | padding=(0, 0), 168 | bias=False 169 | ), 170 | nn.BatchNorm2d(planes * block.expansion) 171 | ) 172 | layers = [block(self.inplanes, planes, stride, downsample)] 173 | self.inplanes = planes * block.expansion 174 | for i in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | return nn.Sequential(*layers) 177 | 178 | def forward(self, x): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) 182 | x = self.maxpool(x) 183 | 184 | x2 = self.layer1(x) 185 | x3 = self.layer2(x2) 186 | x4 = self.layer3(x3) 187 | x5 = self.layer4(x4) 188 | return x2, x3, x4, x5 189 | 190 | 191 | def resnet18(pre_trained_dir=None): 192 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 193 | if pre_trained_dir is None: 194 | return model 195 | 196 | state_dict = model_zoo.load_url(M_URLS["resnet18"], model_dir=pre_trained_dir) 197 | model.load_state_dict(state_dict, strict=False) 198 | return model 199 | 200 | 201 | def resnet34(pre_trained_dir=None): 202 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 203 | if pre_trained_dir is None: 204 | return model 205 | 206 | state_dict = model_zoo.load_url(M_URLS["resnet34"], model_dir=pre_trained_dir) 207 | model.load_state_dict(state_dict, strict=False) # strict是否关注load_state_dict在恢复过程中的部分信息 208 | return model 209 | 210 | 211 | def resnet50(pre_trained_dir=None): 212 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 213 | if pre_trained_dir is None: 214 | return model 215 | state_dict = model_zoo.load_url(M_URLS["resnet50"], model_dir=pre_trained_dir) 216 | model.load_state_dict(state_dict, strict=False) 217 | return model 218 | 219 | 220 | def resnet101(pre_trained_dir=None): 221 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 222 | if pre_trained_dir is None: 223 | return model 224 | 225 | state_dict = model_zoo.load_url(M_URLS["resnet101"], model_dir=pre_trained_dir) 226 | model.load_state_dict(state_dict, strict=False) 227 | return model 228 | 229 | 230 | def resnet152(pre_trained_dir=None): 231 | model = ResNet(Bottleneck, [3, 8, 38, 3]) 232 | if pre_trained_dir is None: 233 | return model 234 | 235 | state_dict = model_zoo.load_url(M_URLS["resnet101"], model_dir=pre_trained_dir) 236 | model.load_state_dict(state_dict, strict=False) 237 | return model 238 | 239 | 240 | if __name__ == "__main__": 241 | data = torch.randn(8, 3, 640, 640) 242 | rn_model = resnet18("//pretrained_models/") 243 | fs = rn_model(data) 244 | for f in fs: 245 | print(f.shape) 246 | -------------------------------------------------------------------------------- /nets/rec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/nets/rec/__init__.py -------------------------------------------------------------------------------- /nets/rec/mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | M_URL = "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" 5 | 6 | 7 | def _weights_init(m): 8 | if isinstance(m, nn.Conv2d): 9 | torch.nn.init.xavier_uniform_(m.weight) 10 | if m.bias is not None: 11 | torch.nn.init.zeros_(m.bias) 12 | elif isinstance(m, nn.BatchNorm2d): 13 | m.weight.data.fill_(1) 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.Linear): 16 | m.weight.data.normal_(0, 0.01) 17 | m.bias.data.zero_() 18 | 19 | 20 | def _make_divisible(v, divisor=8, min_value=None): 21 | # 等比例的增加和减少通道的个数 22 | if min_value is None: 23 | min_value = divisor 24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than 10%. 26 | if new_v < 0.9 * v: 27 | new_v += divisor 28 | return new_v 29 | 30 | 31 | class SqueezeBlock(nn.Module): 32 | def __init__(self, exp_size, divide=4): 33 | super(SqueezeBlock, self).__init__() 34 | self.dense = nn.Sequential( 35 | nn.Linear(exp_size, exp_size // divide), # Squeeze线性连接 36 | nn.ReLU(inplace=True), 37 | nn.Linear(exp_size // divide, exp_size), # Excite线性连接 38 | ) 39 | 40 | def forward(self, x): 41 | batch, channels, height, width = x.size() 42 | # 1.池化 43 | out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1) 44 | out = self.dense(out) 45 | out = F.hardsigmoid(out, inplace=True) 46 | # resize 47 | out = out.view(batch, channels, 1, 1) 48 | # 相乘 49 | return out * x 50 | 51 | 52 | class MobileBlock(nn.Module): 53 | def __init__(self, in_channels, out_channels, kernal_size, stride, non_linear, _se, exp_size, dropout_rate=1.0): 54 | super(MobileBlock, self).__init__() 55 | self.out_channels = out_channels 56 | self.nonLinear = non_linear 57 | self.SE = _se 58 | self.dropout_rate = dropout_rate 59 | padding = (kernal_size - 1) // 2 60 | 61 | self.use_connect = (stride == 1 and in_channels == out_channels) # 残差条件 62 | 63 | if self.nonLinear == "RE": 64 | activation = nn.ReLU 65 | else: 66 | activation = nn.Hardswish 67 | 68 | # 1*1卷积 expand 69 | self.expand_conv = nn.Sequential( 70 | nn.Conv2d(in_channels, exp_size, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), 71 | nn.BatchNorm2d(exp_size), 72 | activation(inplace=True) 73 | ) 74 | # 膨胀的卷积操作, 深度卷积 3* 3 或者 5*5 75 | self.depth_conv = nn.Sequential( 76 | nn.Conv2d(exp_size, exp_size, kernel_size=kernal_size, stride=stride, padding=padding, groups=exp_size), 77 | nn.BatchNorm2d(exp_size), 78 | activation(inplace=True) 79 | ) 80 | 81 | if self.SE: 82 | self.squeeze_block = SqueezeBlock(exp_size) 83 | 84 | # 1*1卷积 逐点卷积 85 | self.point_conv = nn.Sequential( 86 | nn.Conv2d(exp_size, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)), 87 | nn.BatchNorm2d(out_channels), 88 | nn.Identity() 89 | ) 90 | 91 | def forward(self, x): 92 | # MobileNetV2 93 | out = self.expand_conv(x) # 1*1 卷积, 由输入通道转为膨胀通道,转换通道 in->exp 94 | out = self.depth_conv(out) # 3x3或5*5卷积,膨胀通道,使用步长stride 95 | 96 | # Squeeze and Excite 97 | if self.SE: 98 | out = self.squeeze_block(out) 99 | 100 | # 1x1卷积,由膨胀通道,转换为输出通道 101 | out = self.point_conv(out) # 转换通道 exp->out 102 | 103 | # 残差结构 104 | if self.use_connect: 105 | return x + out 106 | else: 107 | return out 108 | 109 | 110 | class MobileNetV3(nn.Module): 111 | def __init__(self, multiplier=0.5, use_se=False): 112 | super(MobileNetV3, self).__init__() 113 | self._multiplier = multiplier 114 | self._use_se = use_se 115 | # in_channel out_channel kernel_size stride nl se exp_size 116 | cfg = [ 117 | # in_channel, out_channel, kernel_size, stride, activation, se, exp_size 118 | [16, 16, 3, (1, 1), 'RE', True, 16], 119 | [16, 24, 3, (2, 1), 'RE', False, 72], 120 | [24, 24, 3, 1, 'RE', False, 88], 121 | [24, 40, 5, (2, 1), 'HS', True, 96], 122 | [40, 40, 5, 1, 'HS', True, 240], 123 | [40, 40, 5, 1, 'HS', True, 240], 124 | [40, 48, 5, 1, 'HS', True, 120], 125 | [48, 48, 5, 1, 'HS', True, 144], 126 | [48, 96, 5, (2, 1), 'HS', True, 288], 127 | [96, 96, 5, 1, 'HS', True, 576], 128 | [96, 96, 5, 1, 'HS', True, 576], 129 | ] 130 | cls_ch_squeeze = _make_divisible(576 * self._multiplier) 131 | self.output_channel = cls_ch_squeeze 132 | self.init_conv = nn.Sequential( 133 | nn.Conv2d( 134 | in_channels=3, out_channels=_make_divisible(16 * self._multiplier), 135 | kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 136 | nn.BatchNorm2d(_make_divisible(16 * self._multiplier)), 137 | nn.Hardswish(inplace=True) 138 | ) 139 | 140 | self.layer = self._make_layer(self._multi_layer_conf(cfg)) 141 | self.last_conv = nn.Sequential( 142 | nn.Conv2d(_make_divisible(cfg[-1][1]), cls_ch_squeeze, kernel_size=(1, 1), stride=(1, 1)), 143 | nn.BatchNorm2d(cls_ch_squeeze), 144 | nn.Hardswish(inplace=True) 145 | ) 146 | self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) 147 | 148 | self.apply(_weights_init) 149 | 150 | def _multi_layer_conf(self, layer_conf): 151 | for lc in layer_conf: 152 | lc[0] = _make_divisible(lc[0] * self._multiplier) 153 | lc[1] = _make_divisible(lc[1] * self._multiplier) 154 | lc[-1] = _make_divisible(lc[-1] * self._multiplier) 155 | return layer_conf 156 | 157 | def _make_layer(self, layer_conf, sequential=True): 158 | block_list = [] 159 | for in_channels, out_channels, kernel_size, stride, activation, se, exp_size in layer_conf: 160 | # activation  NL: 非线性激活函数;HS: H-swish激活函数,RE:ReLU激活函数 161 | # SE Squeeze and Excite结构,是否压缩和激发 162 | # exp_size: 膨胀参数 163 | # MobileBlock瓶颈层 164 | se = se and self._use_se 165 | block_list.append(MobileBlock(in_channels, out_channels, kernel_size, stride, activation, se, exp_size)) 166 | if sequential: 167 | return nn.Sequential(*block_list) 168 | return block_list 169 | 170 | def forward(self, x): 171 | """ 172 | :param x: N * C * H * W 173 | :return: N * 288 * 1 * 25 174 | """ 175 | # 起始部分 176 | out = self.init_conv(x) 177 | # 中间部分 178 | out = self.layer(out) 179 | out = self.last_conv(out) 180 | out = self.pool(out) 181 | return out 182 | 183 | 184 | def rec_mobilenet_v3(pre_trained_dir=None, multiplier=0.5, use_se=False): 185 | if not pre_trained_dir: 186 | return MobileNetV3(multiplier, use_se) 187 | 188 | 189 | if __name__ == "__main__": 190 | model11 = rec_mobilenet_v3() 191 | import torchvision 192 | torchvision.models.mobilenet_v3_small() 193 | x1 = torch.randn(8, 3, 32, 320) 194 | fx = model11(x1) 195 | print(fx.shape) 196 | -------------------------------------------------------------------------------- /nets/rec/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CRNN(nn.Module): 7 | def __init__( 8 | self, 9 | classes_num, 10 | rnn_type, 11 | hidden_size, 12 | num_layers, 13 | bidirectional, 14 | backbone, 15 | ): 16 | super(CRNN, self).__init__() 17 | 18 | self.backbone = self._call_backbone(backbone) 19 | self.rnn_in_channel = self.backbone.output_channel 20 | self.classes_num = classes_num 21 | self.bidirectional = bidirectional 22 | self.hidden_size = hidden_size 23 | self.num_layers = num_layers 24 | self.rnn_type = rnn_type 25 | assert self.rnn_type in ["LSTM", "GRU"] 26 | 27 | if self.rnn_type == "LSTM": 28 | self.rnn_layer = nn.GRU( 29 | input_size=self.rnn_in_channel, 30 | hidden_size=self.hidden_size, 31 | num_layers=self.num_layers, 32 | batch_first=False, 33 | bidirectional=bidirectional 34 | ) 35 | else: 36 | self.rnn_layer = nn.LSTM( 37 | input_size=self.rnn_in_channel, 38 | hidden_size=self.hidden_size, 39 | num_layers=self.num_layers, 40 | batch_first=False, 41 | bidirectional=self.bidirectional 42 | ) 43 | for name, params in self.rnn_layer.named_parameters(): 44 | nn.init.uniform_(params, -0.1, 0.1) 45 | 46 | rnn_out_channel = hidden_size 47 | if self.bidirectional: 48 | rnn_out_channel = hidden_size * 2 49 | 50 | self.fc = nn.Linear( 51 | in_features=rnn_out_channel, 52 | out_features=self.classes_num, 53 | bias=True 54 | ) 55 | self.apply(self._weights_init) 56 | 57 | @staticmethod 58 | def _call_backbone(backbone): 59 | module_func = backbone.pop("name") 60 | if module_func == "rec_mobilenet_v3": 61 | from nets.rec.mobilenet_v3 import rec_mobilenet_v3 62 | module_func = eval(module_func)(**backbone) 63 | else: 64 | raise Exception("backbone {} is not found".format(module_func)) 65 | return module_func 66 | 67 | @staticmethod 68 | def _weights_init(m): 69 | if isinstance(m, nn.Linear): 70 | m.weight.data.normal_(0, 0.01) 71 | m.bias.data.zero_() 72 | 73 | def forward(self, x): 74 | """ 75 | :param x: N * 3 * H * W 76 | :return: # N * T * Feature 77 | """ 78 | x = self.backbone(x) # N * C * 1 * W (mobilenet: N * 288 * 1 * 25) 79 | x = x.squeeze(axis=2) # N * C * W 80 | x = x.permute(2, 0, 1) # N * W * C 81 | x, _ = self.rnn_layer(x) 82 | x = self.fc(x) 83 | if not self.training: 84 | x = F.softmax(x, dim=2) 85 | return x 86 | 87 | 88 | if __name__ == "__main__": 89 | input_ = torch.randn(8, 3, 32, 320) 90 | se = CRNN(5000, "GRU", 48, 2, True, {"name": "rec_mobilenet_v3"}) 91 | # se.eval() 92 | print(se(input_).shape) 93 | -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .optim import OptimizerScheduler 2 | from .learning_rate import LearningSchedule 3 | 4 | 5 | __all__ = ["build_optimizer"] 6 | 7 | 8 | def build_optimizer(parameters, epochs, step_each_epoch, config): 9 | optimizer_name = config.pop("name") 10 | support_dict = ["OptimizerScheduler"] 11 | assert optimizer_name in support_dict 12 | optimizer = eval(optimizer_name)( 13 | parameters=parameters, 14 | optim_method=config["optim_method"], 15 | init_learning_rate=config["init_learning_rate"], 16 | ).optim 17 | 18 | lr_schedule_conf = config["learning_schedule"] 19 | lr_schedule_name = lr_schedule_conf["name"] 20 | lr_schedule = eval(lr_schedule_name)( 21 | optimizer=optimizer, 22 | epochs=epochs, 23 | step_each_epoch=step_each_epoch, 24 | warmup_epoch=lr_schedule_conf["warmup_epoch"], 25 | lr_method=lr_schedule_conf["lr_method"], 26 | last_epoch=-1 27 | ).get_learning_rate 28 | return optimizer, lr_schedule 29 | -------------------------------------------------------------------------------- /optimizer/learning_rate.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR, ExponentialLR, LambdaLR 3 | 4 | 5 | class LearningSchedule(object): 6 | 7 | def __init__( 8 | self, 9 | optimizer, 10 | epochs, 11 | step_each_epoch, 12 | lr_method="_cosine_warmup", 13 | warmup_epoch=2, 14 | last_epoch=-1 15 | ): 16 | self.optimizer = optimizer 17 | self._last_epoch = step_each_epoch * last_epoch if last_epoch != -1 else -1 18 | self.t_max = step_each_epoch * epochs 19 | self.warmup_epoch = step_each_epoch * warmup_epoch 20 | assert lr_method in ["_step_lr", "_multi_step_lr", "_exponential_lr", "_cosine_annealing_lr", 21 | "_cosine_warmup"] 22 | self._lr_method = lr_method 23 | 24 | @property 25 | def get_learning_rate(self): 26 | learning_rate = getattr(self, self._lr_method)() 27 | return learning_rate 28 | 29 | def _step_lr(self, gama=0.9, step_size=5): 30 | """ 31 | 固定步长衰减 lr = base_lr * gama ** (epoch // step_size) 32 | :gama 学习率调整倍数 33 | :step_size 指的是Epoch间隔数 34 | """ 35 | return StepLR( 36 | optimizer=self.optimizer, 37 | step_size=step_size, 38 | gamma=gama, 39 | last_epoch=self._last_epoch 40 | ) 41 | 42 | def _multi_step_lr(self, milestones=None, gama=0.9): 43 | """ 44 | 多步长衰减 45 | :param milestones: 区间epoch 46 | :param gama: 学习率调整倍数 47 | """ 48 | if not milestones: 49 | milestones = [10, 20, 50] 50 | return MultiStepLR( 51 | optimizer=self.optimizer, 52 | milestones=milestones, 53 | gamma=gama, 54 | last_epoch=self._last_epoch 55 | ) 56 | 57 | def _exponential_lr(self, gama=0.98): 58 | """ 59 | lr = base_lr * gama ** epoch 60 | 指数衰减 61 | """ 62 | return ExponentialLR( 63 | optimizer=self.optimizer, 64 | gamma=gama, 65 | last_epoch=self._last_epoch 66 | ) 67 | 68 | def _cosine_annealing_lr(self, eta_min=0): 69 | """ 70 | 余弦退火衰减 71 | :t_max 最大迭代次数, step_each_epoch * epoch 72 | """ 73 | return CosineAnnealingLR( 74 | optimizer=self.optimizer, 75 | T_max=self.t_max, 76 | eta_min=eta_min, 77 | last_epoch=self._last_epoch, 78 | ) 79 | 80 | def _cosine_warmup(self, eta_min=0): 81 | def lr_lambda(epoch): 82 | if epoch < self.warmup_epoch: 83 | return (epoch + 1) / self.warmup_epoch 84 | return eta_min + 0.5 * ( 85 | math.cos((epoch - self.warmup_epoch) / (self.t_max - self.warmup_epoch) * math.pi) + 1) 86 | 87 | return LambdaLR( 88 | optimizer=self.optimizer, 89 | lr_lambda=lr_lambda, 90 | last_epoch=self._last_epoch 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | from torchvision.models import resnet18 96 | from torch.optim import Adam 97 | 98 | model = resnet18(pretrained=False) 99 | _epochs = 10 100 | _step_each_epoch = 100 101 | optimizer1 = Adam(model.parameters(), lr=0.01) 102 | ls_schedule = LearningSchedule( 103 | optimizer=optimizer1, 104 | epochs=_epochs, 105 | step_each_epoch=_step_each_epoch, 106 | warmup_epoch=2 107 | ).get_learning_rate 108 | for _epoch in range(1, _epochs + 1): 109 | for _iter in range(1, _step_each_epoch + 1): 110 | lr = optimizer1.param_groups[0]["lr"] 111 | print("{}/{} {}/{} {:.8f}".format(_epoch, _epochs, _iter, _step_each_epoch, lr)) 112 | optimizer1.zero_grad() 113 | optimizer1.step() 114 | ls_schedule.step() 115 | -------------------------------------------------------------------------------- /optimizer/optim.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | 4 | class OptimizerScheduler(object): 5 | 6 | def __init__(self, parameters, init_learning_rate=0.001, optim_method="_adam"): 7 | self._parameters = parameters 8 | self._learning_rate = init_learning_rate 9 | self._method = optim_method 10 | self.optim = self.__getattribute__(self._method)() 11 | 12 | def _adam(self): 13 | return optim.Adam( 14 | params=self._parameters, 15 | lr=self._learning_rate, 16 | weight_decay=0, 17 | amsgrad=True 18 | ) 19 | 20 | def _sgd(self): 21 | return optim.SGD( 22 | params=self._parameters, 23 | lr=self._learning_rate 24 | ) 25 | -------------------------------------------------------------------------------- /postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .det_postprocess import DBPostProcess 2 | from .rec_postprocess import CRnnPostProcess 3 | 4 | 5 | __all__ = ["build_post_process"] 6 | 7 | 8 | def build_post_process(post_config): 9 | module_name = post_config.pop("name") 10 | assert module_name in ["DBPostProcess", "CRnnPostProcess"] 11 | module_class = eval(module_name)(**post_config) 12 | return module_class 13 | -------------------------------------------------------------------------------- /postprocess/det_postprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pyclipper 4 | from shapely.geometry import Polygon 5 | 6 | 7 | class DBPostProcess(object): 8 | def __init__( 9 | self, 10 | thresh=0.3, 11 | box_thresh=0.7, 12 | max_candidates=1000, 13 | unclip_ratio=1.6 14 | ): 15 | self.min_size = 3 16 | self.thresh = thresh 17 | self.box_thresh = box_thresh 18 | self.max_candidates = max_candidates 19 | self.unclip_ratio = unclip_ratio 20 | 21 | def __call__(self, pred, batch): 22 | """ 23 | prob: text region segmentation map, with shape (N, 1, H, W) 24 | src_scale: the original shape of images. [[H1, W1], [H2, W2], [H3, W3]...] 25 | """ 26 | src_scale = batch["src_scale"] 27 | pred = pred[:, 0, :, :] # binary 28 | segmentation = pred > self.thresh 29 | boxes_batch = [] 30 | scores_batch = [] 31 | for batch_index in range(pred.size(0)): 32 | height, width = src_scale[batch_index] 33 | boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) 34 | boxes_batch.append(boxes) 35 | scores_batch.append(scores) 36 | return np.array(boxes_batch), np.array(scores_batch) 37 | 38 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 39 | """ 40 | pred: 概率图 H * W 41 | _bitmap: 初始二值图 H * W 42 | dest_width: 原始宽度 43 | dest_height: 原始高度 44 | """ 45 | assert len(_bitmap.shape) == 2 46 | bitmap = _bitmap.cpu().numpy() 47 | pred = pred.cpu().detach().numpy() 48 | height, width = bitmap.shape 49 | contours, _ = cv2.findContours( 50 | (bitmap * 255).astype(np.uint8), 51 | cv2.RETR_LIST, # 检测的轮廓不建立等级关系 52 | cv2.CHAIN_APPROX_SIMPLE # 压缩方向元素, 只保留该方向的终点坐标 53 | ) 54 | num_contours = min(len(contours), self.max_candidates) 55 | boxes = [] 56 | scores = [] 57 | for index in range(num_contours): 58 | contour = contours[index].squeeze(1) 59 | points, min_side = self.get_mini_boxes(contour) 60 | # 返回最小边框的长度 61 | if min_side < self.min_size: 62 | continue 63 | points = np.array(points) 64 | score = self.box_score_fast(pred, contour) 65 | if self.box_thresh > score: 66 | continue 67 | 68 | box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 2) 69 | box, min_side = self.get_mini_boxes(box) 70 | if min_side < self.min_size + 2: 71 | continue 72 | box = np.array(box) 73 | if not isinstance(dest_width, int): 74 | dest_width = dest_width.item() 75 | dest_height = dest_height.item() 76 | 77 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) 78 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) 79 | boxes.append(box.astype(np.int16)) 80 | scores.append(score) 81 | return np.array(boxes, dtype=np.int16), scores 82 | 83 | @staticmethod 84 | def unclip(box, unclip_ratio=1.5): 85 | poly = Polygon(box) 86 | distance = poly.area * unclip_ratio / poly.length 87 | offset = pyclipper.PyclipperOffset() 88 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 89 | expanded = np.array(offset.Execute(distance)) 90 | return expanded 91 | 92 | @staticmethod 93 | def get_mini_boxes(contour): 94 | # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度) 95 | bounding_box = cv2.minAreaRect(contour) 96 | # 排序最小外接矩形的4个顶点坐标, 从上往下排序 97 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 98 | if points[1][1] > points[0][1]: 99 | index_1 = 0 100 | index_4 = 1 101 | else: 102 | index_1 = 1 103 | index_4 = 0 104 | if points[3][1] > points[2][1]: 105 | index_2 = 2 106 | index_3 = 3 107 | else: 108 | index_2 = 3 109 | index_3 = 2 110 | 111 | box = [points[index_1], points[index_2], points[index_3], points[index_4]] 112 | return box, min(bounding_box[1]) 113 | 114 | @staticmethod 115 | def box_score_fast(bitmap, _box): 116 | h, w = bitmap.shape[:2] 117 | box = _box.copy() 118 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 119 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 120 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 121 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 122 | 123 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 124 | box[:, 0] = box[:, 0] - xmin 125 | box[:, 1] = box[:, 1] - ymin 126 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 127 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 128 | -------------------------------------------------------------------------------- /postprocess/rec_postprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils.string_utils import CharacterJson 4 | 5 | 6 | class CRnnPostProcess(object): 7 | def __init__(self, character_json_path): 8 | cj = CharacterJson(character_json_path) 9 | self._char2idx = cj.char2idx 10 | self._idx2char = cj.idx2char 11 | 12 | def decode(self, text_index, text_prob=None, is_remove_duplicate=False): 13 | """ convert text-index into text-label. """ 14 | result_list = [] 15 | ignored_token = self._char2idx[""] 16 | batch_size = len(text_index) 17 | for batch_idx in range(batch_size): 18 | char_list = [] 19 | conf_list = [] 20 | for idx in range(len(text_index[batch_idx])): 21 | if text_index[batch_idx][idx] == ignored_token: 22 | continue 23 | if is_remove_duplicate: 24 | # only for predict 25 | if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ 26 | batch_idx][idx]: 27 | continue 28 | char_list.append(self._idx2char[int(text_index[batch_idx][idx])]) 29 | if text_prob is not None: 30 | conf_list.append(text_prob[batch_idx][idx]) 31 | else: 32 | conf_list.append(1) 33 | text = ''.join(char_list) 34 | result_list.append((text, np.mean(conf_list))) 35 | return result_list 36 | 37 | def __call__(self, preds, batch=None): 38 | if torch.is_tensor(preds): 39 | preds = preds.detach().numpy() 40 | 41 | preds_idx = preds.argmax(axis=2) 42 | preds_prob = preds.max(axis=2) 43 | preds_idx = preds_idx.transpose(1, 0) 44 | preds_prob = preds_prob.transpose(1, 0) 45 | text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) 46 | if not batch or "label_idx" not in batch.keys(): 47 | return text 48 | label = batch["label_idx"] 49 | label = self.decode(label) 50 | return text, label 51 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import json 5 | import argparse 6 | import time 7 | import copy 8 | import codecs 9 | from functools import partial 10 | import numpy as np 11 | import onnx 12 | from config.load_conf import ReadConfig 13 | import onnxruntime as rt 14 | from nets import build_model 15 | from postprocess import build_post_process 16 | from data_loader.img_aug import * 17 | 18 | 19 | def main(params): 20 | model = build_model(params["model"]) 21 | post_process = build_post_process(params["post_process"]) 22 | pt = Predictor(model, post_process, params) 23 | pt.predict() 24 | 25 | 26 | class Predictor(object): 27 | def __init__(self, model, post_process, params): 28 | self._model = model 29 | self._conf = params["global"] 30 | self.image_dir_or_path = params["dataset"]["image_dir_or_path"] 31 | self._transforms = self._transforms_func_lst(params["dataset"]["transforms"]) 32 | self._post_process = post_process 33 | self._image_list = self._read_images() 34 | if not os.path.exists(self._conf["res_save_dir"]): 35 | os.makedirs(self._conf["res_save_dir"]) 36 | if self._conf["use_infer_model"]: 37 | self.sess = self._convert_train2infer() 38 | else: 39 | self.sess = self._init_pth_model() 40 | 41 | @staticmethod 42 | def _transforms_func_lst(config): 43 | func_lst = [] 44 | for _transform in config: 45 | operator = list(_transform.keys())[0] 46 | params = dict() if _transform[operator] is None else _transform[operator] 47 | func_name = eval(operator)(**params) 48 | func_lst.append(func_name) 49 | return func_lst 50 | 51 | def _convert_train2infer(self): 52 | if os.path.exists(self._conf["infer_model_path"]): 53 | return rt.InferenceSession(self._conf["infer_model_path"]) 54 | 55 | if not os.path.exists(self._conf["train_model_path"]): 56 | raise Exception("model_det {} not exists".format(self._conf["train_model_path"])) 57 | 58 | ckpt = torch.load(self._conf["train_model_path"], map_location=torch.device('cpu'))["state_dict"] 59 | self._model.load_state_dict(ckpt) 60 | self._model.eval() 61 | 62 | if self._conf["yml_type"] == "DET": 63 | x = torch.randn(1, 3, 224, 224, requires_grad=True) 64 | dynamic_axes = { 65 | "input": {0: "batch_size", 2: "height", 3: "width"}, 66 | "output": {0: "batch_size"} 67 | } 68 | else: 69 | x = torch.randn(1, 3, 32, 320, requires_grad=True) 70 | dynamic_axes = { 71 | "input": {0: "batch_size", 3: "width"}, 72 | "output": {0: "batch_size"} 73 | } 74 | 75 | torch.onnx.export( 76 | model=self._model, 77 | args=x, 78 | f=self._conf["infer_model_path"], 79 | export_params=True, 80 | opset_version=11, 81 | do_constant_folding=True, # 是否执行常量折叠优化 82 | input_names=["input"], # 输入名 83 | output_names=["output"], # 输出名 84 | dynamic_axes=dynamic_axes 85 | ) 86 | try: 87 | onnx_model = onnx.load(self._conf["infer_model_path"]) 88 | onnx.checker.check_model(onnx_model) 89 | except Exception as e: 90 | raise e 91 | return rt.InferenceSession(self._conf["infer_model_path"]) 92 | 93 | def _read_images(self): 94 | imgs_lists = [] 95 | if self.image_dir_or_path is None or not os.path.exists(self.image_dir_or_path): 96 | raise Exception("not found any img file in {}".format(self.image_dir_or_path)) 97 | 98 | img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'} 99 | if os.path.isfile(self.image_dir_or_path) and \ 100 | os.path.splitext(self.image_dir_or_path)[-1][1:].lower() in img_end: 101 | imgs_lists.append(self.image_dir_or_path) 102 | elif os.path.isdir(self.image_dir_or_path): 103 | for single_file in os.listdir(self.image_dir_or_path): 104 | file_path = os.path.join(self.image_dir_or_path, single_file) 105 | if os.path.isfile(file_path) and os.path.splitext(file_path)[-1][1:].lower() in img_end: 106 | imgs_lists.append(file_path) 107 | if len(imgs_lists) == 0: 108 | raise Exception("not found any img file in {}".format(self.image_dir_or_path)) 109 | return imgs_lists 110 | 111 | def _init_pth_model(self): 112 | if not self._conf["train_model_path"]: 113 | return self._model 114 | if not os.path.exists(self._conf["train_model_path"]): 115 | print("pth path {} is not exists".format(self._conf["train_model_path"])) 116 | raise 117 | try: 118 | checkpoint = torch.load(self._conf["train_model_path"], map_location="cpu") 119 | self._model.load_state_dict(checkpoint["state_dict"], strict=False) 120 | except Exception: 121 | print("model_det init failed") 122 | raise 123 | return self._model 124 | 125 | def predict(self): 126 | self._model.eval() 127 | result = [] 128 | for image_path in self._image_list: 129 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 默认BGR CHANNEL_LAST 130 | if image is None: 131 | print("reading image_path: {} failed".format(image_path)) 132 | continue 133 | data = {"image": image} 134 | for _transform in self._transforms: 135 | data = _transform(data) 136 | 137 | for key, val in data.items(): 138 | data[key] = np.expand_dims(val, axis=0) 139 | 140 | start_time = time.time() 141 | if self._conf["use_infer_model"]: 142 | out = self.sess.run(["output"], {"input": data["image"]})[0] 143 | preds = torch.from_numpy(out) 144 | else: 145 | images = torch.from_numpy(data["image"]) 146 | preds = self._model(images) 147 | 148 | print("image: {} \texpend time: {:.4f}".format(image_path, time.time() - start_time)) 149 | post_result = self._post_process(preds, data) 150 | dt_boxes_json = dict() 151 | dt_boxes_json["file_name"] = image_path 152 | if self._conf["yml_type"] == "DET": 153 | dt_boxes_json["bbox"] = post_result[0][0].tolist() 154 | dt_boxes_json["score"] = post_result[1][0].tolist() 155 | self._draw_det_res(image, dt_boxes_json, os.path.basename(image_path)) 156 | else: 157 | dt_boxes_json["text"] = post_result[0][0] 158 | dt_boxes_json["score"] = post_result[0][1] 159 | result.append(dt_boxes_json) 160 | 161 | with codecs.open(os.path.join(self._conf["res_save_dir"], "result.txt"), "a", "utf8") as f: 162 | for res in result: 163 | f.write(json.dumps(res, ensure_ascii=False)+"\n") 164 | 165 | def _draw_det_res(self, image, dt_boxes_json, img_name): 166 | cus_line = partial(cv2.line, color=(255, 255, 0), thickness=1) 167 | if len(dt_boxes_json) > 0: 168 | new_im = copy.copy(image) 169 | for i, box in enumerate(dt_boxes_json["bbox"]): 170 | score = dt_boxes_json["score"][i] 171 | cus_line(new_im, (box[0][0], box[0][1]), (box[1][0], box[1][1])) 172 | cus_line(new_im, (box[1][0], box[1][1]), (box[2][0], box[2][1])) 173 | cus_line(new_im, (box[2][0], box[2][1]), (box[3][0], box[3][1])) 174 | cus_line(new_im, (box[3][0], box[3][1]), (box[0][0], box[0][1])) 175 | cv2.putText( 176 | new_im, 177 | "{:.3f}".format(score), 178 | (box[0][0], box[0][1]), 179 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 180 | fontScale=0.3, 181 | color=(0, 0, 255)) 182 | 183 | save_path = os.path.join(self._conf["res_save_dir"], os.path.basename(img_name)) 184 | cv2.imwrite(save_path, new_im) 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("-c", "--config", default="./config/predict/det.yml", help="配置文件路径") 190 | det_conf_path = parser.parse_args().config 191 | 192 | cus_params = ReadConfig(det_conf_path).base_conf 193 | print("预测相关参数:\n{}".format(json.dumps(cus_params, indent=2, ensure_ascii=False))) 194 | main(cus_params) 195 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.19.2 2 | shapely~=1.7.1 3 | tqdm~=4.60.0 4 | opencv-python~=4.1.2.30 5 | onnx~=1.9.0 6 | onnxruntime~=1.7.0 7 | imgaug~=0.2.8 8 | pyclipper~=1.2.1 9 | torch~=1.8.1 10 | pyyaml~=5.4.1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import torch 5 | import argparse 6 | from tqdm import tqdm 7 | import torch.distributed 8 | from nets import build_model 9 | from losses import build_loss 10 | from metrics import build_metric 11 | from logger.logger import get_logger 12 | from optimizer import build_optimizer 13 | from config.load_conf import ReadConfig 14 | from postprocess import build_post_process 15 | from data_loader import build_data_loader 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' 18 | 19 | 20 | def main(conf, logger): 21 | distributed = False 22 | if not conf["global"]["use_gpu"] or not torch.cuda.is_available(): 23 | device = torch.device("cpu") 24 | else: 25 | device = torch.device("cuda:0") 26 | 27 | model = build_model(conf["model_det"]) 28 | model = model.to(device) 29 | 30 | trainer = Trainer( 31 | model=model, 32 | logger=logger, 33 | conf=conf, 34 | device=device, 35 | distributed=distributed 36 | ) 37 | 38 | logger.info("模型初始化完成....") 39 | time.sleep(2) 40 | trainer.train() 41 | 42 | 43 | class Trainer(object): 44 | def __init__( 45 | self, 46 | model, 47 | logger, 48 | conf, 49 | device, 50 | distributed 51 | ): 52 | self._model = model 53 | self._conf = conf 54 | self._logger = logger 55 | self._device = device 56 | self._global_step = 0 57 | self._last_epoch = -1 58 | self._best_epoch = 0 59 | self._distributed = distributed 60 | self._global_conf = self._conf["global"] 61 | self._best_indicator = 0 62 | self._indicator_name = "best_{}".format(self._conf["metrics"]["main_indicator"]) 63 | self._init_pth_model() 64 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 65 | # rank 标记主机或从机,设置为0表示主机 66 | # world_size标记使用几个主机,设为1表示1个 67 | torch.distributed.init_process_group('nccl', init_method="env://", world_size=1, rank=0) 68 | self._model = torch.nn.parallel.DistributedDataParallel(self._model) 69 | self.distributed = True 70 | 71 | self._steps_per_epoch = self._get_epoch_data(_len=True) 72 | self._validate_data = self._get_validate_data() 73 | self._optimizer, self._schedule = build_optimizer( 74 | parameters=self._model.parameters(), 75 | epochs=self._global_conf["epochs"], 76 | step_each_epoch=self._steps_per_epoch, 77 | config=self._conf["optimizer"] 78 | ) 79 | self._criterion = build_loss(self._conf["loss"]) 80 | self._metrics = build_metric(self._conf["metrics"]) 81 | self._post_process = build_post_process(self._conf["post_process"]) 82 | 83 | self._start_epoch = 1 if self._last_epoch == -1 else self._last_epoch + 1 84 | 85 | def _get_epoch_data(self, _len=False): 86 | data = build_data_loader( 87 | config=self._conf, 88 | distributed=self._distributed, 89 | logger=self._logger, 90 | mode="train" 91 | ) 92 | if _len: 93 | return len(data) 94 | return data 95 | 96 | def _get_validate_data(self): 97 | data = build_data_loader( 98 | config=self._conf, 99 | distributed=self._distributed, 100 | logger=self._logger, 101 | mode="validate" 102 | ) 103 | return data 104 | 105 | def train(self): 106 | self._model.train() 107 | self._logger.info("开始训练....") 108 | time.sleep(1) 109 | for epoch in range(self._start_epoch, self._global_conf["epochs"] + 1): 110 | log_start_time = time.time() 111 | train_loader = self._get_epoch_data() 112 | for idx, batch in enumerate(train_loader): 113 | for key, val in batch.items(): 114 | if not torch.is_tensor(val): 115 | continue 116 | batch[key] = val.to(self._device) 117 | 118 | self._global_step += 1 119 | lr = self._optimizer.param_groups[0]["lr"] 120 | preds = self._model(batch["image"]) 121 | loss_dict = self._criterion(preds, batch) 122 | self._optimizer.zero_grad() 123 | loss_dict["loss"].backward() 124 | self._optimizer.step() 125 | self._schedule.step() 126 | 127 | indicator_str = "" 128 | for key, val in loss_dict.items(): 129 | indicator_str = '{}: {:.4f},'.format(key, val.item()) 130 | 131 | if self._global_conf["yml_type"] == "REC": 132 | post_result = self._post_process(preds, batch) 133 | metrics = self._metrics(post_result) 134 | indicator_str += 'acc: {:.4f}, norm_edit_dis: {:.4f},'.format(metrics["acc"], 135 | metrics["norm_edit_dis"]) 136 | 137 | if self._global_step % self._global_conf["log_iter"] == 0: 138 | batch_time = time.time() - log_start_time 139 | info_txt = "【{}/{}】,【{}/{}】, global_step: {}, lr:{:.6}, {} speed: {:.1f} samples/sec" 140 | info_txt = info_txt.format( 141 | epoch, self._global_conf["epochs"], idx + 1, self._steps_per_epoch, self._global_step, lr, 142 | indicator_str, self._global_conf["log_iter"] * preds.size(0) / batch_time, 143 | ) 144 | self._logger.info(info_txt) 145 | log_start_time = time.time() 146 | 147 | if epoch % self._global_conf["eval_epoch"] == 0: 148 | cur_metrics = self._eval() 149 | self._logger.info( 150 | "cur metrics: {}".format(", ".join(["{}:{}".format(k, v) for k, v in cur_metrics.items()]))) 151 | if cur_metrics[self._conf["metrics"]["main_indicator"]] > self._best_indicator: 152 | self._best_epoch = epoch 153 | self._best_indicator = cur_metrics[self._indicator_name] 154 | self._save_pth_model(self._indicator_name, epoch, self._best_epoch, self._best_indicator) 155 | 156 | if epoch % self._global_conf["save_epoch_iter"] == 0: 157 | file_name = "iter_epoch_{}".format(epoch) 158 | self._save_pth_model(file_name, epoch, self._best_epoch, self._best_indicator) 159 | 160 | self._save_pth_model("latest", epoch, self._best_epoch, self._best_indicator) 161 | 162 | def _eval(self): 163 | self._model.eval() 164 | total_time = 0.0 165 | with tqdm(total=len(self._validate_data), desc='eval model_det:') as pbar: 166 | for batch in self._validate_data: 167 | with torch.no_grad(): 168 | # 数据进行转换和丢到gpu 169 | for key, val in batch.items(): 170 | if not torch.is_tensor(val): 171 | continue 172 | batch[key] = val.to(self._device) 173 | 174 | pbar.update(1) 175 | start = time.time() 176 | preds = self._model(batch["image"]) 177 | post_result = self._post_process(preds, batch) 178 | total_time += time.time() - start 179 | self._metrics(post_result, batch) 180 | 181 | metrics = self._metrics.get_metric() 182 | self._model.train() 183 | return metrics 184 | 185 | def _init_pth_model(self): 186 | init_pth_path = self._global_conf["init_pth_path"] 187 | if not init_pth_path: 188 | return self._model 189 | if not os.path.exists(init_pth_path): 190 | self._logger.error("pth path {} is not exists".format(init_pth_path)) 191 | raise 192 | try: 193 | checkpoint = torch.load(init_pth_path, map_location="cpu") 194 | self._last_epoch = checkpoint["epoch"] 195 | self._best_epoch = checkpoint["best_epoch"] 196 | self._best_indicator = checkpoint[self._indicator_name] 197 | 198 | self._global_step = checkpoint["global_step"] 199 | self._model.load_state_dict(checkpoint["state_dict"], strict=False) 200 | self._optimizer.load_state_dict(checkpoint["optimizer"]) 201 | self._schedule.load_state_dict(checkpoint["schedule"]) 202 | for state in self._optimizer.state.values(): 203 | for k, v in state.items(): 204 | if not torch.is_tensor(v): 205 | continue 206 | state[k] = v.to(self._device) 207 | except Exception: 208 | self._logger.error("model_det init failed") 209 | raise 210 | 211 | def _save_pth_model(self, file_name, epoch, best_epoch, best_indicator): 212 | checkpoint = { 213 | "epoch": epoch, 214 | "best_epoch": best_epoch, 215 | self._indicator_name: best_indicator, 216 | 217 | "global_step": self._global_step, 218 | "state_dict": self._model.module.state_dict() if self._distributed else self._model.state_dict(), 219 | "optimizer": self._optimizer.state_dict(), 220 | "schedule": self._schedule.state_dict(), 221 | } 222 | if not os.path.exists(self._global_conf["save_pth_dir"]): 223 | os.makedirs(self._global_conf["save_pth_dir"]) 224 | torch.save(checkpoint, os.path.join(self._global_conf["save_pth_dir"], file_name + ".pth")) 225 | 226 | 227 | if __name__ == "__main__": 228 | parser = argparse.ArgumentParser() 229 | parser.add_argument("-c", "--config", default="./config/train/det.yml", help="配置文件路径") 230 | det_conf_path = parser.parse_args().config 231 | 232 | cus_params = ReadConfig(det_conf_path).base_conf 233 | cus_logger = get_logger(log_path=cus_params["global"]["save_pth_dir"]) 234 | cus_logger.info("相关自定义参数:\n{}".format(json.dumps(cus_params, indent=2, ensure_ascii=False))) 235 | time.sleep(1) 236 | main(cus_params, cus_logger) 237 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakeywu/ocr_torch/94451fac7d9e503ed4612772a6941c2034f13409/utils/__init__.py -------------------------------------------------------------------------------- /utils/string_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import codecs 3 | 4 | 5 | class CharacterJson(object): 6 | def __init__(self, character_json_path): 7 | self._read_character_dict(character_json_path) 8 | 9 | def _read_character_dict(self, character_json_path): 10 | with codecs.open(character_json_path, "r", "utf8") as f: 11 | char2idx = json.loads(f.read()) 12 | 13 | self.char2idx = char2idx 14 | idx2char = {val: key for key, val in char2idx.items()} 15 | self.idx2char = idx2char 16 | self.classes_num = len(char2idx) 17 | --------------------------------------------------------------------------------