├── README.md ├── bg_img ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg ├── config ├── det_DB_mobilev3.yaml ├── det_DB_resnet50.yaml ├── det_DB_resnet50_3_3.yaml ├── det_DB_resnet50_mul.yaml ├── det_PAN_mobilev3.yaml ├── det_PAN_resnet18.yaml ├── det_PAN_resnet18_3_3.yaml ├── det_PSE_mobilev3.yaml ├── det_PSE_resnet50.yaml ├── det_PSE_resnet50_3_3.yaml ├── det_SAST_resnet50.yaml ├── det_SAST_resnet50_3_3_ori_dataload.yaml ├── det_SAST_resnet50_ori_dataload.yaml ├── rec_CRNN_mobilev3_large_english_all.yaml ├── rec_CRNN_mobilev3_large_english_lmdb.yaml ├── rec_CRNN_mobilev3_small_english_all.yaml ├── rec_CRNN_mobilev3_small_english_lmdb.yaml ├── rec_CRNN_resnet34_english_lmdb.yaml ├── rec_CRNN_resnet_english.yaml ├── rec_CRNN_resnet_english_all.yaml └── rec_FC_resnet_english_all.yaml ├── doc ├── example │ ├── det_test_list.txt │ ├── det_train_list.txt │ ├── label.txt │ ├── rec_test_list.txt │ └── rec_train_list.txt ├── md │ ├── ocr.jpg │ ├── onnx_to_tensorrt.md │ ├── pytorch_to_onnx.md │ ├── 文本检测训练文档.md │ ├── 文本识别训练文档.md │ ├── 模型剪枝.md │ └── 模型蒸馏.md └── show │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── ocr1.jpg │ └── ocr2.jpg ├── make.sh ├── onnx └── onnx-simple.sh ├── ptocr ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── optimizer.cpython-36.pyc ├── dataloader │ ├── DetLoad │ │ ├── DBProcess.py │ │ ├── MakeBorderMap.py │ │ ├── MakeSegMap.py │ │ ├── PANProcess.py │ │ ├── PSEProcess.py │ │ ├── SASTProcess.py │ │ ├── SASTProcess_ori.py │ │ ├── SASTProcess_ori1.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── DBProcess.cpython-36.pyc │ │ │ ├── MakeBorderMap.cpython-36.pyc │ │ │ ├── MakeSegMap.cpython-36.pyc │ │ │ ├── SASTProcess_ori.cpython-36.pyc │ │ │ ├── SASTProcess_ori1.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── transform_img.cpython-36.pyc │ │ └── transform_img.py │ ├── RecLoad │ │ ├── CRNNProcess.py │ │ ├── CRNNProcess1.py │ │ ├── DataAgument.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── CRNNProcess.cpython-36.pyc │ │ │ ├── CRNNProcess1.cpython-36.pyc │ │ │ ├── DataAgument.cpython-36.pyc │ │ │ └── __init__.cpython-36.pyc │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-36.pyc ├── model │ ├── CommonFunction.py │ ├── CommonFunction_Q.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── CommonFunction.cpython-36.pyc │ │ └── __init__.cpython-36.pyc │ ├── architectures │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── det_model.cpython-35.pyc │ │ │ ├── det_model.cpython-36.pyc │ │ │ ├── rec_model.cpython-36.pyc │ │ │ ├── stn.cpython-36.pyc │ │ │ ├── stn_head.cpython-36.pyc │ │ │ └── tps_spatial_transformer.cpython-36.pyc │ │ ├── det_model.py │ │ ├── det_model_q.py │ │ ├── paddle_tps.py │ │ └── rec_model.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── det_mobilev3.cpython-35.pyc │ │ │ ├── det_mobilev3.cpython-36.pyc │ │ │ ├── det_mobilev3_dcd.cpython-36.pyc │ │ │ ├── det_resnet.cpython-35.pyc │ │ │ ├── det_resnet.cpython-36.pyc │ │ │ ├── det_resnet_3_3.cpython-36.pyc │ │ │ ├── det_resnet_sast.cpython-35.pyc │ │ │ ├── det_resnet_sast.cpython-36.pyc │ │ │ ├── det_resnet_sast_3_3.cpython-36.pyc │ │ │ ├── det_scnet.cpython-35.pyc │ │ │ ├── rec_crnn_backbone.cpython-36.pyc │ │ │ ├── rec_mobilev3_bd.cpython-36.pyc │ │ │ ├── rec_vgg.cpython-36.pyc │ │ │ └── reg_resnet_bd.cpython-36.pyc │ │ ├── det_mobilev3.py │ │ ├── det_mobilev3_pytorch_qua.py │ │ ├── det_resnet.py │ │ ├── det_resnet_3_3.py │ │ ├── det_resnet_sast.py │ │ ├── det_resnet_sast_3_3.py │ │ ├── rec_mobilev3_bd.py │ │ ├── reg_mobilev3.py │ │ └── reg_resnet_bd.py │ ├── head │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── det_DBHead.cpython-35.pyc │ │ │ ├── det_DBHead.cpython-36.pyc │ │ │ ├── det_FPEM_FFM_Head.cpython-35.pyc │ │ │ ├── det_FPEM_FFM_Head.cpython-36.pyc │ │ │ ├── det_FPNHead.cpython-35.pyc │ │ │ ├── det_FPNHead.cpython-36.pyc │ │ │ ├── det_SASTHead.cpython-35.pyc │ │ │ ├── det_SASTHead.cpython-36.pyc │ │ │ ├── rec_CRNNHead.cpython-36.pyc │ │ │ └── rec_FCHead.cpython-36.pyc │ │ ├── det_DBHead.py │ │ ├── det_DBHead_Qua.py │ │ ├── det_FPEM_FFM_Head.py │ │ ├── det_FPNHead.py │ │ ├── det_SASTHead.py │ │ ├── rec_CRNNHead.py │ │ └── rec_FCHead.py │ ├── loss │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── basical_loss.cpython-35.pyc │ │ │ ├── basical_loss.cpython-36.pyc │ │ │ ├── centerloss.cpython-36.pyc │ │ │ ├── ctc_loss.cpython-36.pyc │ │ │ ├── db_loss.cpython-35.pyc │ │ │ ├── db_loss.cpython-36.pyc │ │ │ ├── fc_loss.cpython-36.pyc │ │ │ ├── pan_loss.cpython-35.pyc │ │ │ ├── pan_loss.cpython-36.pyc │ │ │ ├── pse_loss.cpython-35.pyc │ │ │ ├── pse_loss.cpython-36.pyc │ │ │ ├── sast_loss.cpython-35.pyc │ │ │ └── sast_loss.cpython-36.pyc │ │ ├── basical_loss.py │ │ ├── centerloss.py │ │ ├── ctc_loss.py │ │ ├── db_loss.py │ │ ├── fc_loss.py │ │ ├── pan_loss.py │ │ ├── pse_loss.py │ │ └── sast_loss.py │ └── segout │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── det_DB_segout.cpython-35.pyc │ │ ├── det_DB_segout.cpython-36.pyc │ │ ├── det_PAN_segout.cpython-35.pyc │ │ ├── det_PAN_segout.cpython-36.pyc │ │ ├── det_PSE_segout.cpython-35.pyc │ │ ├── det_PSE_segout.cpython-36.pyc │ │ ├── det_SAST_segout.cpython-35.pyc │ │ └── det_SAST_segout.cpython-36.pyc │ │ ├── det_DB_segout.py │ │ ├── det_DB_segout_qua.py │ │ ├── det_PAN_segout.py │ │ ├── det_PSE_segout.py │ │ └── det_SAST_segout.py ├── optimizer.py ├── postprocess │ ├── DBpostprocess.py │ ├── PANpostprocess.py │ ├── PSEpostprocess.py │ ├── SASTpostprocess.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DBpostprocess.cpython-36.pyc │ │ ├── SASTpostprocess.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── locality_aware_nms.cpython-36.pyc │ ├── dbprocess │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── cppdbprocess.cpp │ │ ├── cppdbprocess.so │ │ └── include │ │ │ ├── clipper.h │ │ │ ├── clipper │ │ │ ├── clipper.cpp │ │ │ └── clipper.hpp │ │ │ ├── postprocess_op.h │ │ │ └── pybind11 │ │ │ ├── attr.h │ │ │ ├── buffer_info.h │ │ │ ├── cast.h │ │ │ ├── chrono.h │ │ │ ├── class_support.h │ │ │ ├── common.h │ │ │ ├── complex.h │ │ │ ├── descr.h │ │ │ ├── detail │ │ │ ├── class.h │ │ │ ├── common.h │ │ │ ├── descr.h │ │ │ ├── init.h │ │ │ ├── internals.h │ │ │ └── typeid.h │ │ │ ├── eigen.h │ │ │ ├── embed.h │ │ │ ├── eval.h │ │ │ ├── functional.h │ │ │ ├── iostream.h │ │ │ ├── numpy.h │ │ │ ├── operators.h │ │ │ ├── options.h │ │ │ ├── pybind11.h │ │ │ ├── pytypes.h │ │ │ ├── stl.h │ │ │ ├── stl_bind.h │ │ │ └── typeid.h │ ├── lanms │ │ ├── .gitignore │ │ ├── .ycm_extra_conf.py │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── adaptor.cpp │ │ ├── include │ │ │ ├── clipper │ │ │ │ ├── clipper.cpp │ │ │ │ └── clipper.hpp │ │ │ └── pybind11 │ │ │ │ ├── attr.h │ │ │ │ ├── buffer_info.h │ │ │ │ ├── cast.h │ │ │ │ ├── chrono.h │ │ │ │ ├── class_support.h │ │ │ │ ├── common.h │ │ │ │ ├── complex.h │ │ │ │ ├── descr.h │ │ │ │ ├── eigen.h │ │ │ │ ├── embed.h │ │ │ │ ├── eval.h │ │ │ │ ├── functional.h │ │ │ │ ├── numpy.h │ │ │ │ ├── operators.h │ │ │ │ ├── options.h │ │ │ │ ├── pybind11.h │ │ │ │ ├── pytypes.h │ │ │ │ ├── stl.h │ │ │ │ ├── stl_bind.h │ │ │ │ └── typeid.h │ │ └── lanms.h │ ├── locality_aware_nms.py │ └── piexlmerge │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── include │ │ ├── clipper │ │ │ ├── clipper.cpp │ │ │ └── clipper.hpp │ │ └── pybind11 │ │ │ ├── attr.h │ │ │ ├── buffer_info.h │ │ │ ├── cast.h │ │ │ ├── chrono.h │ │ │ ├── class_support.h │ │ │ ├── common.h │ │ │ ├── complex.h │ │ │ ├── descr.h │ │ │ ├── detail │ │ │ ├── class.h │ │ │ ├── common.h │ │ │ ├── descr.h │ │ │ ├── init.h │ │ │ ├── internals.h │ │ │ └── typeid.h │ │ │ ├── eigen.h │ │ │ ├── embed.h │ │ │ ├── eval.h │ │ │ ├── functional.h │ │ │ ├── iostream.h │ │ │ ├── numpy.h │ │ │ ├── operators.h │ │ │ ├── options.h │ │ │ ├── pybind11.h │ │ │ ├── pytypes.h │ │ │ ├── stl.h │ │ │ ├── stl_bind.h │ │ │ └── typeid.h │ │ ├── lanms.h │ │ ├── pixelmerge.cpp │ │ └── pixelmerge.so └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── cal_iou_acc.cpython-36.pyc │ ├── gen_teacher_model.cpython-36.pyc │ ├── logger.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ ├── prune_script.cpython-36.pyc │ ├── transform_label.cpython-36.pyc │ └── util_function.cpython-36.pyc │ ├── cal_iou_acc.py │ ├── gen_teacher_model.py │ ├── logger.py │ ├── metrics.py │ ├── prune_script.py │ ├── transform_label.py │ └── util_function.py ├── requirement.txt ├── script ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── onnx_to_tensorrt.cpython-36.pyc ├── create_lmdb.py ├── create_lmdb_multiprocessing.py ├── get_key_label.py ├── get_train_list.py ├── onnx_to_tensorrt.py ├── pytorch_to_onnx.py └── warp_polar.py ├── to_onnx.sh ├── to_tensorrt.sh └── tools ├── __init__.py ├── __pycache__ ├── MarginLoss.cpython-36.pyc └── __init__.cpython-36.pyc ├── cal.py ├── cal_rescall ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── cal_det.cpython-36.pyc │ ├── cal_iou.cpython-36.pyc │ ├── rrc_evaluation_funcs.cpython-35.pyc │ ├── rrc_evaluation_funcs.cpython-36.pyc │ ├── script.cpython-35.pyc │ └── script.cpython-36.pyc ├── cal_det.py ├── cal_iou.py ├── rrc_evaluation_funcs.py └── script.py ├── det_infer.py ├── det_sast.py ├── det_train.py ├── det_train_qua.py ├── pruned ├── __init__.py ├── prune_model_all.py └── prune_model_backbone.py ├── rec_infer.py ├── rec_infer_bk1.py ├── rec_train.py ├── rec_train_bk1.py ├── rec_train_bk2.py └── rec_train_bk3.py /bg_img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/1.jpg -------------------------------------------------------------------------------- /bg_img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/2.jpg -------------------------------------------------------------------------------- /bg_img/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/3.jpg -------------------------------------------------------------------------------- /bg_img/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/4.jpg -------------------------------------------------------------------------------- /bg_img/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/5.jpg -------------------------------------------------------------------------------- /bg_img/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/6.jpg -------------------------------------------------------------------------------- /bg_img/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/7.jpg -------------------------------------------------------------------------------- /bg_img/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/8.jpg -------------------------------------------------------------------------------- /bg_img/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/bg_img/9.jpg -------------------------------------------------------------------------------- /config/det_DB_mobilev3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' 3 | algorithm: DB 4 | pretrained: True 5 | in_channels: [24, 40, 48, 96] 6 | inner_channels: 96 7 | k: 50 8 | adaptive: True 9 | crop_shape: [640,640] 10 | shrink_ratio: 0.4 11 | n_epoch: 1200 12 | start_val: 700 13 | show_step: 20 14 | checkpoints: ./checkpoint 15 | save_epoch: 100 16 | restore: False 17 | restore_file : ./checkpoint/DB_best.pth.tar 18 | 19 | backbone: 20 | function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small 21 | 22 | head: 23 | function: ptocr.model.head.det_DBHead,DB_Head 24 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 25 | # function: ptocr.model.head.det_FPNHead,FPN_Head 26 | 27 | segout: 28 | function: ptocr.model.segout.det_DB_segout,SegDetector 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.det_model,DetModel 32 | loss_function: ptocr.model.architectures.det_model,DetLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.db_loss,DBLoss 36 | l1_scale: 10 37 | bce_scale: 1 38 | 39 | #optimizer: 40 | # function: ptocr.optimizer,AdamDecay 41 | # base_lr: 0.002 42 | # beta1: 0.9 43 | # beta2: 0.999 44 | 45 | optimizer: 46 | function: ptocr.optimizer,SGDDecay 47 | base_lr: 0.002 48 | momentum: 0.99 49 | weight_decay: 0.00005 50 | 51 | optimizer_decay: 52 | function: ptocr.optimizer,adjust_learning_rate_poly 53 | factor: 0.9 54 | 55 | #optimizer_decay: 56 | # function: ptocr.optimizer,adjust_learning_rate 57 | # schedule: [1,2] 58 | # gama: 0.1 59 | 60 | trainload: 61 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain 62 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 63 | num_workers: 10 64 | batch_size: 16 65 | 66 | testload: 67 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest 68 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 69 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 70 | test_size: 736 71 | stride: 32 72 | num_workers: 5 73 | batch_size: 4 74 | 75 | postprocess: 76 | function: ptocr.postprocess.DBpostprocess,DBPostProcess 77 | is_poly: False 78 | thresh: 0.5 79 | box_thresh: 0.6 80 | max_candidates: 1000 81 | unclip_ratio: 2 82 | min_size: 3 83 | 84 | infer: 85 | model_path: './checkpoint/DB_best.pth.tar' 86 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 87 | save_path: './result' 88 | -------------------------------------------------------------------------------- /config/det_DB_resnet50.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' # 设置训练的gpu id,多卡训练设置为 '0,1,2' 3 | algorithm: DB # 算法名称 4 | pretrained: True # 是否加载预训练 5 | in_channels: [256, 512, 1024, 2048] # 6 | inner_channels: 256 # 7 | k: 50 8 | adaptive: True 9 | crop_shape: [640,640] #训练时crop图片的大小 10 | shrink_ratio: 0.4 # kernel向内收缩比率 11 | n_epoch: 1200 # 训练的epoch 12 | start_val: 400 #开始验证的epoch,如果不想验证直接设置数值大于n_epoch 13 | show_step: 20 #设置迭代多少次输出一次loss 14 | checkpoints: ./checkpoint #保存模型地址 15 | save_epoch: 100 #设置每多少个epoch保存一次模型 16 | restore: False #是否恢复训练 17 | restore_file : ./DB.pth.tar #恢复训练所需加载模型的地址 18 | 19 | backbone: 20 | function: ptocr.model.backbone.det_resnet,resnet50 21 | 22 | head: 23 | function: ptocr.model.head.det_DBHead,DB_Head 24 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 25 | # function: ptocr.model.head.det_FPNHead,FPN_Head 26 | 27 | segout: 28 | function: ptocr.model.segout.det_DB_segout,SegDetector 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.det_model,DetModel 32 | loss_function: ptocr.model.architectures.det_model,DetLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.db_loss,DBLoss 36 | l1_scale: 10 37 | bce_scale: 1 38 | 39 | #optimizer: 40 | # function: ptocr.optimizer,AdamDecay 41 | # base_lr: 0.002 42 | # beta1: 0.9 43 | # beta2: 0.999 44 | 45 | optimizer: 46 | function: ptocr.optimizer,SGDDecay 47 | base_lr: 0.002 48 | momentum: 0.99 49 | weight_decay: 0.0005 50 | 51 | optimizer_decay: 52 | function: ptocr.optimizer,adjust_learning_rate_poly 53 | factor: 0.9 54 | 55 | #optimizer_decay: 56 | # function: ptocr.optimizer,adjust_learning_rate 57 | # schedule: [1,2] 58 | # gama: 0.1 59 | 60 | trainload: 61 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain 62 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 63 | num_workers: 10 64 | batch_size: 8 65 | 66 | testload: 67 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest 68 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 69 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 70 | test_size: 736 71 | stride: 32 72 | num_workers: 5 73 | batch_size: 4 74 | 75 | postprocess: 76 | function: ptocr.postprocess.DBpostprocess,DBPostProcess 77 | is_poly: False #测试时,检测弯曲文本设置成 True,否则就是输出矩形框 78 | thresh: 0.5 79 | box_thresh: 0.6 80 | max_candidates: 1000 81 | unclip_ratio: 2 82 | min_size: 3 83 | 84 | infer: 85 | model_path: './checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1200_bk/DB_best.pth.tar' 86 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 87 | save_path: './result' 88 | -------------------------------------------------------------------------------- /config/det_DB_resnet50_3_3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: DB 4 | pretrained: False 5 | in_channels: [256, 512, 1024, 2048] 6 | inner_channels: 256 7 | k: 50 8 | adaptive: True 9 | crop_shape: [640,640] 10 | shrink_ratio: 0.4 11 | n_epoch: 600 12 | start_val: 6000 13 | show_step: 20 14 | checkpoints: ./checkpoint 15 | save_epoch: 100 16 | restore: False 17 | restore_file : ./DB.pth.tar 18 | 19 | backbone: 20 | function: ptocr.model.backbone.det_resnet_3_3,resnet50 21 | 22 | head: 23 | # function: ptocr.model.head.det_DBHead,DB_Head 24 | function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 25 | # function: ptocr.model.head.det_FPNHead,FPN_Head 26 | 27 | segout: 28 | function: ptocr.model.segout.det_DB_segout,SegDetector 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.det_model,DetModel 32 | loss_function: ptocr.model.architectures.det_model,DetLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.db_loss,DBLoss 36 | l1_scale: 10 37 | bce_scale: 1 38 | 39 | #optimizer: 40 | # function: ptocr.optimizer,AdamDecay 41 | # base_lr: 0.002 42 | # beta1: 0.9 43 | # beta2: 0.999 44 | 45 | optimizer: 46 | function: ptocr.optimizer,SGDDecay 47 | base_lr: 0.002 48 | momentum: 0.99 49 | weight_decay: 0.0005 50 | 51 | optimizer_decay: 52 | function: ptocr.optimizer,adjust_learning_rate_poly 53 | factor: 0.9 54 | 55 | #optimizer_decay: 56 | # function: ptocr.optimizer,adjust_learning_rate 57 | # schedule: [1,2] 58 | # gama: 0.1 59 | 60 | trainload: 61 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrain 62 | train_file: /src/notebooks/MyworkData/huayandang/train_list.txt 63 | num_workers: 10 64 | batch_size: 8 65 | 66 | testload: 67 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest 68 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 69 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 70 | test_size: 736 71 | stride: 32 72 | num_workers: 5 73 | batch_size: 4 74 | 75 | postprocess: 76 | function: ptocr.postprocess.DBpostprocess,DBPostProcess 77 | is_poly: False 78 | thresh: 0.2 79 | box_thresh: 0.3 80 | max_candidates: 1000 81 | unclip_ratio: 1.5 82 | min_size: 3 83 | 84 | infer: 85 | model_path: './checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1200_bk/DB_best.pth.tar' 86 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 87 | save_path: './result' 88 | -------------------------------------------------------------------------------- /config/det_DB_resnet50_mul.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' # 设置训练的gpu id,多卡训练设置为 '0,1,2' 3 | algorithm: DB # 算法名称 4 | pretrained: True # 是否加载预训练 5 | in_channels: [256, 512, 1024, 2048] # 6 | inner_channels: 256 # 7 | k: 50 8 | n_class: 3 9 | adaptive: True 10 | crop_shape: [640,640] #训练时crop图片的大小 11 | shrink_ratio: 0.4 # kernel向内收缩比率 12 | n_epoch: 1200 # 训练的epoch 13 | start_val: 400 #开始验证的epoch,如果不想验证直接设置数值大于n_epoch 14 | show_step: 20 #设置迭代多少次输出一次loss 15 | checkpoints: ./checkpoint #保存模型地址 16 | save_epoch: 100 #设置每多少个epoch保存一次模型 17 | restore: False #是否恢复训练 18 | restore_file : ./DB.pth.tar #恢复训练所需加载模型的地址 19 | 20 | backbone: 21 | function: ptocr.model.backbone.det_resnet,resnet50 22 | 23 | head: 24 | function: ptocr.model.head.det_DBHead,DB_Head 25 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 26 | # function: ptocr.model.head.det_FPNHead,FPN_Head 27 | 28 | segout: 29 | function: ptocr.model.segout.det_DB_segout,SegDetectorMul 30 | 31 | architectures: 32 | model_function: ptocr.model.architectures.det_model,DetModel 33 | loss_function: ptocr.model.architectures.det_model,DetLoss 34 | 35 | loss: 36 | function: ptocr.model.loss.db_loss,DBLossMul 37 | l1_scale: 10 38 | bce_scale: 1 39 | class_scale: 1 40 | 41 | #optimizer: 42 | # function: ptocr.optimizer,AdamDecay 43 | # base_lr: 0.002 44 | # beta1: 0.9 45 | # beta2: 0.999 46 | 47 | optimizer: 48 | function: ptocr.optimizer,SGDDecay 49 | base_lr: 0.002 50 | momentum: 0.99 51 | weight_decay: 0.0005 52 | 53 | optimizer_decay: 54 | function: ptocr.optimizer,adjust_learning_rate_poly 55 | factor: 0.9 56 | 57 | #optimizer_decay: 58 | # function: ptocr.optimizer,adjust_learning_rate 59 | # schedule: [1,2] 60 | # gama: 0.1 61 | 62 | trainload: 63 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTrainMul 64 | train_file: /src/notebooks/fangxuwei_96/TextGenerator-master/output/train/train_list.txt 65 | num_workers: 10 66 | batch_size: 8 67 | 68 | testload: 69 | function: ptocr.dataloader.DetLoad.DBProcess,DBProcessTest 70 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 71 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 72 | test_size: 736 73 | stride: 32 74 | num_workers: 5 75 | batch_size: 4 76 | 77 | postprocess: 78 | function: ptocr.postprocess.DBpostprocess,DBPostProcessMul 79 | is_poly: False #测试时,检测弯曲文本设置成 True,否则就是输出矩形框 80 | thresh: 0.5 81 | box_thresh: 0.6 82 | max_candidates: 1000 83 | unclip_ratio: 2 84 | min_size: 3 85 | 86 | infer: 87 | model_path: './checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_601_train_mul/DB_best.pth.tar' 88 | path: '/src/notebooks/fangxuwei_96/TextGenerator-master/output/img/' 89 | save_path: './result' 90 | -------------------------------------------------------------------------------- /config/det_PAN_mobilev3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: PAN 4 | pretrained: False 5 | in_channels: [24, 40, 48, 96] 6 | inner_channels: 96 7 | classes: 6 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.4 10 | n_epoch: 600 11 | show_step: 20 12 | start_val: 300 13 | save_epoch: 100 14 | checkpoints: ./checkpoint 15 | restore: False 16 | restore_file: ./PAN.pth.tar 17 | 18 | backbone: 19 | function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small 20 | 21 | head: 22 | # function: ptocr.model.head.det_DBHead,DB_Head 23 | function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 24 | # function: ptocr.model.head.det_FPNHead,FPN_Head 25 | segout: 26 | function: ptocr.model.segout.det_PAN_segout,SegDetector 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.det_model,DetModel 30 | loss_function: ptocr.model.architectures.det_model,DetLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.pan_loss,PANLoss 34 | kernel_rate: 0.5 35 | agg_dis_rate: 0.25 36 | 37 | #optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.002 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | 43 | optimizer: 44 | function: ptocr.optimizer,SGDDecay 45 | base_lr: 0.001 46 | momentum: 0.99 47 | weight_decay: 0.00005 48 | 49 | # optimizer_decay: 50 | # function: ptocr.optimizer,adjust_learning_rate_poly 51 | # factor: 0.9 52 | 53 | optimizer_decay: 54 | function: ptocr.optimizer,adjust_learning_rate 55 | schedule: [200,400,500] 56 | gama: 0.1 57 | 58 | 59 | trainload: 60 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTrain 61 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 62 | num_workers: 10 63 | batch_size: 16 64 | 65 | testload: 66 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTest 67 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 68 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 69 | test_size: 736 70 | stride: 32 71 | num_workers: 5 72 | batch_size: 2 73 | 74 | postprocess: 75 | function: ptocr.postprocess.PANpostprocess,PANPostProcess 76 | is_poly: False 77 | bin_th: 1 78 | scale: 1 79 | min_kernel_area: 8 80 | min_text_area: 50 81 | min_score: 0.93 82 | dis_thresh: 6 -------------------------------------------------------------------------------- /config/det_PAN_resnet18.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' 3 | algorithm: PAN 4 | pretrained: True 5 | in_channels: [64, 128, 256, 512] 6 | inner_channels: 128 7 | classes: 6 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.5 10 | n_epoch: 600 11 | show_step: 20 12 | start_val: 300 13 | save_epoch: 50 14 | checkpoints: ./checkpoint 15 | restore: False 16 | restore_file: ./checkpoint/ag_PAN_bb_resnet18_he_FPEM_FFM_Head_bs_16_ep_600/PAN_200.pth.tar 17 | 18 | backbone: 19 | function: ptocr.model.backbone.det_resnet,resnet18 20 | 21 | head: 22 | # function: ptocr.model.head.det_DBHead,DB_Head 23 | function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 24 | # function: ptocr.model.head.det_FPNHead,FPN_Head 25 | segout: 26 | function: ptocr.model.segout.det_PAN_segout,SegDetector 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.det_model,DetModel 30 | loss_function: ptocr.model.architectures.det_model,DetLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.pan_loss,PANLoss 34 | kernel_rate: 0.5 35 | agg_dis_rate: 0.25 36 | 37 | #optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.002 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | 43 | optimizer: 44 | function: ptocr.optimizer,SGDDecay 45 | base_lr: 0.001 46 | momentum: 0.99 47 | weight_decay: 0.00005 48 | 49 | # optimizer_decay: 50 | # function: ptocr.optimizer,adjust_learning_rate_poly 51 | # factor: 0.9 52 | 53 | optimizer_decay: 54 | function: ptocr.optimizer,adjust_learning_rate 55 | schedule: [200,400,500] 56 | gama: 0.1 57 | 58 | 59 | trainload: 60 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTrain 61 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 62 | num_workers: 10 63 | batch_size: 16 64 | 65 | testload: 66 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTest 67 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 68 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 69 | test_size: 736 70 | stride: 32 71 | num_workers: 5 72 | batch_size: 2 73 | 74 | postprocess: 75 | function: ptocr.postprocess.PANpostprocess,PANPostProcess 76 | is_poly: False 77 | bin_th: 1 78 | scale: 1 79 | min_kernel_area: 4 80 | min_text_area: 300 81 | min_score: 0.90 82 | dis_thresh: 1 83 | 84 | infer: 85 | model_path: './checkpoint/ag_PAN_bb_resnet18_he_FPEM_FFM_Head_bs_16_ep_600/PAN_best.pth.tar' 86 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 87 | save_path: './result' -------------------------------------------------------------------------------- /config/det_PAN_resnet18_3_3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: PAN 4 | pretrained: True 5 | in_channels: [64, 128, 256, 512] 6 | inner_channels: 128 7 | classes: 6 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.5 10 | n_epoch: 601 11 | show_step: 20 12 | start_val: 200 13 | save_epoch: 100 14 | checkpoints: ./checkpoint 15 | restore: True 16 | restore_file: ./checkpoint/ag_PAN_bb_resnet18_he_FPEM_FFM_Head_bs_14_ep_601/PAN_best.pth.tar 17 | 18 | backbone: 19 | function: ptocr.model.backbone.det_resnet_3*3,resnet18 20 | 21 | head: 22 | # function: ptocr.model.head.det_DBHead,DB_Head 23 | function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 24 | # function: ptocr.model.head.det_FPNHead,FPN_Head 25 | segout: 26 | function: ptocr.model.segout.det_PAN_segout,SegDetector 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.det_model,DetModel 30 | loss_function: ptocr.model.architectures.det_model,DetLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.pan_loss,PANLoss 34 | kernel_rate: 0.5 35 | agg_dis_rate: 0.25 36 | 37 | #optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.002 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | 43 | optimizer: 44 | function: ptocr.optimizer,SGDDecay 45 | base_lr: 0.001 46 | momentum: 0.99 47 | weight_decay: 0.00005 48 | 49 | optimizer_decay: 50 | function: ptocr.optimizer,adjust_learning_rate_poly 51 | factor: 0.9 52 | 53 | # optimizer_decay: 54 | # function: ptocr.optimizer,adjust_learning_rate 55 | # schedule: [200,400,500] 56 | # gama: 0.1 57 | 58 | 59 | trainload: 60 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTrain 61 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 62 | num_workers: 10 63 | batch_size: 14 64 | 65 | testload: 66 | function: ptocr.dataloader.DetLoad.PANProcess,PANProcessTest 67 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 68 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 69 | test_size: 736 70 | stride: 32 71 | num_workers: 5 72 | batch_size: 2 73 | 74 | postprocess: 75 | function: ptocr.postprocess.PANpostprocess,PANPostProcess 76 | is_poly: False 77 | bin_th: 1 78 | scale: 1 79 | min_kernel_area: 4 80 | min_text_area: 300 81 | min_score: 0.90 82 | dis_thresh: 1 83 | 84 | infer: 85 | model_path: './checkpoint/ag_PAN_bb_resnet18_he_FPEM_FFM_Head_bs_14_ep_601/PAN_best.pth.tar' 86 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 87 | save_path: './result' -------------------------------------------------------------------------------- /config/det_PSE_mobilev3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' 3 | algorithm: PSE 4 | pretrained: True 5 | in_channels: [24, 40, 48, 96] 6 | inner_channels: 96 7 | classes: 7 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.4 10 | n_epoch: 1200 11 | show_step: 5 12 | checkpoints: ./checkpoint 13 | restore: False 14 | restore_file: ./PSE.pth.tar 15 | 16 | backbone: 17 | function: ptocr.model.backbone.det_mobilev3,mobilenet_v3_small 18 | 19 | head: 20 | # function: ptocr.model.head.det_DBHead,DB_Head 21 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 22 | function: ptocr.model.head.det_FPNHead,FPN_Head 23 | 24 | segout: 25 | function: ptocr.model.segout.det_PSE_segout,SegDetector 26 | 27 | architectures: 28 | model_function: ptocr.model.architectures.det_model,DetModel 29 | loss_function: ptocr.model.architectures.det_model,DetLoss 30 | 31 | loss: 32 | function: ptocr.model.loss.pse_loss,PSELoss 33 | text_tatio: 0.7 34 | 35 | #optimizer: 36 | # function: ptocr.optimizer,AdamDecay 37 | # base_lr: 0.002 38 | # beta1: 0.9 39 | # beta2: 0.999 40 | 41 | optimizer: 42 | function: ptocr.optimizer,SGDDecay 43 | base_lr: 0.002 44 | momentum: 0.99 45 | 46 | optimizer_decay: 47 | function: ptocr.optimizer,adjust_learning_rate_poly 48 | factor: 0.9 49 | 50 | trainload: 51 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTrain 52 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 53 | num_workers: 0 54 | batch_size: 2 55 | 56 | testload: 57 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTest 58 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 59 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 60 | test_size: 2240 61 | stride: 32 62 | num_workers: 5 63 | batch_size: 2 64 | 65 | postprocess: 66 | function: ptocr.postprocess.PSEpostprocess,PSEPostProcess 67 | is_poly: True 68 | binary_th: 1 69 | scale: 1 70 | min_kernel_area: 5 71 | min_text_area: 800 72 | min_score: 0.93 73 | 74 | infer: 75 | model_path: './checkpoint/ag_PSE_bb_mobilenet_v3_small_he_FPN_Head_bs_16_ep_1200/PSE_400.pth.tar' 76 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 77 | save_path: './result' 78 | -------------------------------------------------------------------------------- /config/det_PSE_resnet50.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: PSE 4 | pretrained: True 5 | in_channels: [256, 512, 1024, 2048] 6 | inner_channels: 256 7 | classes: 7 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.4 10 | n_epoch: 600 11 | show_step: 20 12 | start_val: 400 13 | save_epoch: 100 14 | checkpoints: ./checkpoint 15 | restore: False 16 | restore_file: ./checkpoint/ag_PSE_bb_resnet50_he_FPN_Head_bs_8_ep_600/PSE_400.pth.tar 17 | 18 | backbone: 19 | function: ptocr.model.backbone.det_resnet,resnet50 20 | 21 | head: 22 | # function: ptocr.model.head.det_DBHead,DB_Head 23 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 24 | function: ptocr.model.head.det_FPNHead,FPN_Head 25 | 26 | segout: 27 | function: ptocr.model.segout.det_PSE_segout,SegDetector 28 | 29 | architectures: 30 | model_function: ptocr.model.architectures.det_model,DetModel 31 | loss_function: ptocr.model.architectures.det_model,DetLoss 32 | 33 | loss: 34 | function: ptocr.model.loss.pse_loss,PSELoss 35 | text_tatio: 0.7 36 | 37 | #optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.002 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | 43 | optimizer: 44 | function: ptocr.optimizer,SGDDecay 45 | base_lr: 0.001 46 | momentum: 0.99 47 | weight_decay: 0.00005 48 | 49 | # optimizer_decay: 50 | # function: ptocr.optimizer,adjust_learning_rate_poly 51 | # factor: 0.9 52 | 53 | optimizer_decay: 54 | function: ptocr.optimizer,adjust_learning_rate 55 | schedule: [200,400,500] 56 | gama: 0.1 57 | 58 | trainload: 59 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTrain 60 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 61 | num_workers: 12 62 | batch_size: 8 63 | 64 | testload: 65 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTest 66 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 67 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 68 | test_size: 2240 69 | stride: 32 70 | num_workers: 5 71 | batch_size: 1 72 | 73 | postprocess: 74 | function: ptocr.postprocess.PSEpostprocess,PSEPostProcess 75 | is_poly: False 76 | binary_th: 1 77 | scale: 1 78 | min_kernel_area: 5 79 | min_text_area: 800 80 | min_score: 0.93 81 | 82 | infer: 83 | model_path: './checkpoint/ag_PSE_bb_resnet50_he_FPN_Head_bs_8_ep_600/PSE_best.pth.tar' 84 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 85 | save_path: './result' 86 | 87 | -------------------------------------------------------------------------------- /config/det_PSE_resnet50_3_3.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: PSE 4 | pretrained: True 5 | in_channels: [256, 512, 1024, 2048] 6 | inner_channels: 256 7 | classes: 7 8 | crop_shape: [640,640] 9 | shrink_ratio: 0.4 10 | n_epoch: 601 11 | show_step: 20 12 | start_val: 400 13 | save_epoch: 100 14 | checkpoints: ./checkpoint 15 | restore: False 16 | restore_file: ./checkpoint/ag_PSE_bb_resnet50_he_FPN_Head_bs_8_ep_601/PSE_best.pth.tar 17 | 18 | backbone: 19 | function: ptocr.model.backbone.det_resnet_3*3,resnet50 20 | 21 | head: 22 | # function: ptocr.model.head.det_DBHead,DB_Head 23 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 24 | function: ptocr.model.head.det_FPNHead,FPN_Head 25 | 26 | segout: 27 | function: ptocr.model.segout.det_PSE_segout,SegDetector 28 | 29 | architectures: 30 | model_function: ptocr.model.architectures.det_model,DetModel 31 | loss_function: ptocr.model.architectures.det_model,DetLoss 32 | 33 | loss: 34 | function: ptocr.model.loss.pse_loss,PSELoss 35 | text_tatio: 0.7 36 | 37 | #optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.002 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | 43 | optimizer: 44 | function: ptocr.optimizer,SGDDecay 45 | base_lr: 0.001 46 | momentum: 0.99 47 | weight_decay: 0.00005 48 | 49 | # optimizer_decay: 50 | # function: ptocr.optimizer,adjust_learning_rate_poly 51 | # factor: 0.9 52 | 53 | optimizer_decay: 54 | function: ptocr.optimizer,adjust_learning_rate 55 | schedule: [200,400,500] 56 | gama: 0.1 57 | 58 | trainload: 59 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTrain 60 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 61 | num_workers: 12 62 | batch_size: 8 63 | 64 | testload: 65 | function: ptocr.dataloader.DetLoad.PSEProcess,PSEProcessTest 66 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 67 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 68 | test_size: 2240 69 | stride: 32 70 | num_workers: 5 71 | batch_size: 1 72 | 73 | postprocess: 74 | function: ptocr.postprocess.PSEpostprocess,PSEPostProcess 75 | is_poly: False 76 | binary_th: 1 77 | scale: 1 78 | min_kernel_area: 5 79 | min_text_area: 800 80 | min_score: 0.93 81 | 82 | infer: 83 | model_path: './checkpoint/ag_PSE_bb_resnet50_he_FPN_Head_bs_8_ep_600/PSE_best.pth.tar' 84 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 85 | save_path: './result' 86 | 87 | -------------------------------------------------------------------------------- /config/det_SAST_resnet50.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: SAST 4 | pretrained: True 5 | with_attention: True 6 | crop_shape: [512,512] 7 | n_epoch: 900 8 | start_val: 600 9 | show_step: 20 10 | checkpoints: ./checkpoint 11 | save_epoch: 100 12 | restore: False 13 | restore_file : ./checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_12_ep_2000/SAST_best.pth.tar 14 | 15 | backbone: 16 | function: ptocr.model.backbone.det_resnet_sast,resnet50 17 | 18 | head: 19 | function: ptocr.model.head.det_SASTHead,SASTHead 20 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 21 | # function: ptocr.model.head.det_FPNHead,FPN_Head 22 | 23 | segout: 24 | function: ptocr.model.segout.det_SAST_segout,SegDetector 25 | 26 | architectures: 27 | model_function: ptocr.model.architectures.det_model,DetModel 28 | loss_function: ptocr.model.architectures.det_model,DetLoss 29 | 30 | loss: 31 | function: ptocr.model.loss.sast_loss,SASTLoss 32 | tvo_lw: 1.5 33 | tco_lw: 1.5 34 | score_lw: 1.0 35 | border_lw: 1.0 36 | 37 | # optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.001 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | # weight_decay: 0.00005 43 | 44 | optimizer: 45 | function: ptocr.optimizer,RMSPropDecay 46 | base_lr: 0.001 47 | momentum: 0 48 | alpha: 0.95 49 | weight_decay: 0.00005 50 | 51 | 52 | # optimizer: 53 | # function: ptocr.optimizer,SGDDecay 54 | # weight_decay: 0.00005 55 | # base_lr: 0.005 56 | # momentum: 0.95 57 | 58 | # optimizer_decay: 59 | # function: ptocr.optimizer,adjust_learning_rate_poly 60 | # factor: 0.9 61 | 62 | optimizer_decay: 63 | function: ptocr.optimizer,adjust_learning_rate 64 | schedule: [300,600,800,850] 65 | gama: 0.3 66 | 67 | 68 | trainload: 69 | function: ptocr.dataloader.DetLoad.SASTProcess,SASTProcessTrain 70 | train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 71 | num_workers: 12 72 | batch_size: 8 73 | min_crop_side_ratio: 0.3 74 | min_crop_size: 24 75 | min_text_size: 4 76 | 77 | 78 | testload: 79 | function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTest 80 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 81 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 82 | test_size: 1536 83 | stride: 128 84 | num_workers: 5 85 | batch_size: 4 86 | 87 | postprocess: 88 | function: ptocr.postprocess.SASTpostprocess,SASTPostProcess 89 | is_poly: False 90 | score_thresh: 0.5 91 | nms_thresh: 0.2 92 | sample_pts_num: 2 93 | shrink_ratio_of_width: 0.3 94 | expand_scale: 1.0 95 | tcl_map_thresh: 0.7 96 | 97 | infer: 98 | model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_1000/SAST_best.pth.tar' 99 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 100 | save_path: './result' 101 | -------------------------------------------------------------------------------- /config/det_SAST_resnet50_3_3_ori_dataload.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' 3 | algorithm: SAST 4 | pretrained: True 5 | with_attention: True 6 | crop_shape: [512,512] 7 | n_epoch: 901 8 | start_val: 5000 9 | show_step: 20 10 | checkpoints: ./checkpoint 11 | save_epoch: 100 12 | restore: False 13 | restore_file : ./checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_12_ep_2000/SAST_best.pth.tar 14 | 15 | backbone: 16 | function: ptocr.model.backbone.det_resnet_sast_3_3,resnet50 17 | 18 | head: 19 | function: ptocr.model.head.det_SASTHead,SASTHead 20 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 21 | # function: ptocr.model.head.det_FPNHead,FPN_Head 22 | 23 | segout: 24 | function: ptocr.model.segout.det_SAST_segout,SegDetector 25 | 26 | architectures: 27 | model_function: ptocr.model.architectures.det_model,DetModel 28 | loss_function: ptocr.model.architectures.det_model,DetLoss 29 | 30 | loss: 31 | function: ptocr.model.loss.sast_loss,SASTLoss 32 | tvo_lw: 1.5 33 | tco_lw: 1.5 34 | score_lw: 1.0 35 | border_lw: 1.0 36 | 37 | # optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.001 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | # weight_decay: 0.00005 43 | 44 | optimizer: 45 | function: ptocr.optimizer,RMSPropDecay 46 | base_lr: 0.001 47 | momentum: 0 48 | alpha: 0.95 49 | weight_decay: 0.00005 50 | 51 | 52 | # optimizer: 53 | # function: ptocr.optimizer,SGDDecay 54 | # weight_decay: 0.00005 55 | # base_lr: 0.005 56 | # momentum: 0.95 57 | 58 | # optimizer_decay: 59 | # function: ptocr.optimizer,adjust_learning_rate_poly 60 | # factor: 0.9 61 | 62 | optimizer_decay: 63 | function: ptocr.optimizer,adjust_learning_rate 64 | schedule: [300,600,800,850] 65 | gama: 0.3 66 | 67 | 68 | trainload: 69 | function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTrain 70 | train_file: /src/notebooks/MyworkData/huayandang/train_list.txt 71 | num_workers: 12 72 | batch_size: 8 73 | min_crop_side_ratio: 0.3 74 | min_crop_size: 24 75 | min_text_size: 4 76 | 77 | 78 | testload: 79 | function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTest 80 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 81 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 82 | test_size: 1536 83 | stride: 128 84 | num_workers: 5 85 | batch_size: 4 86 | 87 | postprocess: 88 | function: ptocr.postprocess.SASTpostprocess,SASTPostProcess 89 | is_poly: False 90 | score_thresh: 0.5 91 | nms_thresh: 0.2 92 | sample_pts_num: 2 93 | shrink_ratio_of_width: 0.3 94 | expand_scale: 1.0 95 | tcl_map_thresh: 0.7 96 | 97 | infer: 98 | model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_901/SAST_400.pth.tar' 99 | path: '/src/notebooks/MyworkData/huayandang/train' 100 | save_path: './result' 101 | -------------------------------------------------------------------------------- /config/det_SAST_resnet50_ori_dataload.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '2' 3 | algorithm: SAST 4 | pretrained: True 5 | with_attention: True 6 | crop_shape: [512,512] 7 | n_epoch: 1000 8 | start_val: 600 9 | show_step: 20 10 | checkpoints: ./checkpoint 11 | save_epoch: 100 12 | restore: False 13 | restore_file : ./checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_12_ep_2000/SAST_best.pth.tar 14 | 15 | backbone: 16 | function: ptocr.model.backbone.det_resnet_sast,resnet50 17 | 18 | head: 19 | function: ptocr.model.head.det_SASTHead,SASTHead 20 | # function: ptocr.model.head.det_FPEM_FFM_Head,FPEM_FFM_Head 21 | # function: ptocr.model.head.det_FPNHead,FPN_Head 22 | 23 | segout: 24 | function: ptocr.model.segout.det_SAST_segout,SegDetector 25 | 26 | architectures: 27 | model_function: ptocr.model.architectures.det_model,DetModel 28 | loss_function: ptocr.model.architectures.det_model,DetLoss 29 | 30 | loss: 31 | function: ptocr.model.loss.sast_loss,SASTLoss 32 | tvo_lw: 1.5 33 | tco_lw: 1.5 34 | score_lw: 1.0 35 | border_lw: 1.0 36 | 37 | # optimizer: 38 | # function: ptocr.optimizer,AdamDecay 39 | # base_lr: 0.001 40 | # beta1: 0.9 41 | # beta2: 0.999 42 | # weight_decay: 0.00005 43 | 44 | optimizer: 45 | function: ptocr.optimizer,RMSPropDecay 46 | base_lr: 0.001 47 | momentum: 0 48 | alpha: 0.95 49 | weight_decay: 0.00005 50 | 51 | 52 | # optimizer: 53 | # function: ptocr.optimizer,SGDDecay 54 | # weight_decay: 0.00005 55 | # base_lr: 0.005 56 | # momentum: 0.95 57 | 58 | # optimizer_decay: 59 | # function: ptocr.optimizer,adjust_learning_rate_poly 60 | # factor: 0.9 61 | 62 | optimizer_decay: 63 | function: ptocr.optimizer,adjust_learning_rate 64 | schedule: [300,600,800,850] 65 | gama: 0.3 66 | 67 | 68 | # trainload: 69 | # function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTrain 70 | # train_file: /src/notebooks/detect_text/icdar2015/train_list.txt 71 | # num_workers: 12 72 | # batch_size: 8 73 | # min_crop_side_ratio: 0.3 74 | # min_crop_size: 24 75 | # min_text_size: 4 76 | 77 | trainload: 78 | function: ptocr.dataloader.DetLoad.SASTProcess_ori1,SASTProcessTrain 79 | data_dir: /src/notebooks/dataset_for_sast 80 | train_file_target: icdar2015/train_label_json.txt 81 | train_file_extre: [icdar17_mlt_latin/train_label_json.txt,coco_text_icdar_4pts/train_label_json.txt,icdar2013/train_label_json.txt] 82 | train_file_ratio: 0.5 83 | num_workers: 12 84 | batch_size: 8 85 | min_crop_side_ratio: 0.3 86 | min_crop_size: 24 87 | min_text_size: 4 88 | 89 | testload: 90 | function: ptocr.dataloader.DetLoad.SASTProcess_ori,SASTProcessTest 91 | test_file: /src/notebooks/detect_text/icdar2015/test_list.txt 92 | test_gt_path: /src/notebooks/detect_text/icdar2015/ch4_test_gts/ 93 | test_size: 1536 94 | stride: 128 95 | num_workers: 5 96 | batch_size: 4 97 | 98 | postprocess: 99 | function: ptocr.postprocess.SASTpostprocess,SASTPostProcess 100 | is_poly: False 101 | score_thresh: 0.5 102 | nms_thresh: 0.2 103 | sample_pts_num: 2 104 | shrink_ratio_of_width: 0.3 105 | expand_scale: 1.0 106 | tcl_map_thresh: 0.7 107 | 108 | infer: 109 | model_path: './checkpoint/ag_SAST_bb_resnet50_he_SASTHead_bs_8_ep_1000/SAST_best.pth.tar' 110 | path: '/src/notebooks/detect_text/icdar2015/ch4_test_images' 111 | save_path: './result' 112 | -------------------------------------------------------------------------------- /config/rec_CRNN_mobilev3_large_english_all.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0,1' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 960 6 | hiddenchannel: 96 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_conv: False 10 | use_attention: False 11 | use_lstm: True 12 | lstm_num: 2 13 | classes: 1000 14 | max_iters: 300000 15 | eval_iter: 10000 16 | show_step: 100 17 | checkpoints: ./checkpoint 18 | save_epoch: 1 19 | show_num: 10 20 | restore: False 21 | finetune: False 22 | restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar 23 | 24 | backbone: 25 | function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large 26 | 27 | head: 28 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.rec_model,RecModel 32 | loss_function: ptocr.model.architectures.rec_model,RecLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.ctc_loss,CTCLoss 36 | use_ctc_weight: False 37 | reduction: 'mean' 38 | center_function: ptocr.model.loss.centerloss,CenterLoss 39 | use_center: False 40 | center_lr: 0.5 41 | label_score: 0.95 42 | # min_score: 0.01 43 | weight_center: 0.000001 44 | 45 | 46 | optimizer: 47 | function: ptocr.optimizer,AdamDecay 48 | base_lr: 0.001 49 | beta1: 0.9 50 | beta2: 0.999 51 | weight_decay: 0.00005 52 | 53 | # optimizer: 54 | # function: ptocr.optimizer,SGDDecay 55 | # base_lr: 0.002 56 | # momentum: 0.99 57 | # weight_decay: 0.00005 58 | 59 | # optimizer_decay: 60 | # function: ptocr.optimizer,adjust_learning_rate_poly 61 | # factor: 0.9 62 | 63 | optimizer_decay: 64 | function: ptocr.optimizer,adjust_learning_rate 65 | schedule: [100000,200000] 66 | gama: 0.1 67 | 68 | optimizer_decay_center: 69 | function: ptocr.optimizer,adjust_learning_rate_center 70 | schedule: [100000,200000] 71 | gama: 0.1 72 | 73 | trainload: 74 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad 75 | train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] 76 | batch_ratio: [0.5,0.5] 77 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 78 | bg_path: ./bg_img/ 79 | num_workers: 16 80 | batch_size: 512 81 | 82 | testload: 83 | function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest 84 | test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt 85 | num_workers: 8 86 | batch_size: 256 87 | 88 | 89 | label_transform: 90 | function: ptocr.utils.transform_label,strLabelConverter 91 | 92 | transform: 93 | function: ptocr.dataloader.RecLoad.DataAgument,transform_label 94 | t_type: lower 95 | char_type: En 96 | 97 | infer: 98 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 99 | model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_large_he_CRNN_Head_bs_512_ep_300000_mobilev2_alldata/CRNN_210000.pth.tar' 100 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 101 | path: './english_val_img/' 102 | save_path: '' 103 | -------------------------------------------------------------------------------- /config/rec_CRNN_mobilev3_large_english_lmdb.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 960 6 | hiddenchannel: 96 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_attention: False 10 | use_lstm: True 11 | lstm_num: 2 12 | n_epoch: 8 13 | start_val: 0 14 | show_step: 50 15 | checkpoints: ./checkpoint 16 | save_epoch: 1 17 | show_num: 10 18 | restore: False 19 | finetune: False 20 | restore_file : ./checkpoint/ 21 | 22 | backbone: 23 | function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_large 24 | 25 | head: 26 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.rec_model,RecModel 30 | loss_function: ptocr.model.architectures.rec_model,RecLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.ctc_loss,CTCLoss 34 | ctc_type: 'warpctc' # torchctc 35 | use_ctc_weight: False 36 | loss_title: ['ctc_loss'] 37 | 38 | optimizer: 39 | function: ptocr.optimizer,AdamDecay 40 | base_lr: 0.001 41 | beta1: 0.9 42 | beta2: 0.999 43 | weight_decay: 0.00005 44 | 45 | 46 | optimizer_decay: 47 | function: ptocr.optimizer,adjust_learning_rate 48 | schedule: [4,6] 49 | gama: 0.1 50 | 51 | 52 | trainload: 53 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 54 | train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' 55 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 56 | bg_path: ./bg_img/ 57 | num_workers: 10 58 | batch_size: 512 59 | 60 | valload: 61 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 62 | val_file: '/src/notebooks/IIIT5k_3000/lmdb/' 63 | num_workers: 5 64 | batch_size: 256 65 | 66 | label_transform: 67 | function: ptocr.utils.transform_label,strLabelConverter 68 | label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label 69 | t_type: lower 70 | char_type: En 71 | 72 | infer: 73 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 74 | model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' 75 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 76 | path: './english_val_img/SVT/image/' 77 | save_path: '' 78 | -------------------------------------------------------------------------------- /config/rec_CRNN_mobilev3_small_english_all.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '1' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 576 6 | hiddenchannel: 48 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_conv: False 10 | use_attention: False 11 | use_lstm: True 12 | lstm_num: 2 13 | classes: 1000 14 | max_iters: 300000 15 | eval_iter: 10000 16 | show_step: 100 17 | checkpoints: ./checkpoint 18 | save_epoch: 1 19 | show_num: 10 20 | restore: False 21 | finetune: False 22 | restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar 23 | 24 | backbone: 25 | function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small 26 | 27 | head: 28 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.rec_model,RecModel 32 | loss_function: ptocr.model.architectures.rec_model,RecLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.ctc_loss,CTCLoss 36 | use_ctc_weight: False 37 | reduction: 'mean' 38 | center_function: ptocr.model.loss.centerloss,CenterLoss 39 | use_center: False 40 | center_lr: 0.5 41 | label_score: 0.95 42 | # min_score: 0.01 43 | weight_center: 0.000001 44 | 45 | 46 | optimizer: 47 | function: ptocr.optimizer,AdamDecay 48 | base_lr: 0.001 49 | beta1: 0.9 50 | beta2: 0.999 51 | weight_decay: 0.00005 52 | 53 | # optimizer: 54 | # function: ptocr.optimizer,SGDDecay 55 | # base_lr: 0.002 56 | # momentum: 0.99 57 | # weight_decay: 0.00005 58 | 59 | # optimizer_decay: 60 | # function: ptocr.optimizer,adjust_learning_rate_poly 61 | # factor: 0.9 62 | 63 | optimizer_decay: 64 | function: ptocr.optimizer,adjust_learning_rate 65 | schedule: [100000,200000] 66 | gama: 0.1 67 | 68 | optimizer_decay_center: 69 | function: ptocr.optimizer,adjust_learning_rate_center 70 | schedule: [100000,200000] 71 | gama: 0.1 72 | 73 | trainload: 74 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad 75 | train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] 76 | batch_ratio: [0.5,0.5] 77 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 78 | bg_path: ./bg_img/ 79 | num_workers: 16 80 | batch_size: 512 81 | 82 | valload: 83 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad 84 | root: '/src/notebooks/pytorchOCR-master/english_val_img' 85 | dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077'] 86 | test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt 87 | num_workers: 2 88 | batch_size: 1 89 | 90 | 91 | label_transform: 92 | function: ptocr.utils.transform_label,strLabelConverter 93 | 94 | transform: 95 | function: ptocr.dataloader.RecLoad.DataAgument,transform_label 96 | t_type: lower 97 | char_type: En 98 | 99 | infer: 100 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 101 | model_path: './checkpoint/ag_CRNN_bb_mobilenet_v3_small_he_CRNN_Head_bs_512_ep_300000_mobilev2_small_alldata/CRNN_210000.pth.tar' 102 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 103 | path: './english_val_img/' 104 | save_path: '' 105 | -------------------------------------------------------------------------------- /config/rec_CRNN_mobilev3_small_english_lmdb.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 576 6 | hiddenchannel: 48 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_attention: False 10 | use_lstm: True 11 | lstm_num: 2 12 | n_epoch: 8 13 | start_val: 0 14 | show_step: 50 15 | checkpoints: ./checkpoint 16 | save_epoch: 1 17 | show_num: 10 18 | restore: False 19 | finetune: False 20 | restore_file : ./checkpoint/ 21 | 22 | backbone: 23 | function: ptocr.model.backbone.rec_mobilev3_bd,mobilenet_v3_small 24 | 25 | head: 26 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.rec_model,RecModel 30 | loss_function: ptocr.model.architectures.rec_model,RecLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.ctc_loss,CTCLoss 34 | ctc_type: 'warpctc' # torchctc 35 | use_ctc_weight: False 36 | loss_title: ['ctc_loss'] 37 | 38 | optimizer: 39 | function: ptocr.optimizer,AdamDecay 40 | base_lr: 0.001 41 | beta1: 0.9 42 | beta2: 0.999 43 | weight_decay: 0.00005 44 | 45 | 46 | optimizer_decay: 47 | function: ptocr.optimizer,adjust_learning_rate 48 | schedule: [4,6] 49 | gama: 0.1 50 | 51 | 52 | trainload: 53 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 54 | train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' 55 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 56 | bg_path: ./bg_img/ 57 | num_workers: 10 58 | batch_size: 512 59 | 60 | valload: 61 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 62 | val_file: '/src/notebooks/IIIT5k_3000/lmdb/' 63 | num_workers: 5 64 | batch_size: 256 65 | 66 | label_transform: 67 | function: ptocr.utils.transform_label,strLabelConverter 68 | label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label 69 | t_type: lower 70 | char_type: En 71 | 72 | infer: 73 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 74 | model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' 75 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 76 | path: './english_val_img/SVT/image/' 77 | save_path: '' 78 | -------------------------------------------------------------------------------- /config/rec_CRNN_resnet34_english_lmdb.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 512 6 | hiddenchannel: 128 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_attention: False 10 | use_lstm: True 11 | lstm_num: 2 12 | n_epoch: 8 13 | start_val: 0 14 | show_step: 50 15 | checkpoints: ./checkpoint 16 | save_epoch: 1 17 | show_num: 10 18 | restore: True 19 | finetune: False 20 | restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_no_attention_no_weight/CRNN_best.pth.tar 21 | 22 | backbone: 23 | function: ptocr.model.backbone.reg_resnet_bd,resnet34 24 | 25 | head: 26 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 27 | 28 | architectures: 29 | model_function: ptocr.model.architectures.rec_model,RecModel 30 | loss_function: ptocr.model.architectures.rec_model,RecLoss 31 | 32 | loss: 33 | function: ptocr.model.loss.ctc_loss,CTCLoss 34 | ctc_type: 'warpctc' # torchctc 35 | use_ctc_weight: False 36 | loss_title: ['ctc_loss'] 37 | 38 | optimizer: 39 | function: ptocr.optimizer,AdamDecay 40 | base_lr: 0.001 41 | beta1: 0.9 42 | beta2: 0.999 43 | weight_decay: 0.00005 44 | 45 | 46 | optimizer_decay: 47 | function: ptocr.optimizer,adjust_learning_rate 48 | schedule: [4,6] 49 | gama: 0.1 50 | 51 | 52 | trainload: 53 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 54 | train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' 55 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 56 | bg_path: ./bg_img/ 57 | num_workers: 16 58 | batch_size: 512 59 | 60 | valload: 61 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessLmdbLoad 62 | val_file: '/src/notebooks/IIIT5k_3000/lmdb/' 63 | num_workers: 5 64 | batch_size: 256 65 | 66 | label_transform: 67 | function: ptocr.utils.transform_label,strLabelConverter 68 | label_function: ptocr.dataloader.RecLoad.DataAgument,transform_label 69 | t_type: lower 70 | char_type: En 71 | 72 | infer: 73 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 74 | model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' 75 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 76 | path: './english_val_img/SVT/image/' 77 | save_path: '' 78 | -------------------------------------------------------------------------------- /config/rec_CRNN_resnet_english.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0,1' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 512 6 | hiddenchannel: 128 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_conv: False 10 | use_attention: False 11 | use_lstm: True 12 | lstm_num: 2 13 | classes: 1000 14 | n_epoch: 8 15 | start_val: 0 16 | show_step: 100 17 | checkpoints: ./checkpoint 18 | save_epoch: 1 19 | show_num: 10 20 | restore: True 21 | finetune: True 22 | restore_file : ./checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_256_ep_20_no_channel_timestep_rnn/CRNN_best.pth.tar 23 | 24 | backbone: 25 | function: ptocr.model.backbone.reg_resnet_bd,resnet34 26 | 27 | head: 28 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.rec_model,RecModel 32 | loss_function: ptocr.model.architectures.rec_model,RecLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.ctc_loss,CTCLoss 36 | use_ctc_weight: True 37 | reduction: 'none' 38 | center_function: ptocr.model.loss.centerloss,CenterLoss 39 | use_center: True 40 | center_lr: 0.5 41 | label_score: 0.95 42 | # min_score: 0.01 43 | weight_center: 0.001 44 | 45 | 46 | optimizer: 47 | function: ptocr.optimizer,AdamDecay 48 | base_lr: 0.001 49 | beta1: 0.9 50 | beta2: 0.999 51 | weight_decay: 0.00005 52 | 53 | # optimizer: 54 | # function: ptocr.optimizer,SGDDecay 55 | # base_lr: 0.002 56 | # momentum: 0.99 57 | # weight_decay: 0.00005 58 | 59 | # optimizer_decay: 60 | # function: ptocr.optimizer,adjust_learning_rate_poly 61 | # factor: 0.9 62 | 63 | optimizer_decay: 64 | function: ptocr.optimizer,adjust_learning_rate 65 | schedule: [4,6] 66 | gama: 0.1 67 | 68 | optimizer_decay_center: 69 | function: ptocr.optimizer,adjust_learning_rate_center 70 | schedule: [4,6] 71 | gama: 0.1 72 | 73 | trainload: 74 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTrainLmdb 75 | train_file: '/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/' 76 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 77 | bg_path: ./bg_img/ 78 | num_workers: 10 79 | batch_size: 512 80 | 81 | testload: 82 | function: ptocr.dataloader.RecLoad.CRNNProcess,CRNNProcessTest 83 | test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt 84 | num_workers: 5 85 | batch_size: 256 86 | 87 | 88 | label_transform: 89 | function: ptocr.utils.transform_label,strLabelConverter 90 | 91 | infer: 92 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 93 | model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_8_center_loss/CRNN_best.pth.tar' 94 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 95 | path: './english_val_img/SVT/image/' 96 | save_path: '' 97 | -------------------------------------------------------------------------------- /config/rec_CRNN_resnet_english_all.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0,1' 3 | algorithm: CRNN 4 | pretrained: False 5 | inchannel: 512 6 | hiddenchannel: 256 7 | img_shape: [32,100] 8 | is_gray: True 9 | use_conv: False 10 | use_attention: False 11 | use_lstm: True 12 | lstm_num: 2 13 | classes: 1000 14 | max_iters: 200000 15 | eval_iter: 10000 16 | show_step: 100 17 | checkpoints: ./checkpoint 18 | save_epoch: 1 19 | show_num: 10 20 | restore: False 21 | finetune: False 22 | restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar 23 | 24 | backbone: 25 | function: ptocr.model.backbone.reg_resnet_bd,resnet34 26 | 27 | head: 28 | function: ptocr.model.head.rec_CRNNHead,CRNN_Head 29 | 30 | architectures: 31 | model_function: ptocr.model.architectures.rec_model,RecModel 32 | loss_function: ptocr.model.architectures.rec_model,RecLoss 33 | 34 | loss: 35 | function: ptocr.model.loss.ctc_loss,CTCLoss 36 | use_ctc_weight: False 37 | reduction: 'mean' 38 | center_function: ptocr.model.loss.centerloss,CenterLoss 39 | use_center: False 40 | center_lr: 0.5 41 | label_score: 0.95 42 | # min_score: 0.01 43 | weight_center: 0.000001 44 | 45 | 46 | optimizer: 47 | function: ptocr.optimizer,AdamDecay 48 | base_lr: 0.001 49 | beta1: 0.9 50 | beta2: 0.999 51 | weight_decay: 0.00005 52 | 53 | # optimizer: 54 | # function: ptocr.optimizer,SGDDecay 55 | # base_lr: 0.002 56 | # momentum: 0.99 57 | # weight_decay: 0.00005 58 | 59 | # optimizer_decay: 60 | # function: ptocr.optimizer,adjust_learning_rate_poly 61 | # factor: 0.9 62 | 63 | optimizer_decay: 64 | function: ptocr.optimizer,adjust_learning_rate 65 | schedule: [80000,160000] 66 | gama: 0.1 67 | 68 | optimizer_decay_center: 69 | function: ptocr.optimizer,adjust_learning_rate_center 70 | schedule: [80000,160000] 71 | gama: 0.1 72 | 73 | trainload: 74 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad 75 | train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] 76 | batch_ratio: [0.5,0.5] 77 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 78 | bg_path: ./bg_img/ 79 | num_workers: 16 80 | batch_size: 512 81 | 82 | testload: 83 | function: ptocr.dataloader.RecLoad.CRNNProcess1,CRNNProcessTest 84 | test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt 85 | num_workers: 8 86 | batch_size: 256 87 | 88 | 89 | label_transform: 90 | function: ptocr.utils.transform_label,strLabelConverter 91 | 92 | transform: 93 | function: ptocr.dataloader.RecLoad.DataAgument,transform_label 94 | t_type: lower 95 | char_type: En 96 | 97 | infer: 98 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 99 | model_path: './checkpoint/ag_CRNN_bb_resnet34_he_CRNN_Head_bs_512_ep_200000_alldata/CRNN_120000.pth.tar' 100 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 101 | path: './english_val_img/' 102 | save_path: '' 103 | -------------------------------------------------------------------------------- /config/rec_FC_resnet_english_all.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | gpu_id: '0,1' 3 | algorithm: FC 4 | pretrained: False 5 | in_channels: 2048 6 | out_channels: 1024 7 | ignore_index: 37 8 | max_length: 25 9 | img_shape: [32,100] 10 | is_gray: True 11 | use_conv: False 12 | use_attention: False 13 | use_lstm: True 14 | lstm_num: 2 15 | num_class: 36 16 | start_iters: 0 17 | max_iters: 300000 18 | eval_iter: 10000 19 | show_step: 100 20 | checkpoints: ./checkpoint 21 | save_epoch: 1 22 | show_num: 10 23 | restore: False 24 | finetune: False 25 | restore_file : ./checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_20_20210207English/CRNN_best.pth.tar 26 | 27 | backbone: 28 | function: ptocr.model.backbone.reg_resnet_bd,resnet50 29 | 30 | head: 31 | function: ptocr.model.head.rec_FCHead,FC_Head 32 | 33 | architectures: 34 | model_function: ptocr.model.architectures.rec_model,RecModel 35 | loss_function: ptocr.model.architectures.rec_model,RecLoss 36 | 37 | loss: 38 | function: ptocr.model.loss.fc_loss,FCLoss 39 | use_ctc_weight: False 40 | reduction: 'mean' 41 | center_function: ptocr.model.loss.centerloss,CenterLoss 42 | use_center: False 43 | center_lr: 0.5 44 | label_score: 0.95 45 | # min_score: 0.01 46 | weight_center: 0.000001 47 | 48 | 49 | optimizer: 50 | function: ptocr.optimizer,AdamDecay 51 | base_lr: 0.001 52 | beta1: 0.9 53 | beta2: 0.999 54 | weight_decay: 0.00005 55 | 56 | # optimizer: 57 | # function: ptocr.optimizer,SGDDecay 58 | # base_lr: 0.002 59 | # momentum: 0.99 60 | # weight_decay: 0.00005 61 | 62 | # optimizer_decay: 63 | # function: ptocr.optimizer,adjust_learning_rate_poly 64 | # factor: 0.9 65 | 66 | optimizer_decay: 67 | function: ptocr.optimizer,adjust_learning_rate 68 | schedule: [100000,200000] 69 | gama: 0.1 70 | 71 | optimizer_decay_center: 72 | function: ptocr.optimizer,adjust_learning_rate_center 73 | schedule: [80000,160000] 74 | gama: 0.1 75 | 76 | trainload: 77 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetDataLoad 78 | train_file: ['/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/SynthText/','/src/notebooks/MyworkData/EnglishCrnnData/train_lmdb/MJSynth'] 79 | batch_ratio: [0.5,0.5] 80 | key_file: /src/notebooks/MyworkData/EnglishCrnnData/key_new.txt 81 | bg_path: ./bg_img/ 82 | num_workers: 16 83 | batch_size: 256 84 | 85 | valload: 86 | function: ptocr.dataloader.RecLoad.CRNNProcess1,GetValDataLoad 87 | root: '/src/notebooks/pytorchOCR-master/english_val_img' 88 | dir: ['CUTE80','IC03_867','IC13_1015','IC13_857','IC15_1811','IIIT5k_3000','SVT','SVTP','IC15_2077'] 89 | test_file: /src/notebooks/MyworkData/EnglishCrnnData/val_new.txt 90 | num_workers: 2 91 | batch_size: 1 92 | 93 | 94 | label_transform: 95 | function: ptocr.utils.transform_label,FCConverter 96 | 97 | transform: 98 | function: ptocr.dataloader.RecLoad.DataAgument,transform_label 99 | t_type: lower 100 | char_type: En 101 | 102 | infer: 103 | # model_path: './checkpoint/ag_CRNN_bb_rec_crnn_backbone_he_CRNN_Head_bs_256_ep_10_synthtext/CRNN_best.pth.tar' 104 | model_path: './checkpoint/ag_FC_bb_resnet34_he_FC_Head_bs_128_ep_200000_FC/FC_190000.pth.tar' 105 | # path: '/src/notebooks/MyworkData/EnglishCrnnData/image/2697/6/107_Ramification_62303.jpg' 106 | path: './english_val_img/' 107 | save_path: '' 108 | -------------------------------------------------------------------------------- /doc/example/det_test_list.txt: -------------------------------------------------------------------------------- 1 | /src/notebooks/detect_text/icdar2015/image/img_1000.jpg 2 | /src/notebooks/detect_text/icdar2015/image/img_100.jpg 3 | /src/notebooks/detect_text/icdar2015/image/img_101.jpg 4 | /src/notebooks/detect_text/icdar2015/image/img_102.jpg 5 | /src/notebooks/detect_text/icdar2015/image/img_103.jpg 6 | /src/notebooks/detect_text/icdar2015/image/img_104.jpg 7 | /src/notebooks/detect_text/icdar2015/image/img_105.jpg 8 | /src/notebooks/detect_text/icdar2015/image/img_106.jpg 9 | /src/notebooks/detect_text/icdar2015/image/img_107.jpg -------------------------------------------------------------------------------- /doc/example/det_train_list.txt: -------------------------------------------------------------------------------- 1 | /src/notebooks/detect_text/icdar2015/image/img_1000.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_1000.txt 2 | /src/notebooks/detect_text/icdar2015/image/img_100.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_100.txt 3 | /src/notebooks/detect_text/icdar2015/image/img_101.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_101.txt 4 | /src/notebooks/detect_text/icdar2015/image/img_102.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_102.txt 5 | /src/notebooks/detect_text/icdar2015/image/img_103.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_103.txt 6 | /src/notebooks/detect_text/icdar2015/image/img_104.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_104.txt 7 | /src/notebooks/detect_text/icdar2015/image/img_105.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_105.txt 8 | /src/notebooks/detect_text/icdar2015/image/img_106.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_106.txt 9 | /src/notebooks/detect_text/icdar2015/image/img_107.jpg /src/notebooks/detect_text/icdar2015/label/gt_img_107.txt -------------------------------------------------------------------------------- /doc/example/label.txt: -------------------------------------------------------------------------------- 1 | 367,87,426,86,433,140,375,141,### 2 | 381,212,431,217,434,240,384,236,text 3 | 386,261,447,265,450,287,389,283,text 4 | -------------------------------------------------------------------------------- /doc/example/rec_test_list.txt: -------------------------------------------------------------------------------- 1 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_1.png JOINT 2 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_2.png yourself 3 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_3.png 154 4 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_4.png 197 5 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_5.png 727 6 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_6.png 198 7 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_7.png 20029 8 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_8.png Free 9 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_9.png from 10 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_10.png PAIN 11 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_11.png BLOCK 12 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_12.png 441B 13 | /src/notebooks/detect_text/icdar2015/recognize/test_img/word_13.png STOREY -------------------------------------------------------------------------------- /doc/example/rec_train_list.txt: -------------------------------------------------------------------------------- 1 | /src/notebooks/detect_text/icdar2015/recognize/image/word_1.png Genaxis Theatre 2 | /src/notebooks/detect_text/icdar2015/recognize/image/word_2.png [06] 3 | /src/notebooks/detect_text/icdar2015/recognize/image/word_3.png 62-03 4 | /src/notebooks/detect_text/icdar2015/recognize/image/word_4.png Carpark 5 | /src/notebooks/detect_text/icdar2015/recognize/image/word_5.png EXIT 6 | /src/notebooks/detect_text/icdar2015/recognize/image/word_6.png I2R 7 | /src/notebooks/detect_text/icdar2015/recognize/image/word_7.png fusionopolis 8 | /src/notebooks/detect_text/icdar2015/recognize/image/word_8.png fusionopolis 9 | /src/notebooks/detect_text/icdar2015/recognize/image/word_9.png Reserve 10 | /src/notebooks/detect_text/icdar2015/recognize/image/word_10.png CAUTION -------------------------------------------------------------------------------- /doc/md/ocr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/md/ocr.jpg -------------------------------------------------------------------------------- /doc/md/onnx_to_tensorrt.md: -------------------------------------------------------------------------------- 1 | ### onnx转tensorrt 2 | 本项目使用tensorrt版本:TensorRT-7.0.0.11 3 | 4 | 5 | 6 | #### 参数解释 7 | pytorch_to_onnx.py 文件参数 8 | |参数|含义| 9 | |-|-| 10 | |config|算法的配置文件| 11 | |model_path|训练好的模型文件| 12 | |img_path|测试的图片| 13 | |save_path|onnx保存文件| 14 | |batch_size|设置测试batch| 15 | |max_size|设置最长边| 16 | |algorithm|算法名称| 17 | |add_padding|是否将短边padding到和长边一样| 18 | 19 | 20 | 21 | onnx_to_tensorrt.py 文件参数 22 | |参数|含义| 23 | |-|-| 24 | |onnx_path|生成的onnx文件| 25 | |trt_engine_path|保存的engine文件路径| 26 | |img_path|测试的图片| 27 | |batch_size|设置测试batch| 28 | |max_size|设置最长边| 29 | |algorithm|算法名称| 30 | |add_padding|是否将短边padding到和长边一样| 31 | 32 | #### 单张调用 33 | 1. 生成onnx文件 34 | 35 | 36 | - DB算法调用 37 | ``` 38 | python3 ./script/pytorch_to_onnx.py --config ./config/det_DB_mobilev3.yaml --model_path ./checkpoint/DB_best.pth.tar --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images/img_10.jpg --save_path ./onnx/DB.onnx --batch_size 1 --max_size 1536 --algorithm DB --add_padding 39 | ``` 40 | 2. simple onnx文件 41 | 42 | ``` 43 | sh onnx-simple.sh DB.onnx DB-simple.onnx 44 | ``` 45 | 46 | 3. 生成tensorrt engine 47 | - DB算法调用 48 | ``` 49 | CUDA_VISIBLE_DEVICES=2 python3 ./script/onnx_to_tensorrt.py --onnx_path ./onnx/DB-simple.onnx --trt_engine_path ./onnx/DB.engine --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images/img_10.jpg --batch_size 1 --algorithm DB --max_size 1536 --add_padding 50 | ``` 51 | 52 | 4. infer 调用 53 | 54 | ``` 55 | python3 ./tools/det_infer.py --config ./config/det_DB_mobilev3.yaml --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images --result_save_path ./result --trt_path ./onnx/DB.engine --batch_size 1 --max_size 1536 --add_padding 56 | ``` 57 | 58 | #### batch 调用 59 | 60 | 操作同上,和单张调用一样,只是要把batch_size设置大于1 61 | 62 | - 提示:其余算法类似 63 | -------------------------------------------------------------------------------- /doc/md/pytorch_to_onnx.md: -------------------------------------------------------------------------------- 1 | ### pytorch 转onnx 2 | 3 | #### 1. 运行根目录to_onnx.sh文件: 4 | 5 | ``` 6 | sh to_onnx.sh 7 | ``` 8 | 里面有四个参数,需要对应修改,参数解释如下: 9 | |参数|解释| 10 | |-|-| 11 | |config|对应算法的config文件| 12 | |model_path|对应算法的模型文件| 13 | |img_path|测试图片| 14 | |save_path|保存onnx文件| 15 | - 提示:这里onnx文件建议生成在项目中onnx文件夹下 16 | #### 2. 使用onnx文件夹下的 onnx-simple.sh对生成的onnx文件进行精简,运行: 17 | 18 | ``` 19 | sh onnx-simple.sh 生成的onnx文件 精简后的onnx文件 20 | ``` 21 | 例如: 22 | 23 | ``` 24 | sh onnx-simple.sh DBnet.onnx DBnet-simple.onnx 25 | ``` 26 | #### 3. onnx调用 27 | 运行: 28 | 29 | ``` 30 | python3 ./tools/det_infer.py --config ./config/det_DB_mobilev3.yaml --model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_400/DB_64.pth.tar --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images/img_10.jpg --result_save_path ./result --onnx_path ./onnx/DBnet-simple.onnx 31 | ``` 32 | - 提示:这里如果加上--onnx_path就是onnx调用,否则是pytorch调用 33 | 34 | 35 | -------------------------------------------------------------------------------- /doc/md/文本检测训练文档.md: -------------------------------------------------------------------------------- 1 | ## 训练文档 2 | *** 3 | 4 | ### step 1 环境安装和预训练模型 5 | 1. 编译c++后处理文件 6 | 7 | ``` 8 | sh make.sh 9 | ``` 10 | 2. 下载预训练模型 11 | 预训练模型地址:[下载链接](https://pan.baidu.com/s/1zONYFPsS3szaf5BHeQh5ZA)(code:fxw6) 12 | 13 | 3. 下载icdar2015测试模型(不做测试可跳过这一步) 14 | 测试模型地址:[下载链接](https://pan.baidu.com/s/1zONYFPsS3szaf5BHeQh5ZA)(code:fxw6) 15 | 16 | 4. 将下载下来的pre_model和checkpoint文件夹分别替换项目中的同名文件夹 17 | *** 18 | 19 | ### step 2 准备训练所需文件 20 | 1. 准备一个train_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/det_train_list.txt),格式是:图片文件绝对位置+图片文件标注txt文件绝对位置,用 \t 分隔. 参照下面修改成你自己的地址。 21 | 22 | ``` 23 | python3 ./script/get_train_list.py --img_path /src/train/image --label_path /src/train/label --save_path /src/train 24 | ``` 25 | 运行后在/src/train会生成一个train_list.txt。 26 | - 提示:你的图片和label文件要同名。如果你要在训练时做验证要生成一个test_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/det_test_list.txt),在yaml中start_val可设置多少epoch后开始做验证。 27 | - 制作label文件说明:照着icdar2015的格式, x1,y1,x2,y2,x3,y3,x4,y4,label,其中不参与训练文本(例如模糊文本),label设置为###,代表不参与训练,除此之外表示参与训练,可以像我这样用text,或者别的也行。[label格式参照这里](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/label.txt) 28 | *** 29 | ### step 3 模型训练 30 | #### 正常模型训练 31 | 32 | 33 | 1. 修改./config中对应算法的yaml中参数,基本上只需修改数据路径即可。 34 | 2. 运行下面命令 35 | 36 | ``` 37 | python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str train_log --n_epoch 1200 --start_val 600 --base_lr 0.002 --gpu_id 2 38 | ``` 39 | 40 | - 提示:在./config/det_DB_resnet50.yaml里有[参数解释](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/config/det_DB_resnet50.yaml),其他的yaml文件都是类似的,可以参照里面的。log_str是为了多次训练的时候结果可以保存在不同文件夹 41 | 42 | #### 断点恢复训练 43 | 将yaml文件中base下的restore置为True,restore_file填上恢复训练的模型地址,运行: 44 | ``` 45 | python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str train_log --n_epoch 1200 --start_val 600 --base_lr 0.002 --gpu_id 2 46 | ``` 47 | 48 | 49 | *** 50 | 51 | ### step 4 模型测试 52 | 1. 修改infer.sh中的参数 53 | 2. 运行下面命令 54 | 55 | ``` 56 | sh infer.sh 57 | ``` 58 | - 提示:测试时infer.sh文件中的 --img_path 既可以是文件夹也可以是文件 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /doc/md/文本识别训练文档.md: -------------------------------------------------------------------------------- 1 | ### 文本识别 2 | #### 数据准备 3 | 4 | 需要一个train_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_train_list.txt) , 格式:图片绝对路径+\t+label。 具体可参照项目中data/example中例子。 5 | 如果训练过程中需要做验证,需要制作相同的数据格式有一个test_list.txt[示例](https://github.com/BADBADBADBOY/pytorchOCR/blob/master/doc/example/rec_test_list.txt)。 6 | 7 | #### 训练模型 8 | 1. 修改./config中对应算法的yaml中参数,基本上只需修改数据路径即可。 9 | 2. 在./tools/rec_train.py最下面打开不同的config中的yaml对应不同的算法 10 | 3. 运行下面命令 11 | 12 | ``` 13 | python3 ./tools/rec_train.py 14 | ``` 15 | #### 测试模型 16 | 1. 运行下面命令 17 | 18 | ``` 19 | python3 ./tools/rec_infer.py 20 | ``` -------------------------------------------------------------------------------- /doc/md/模型剪枝.md: -------------------------------------------------------------------------------- 1 | ### 模型剪枝 2 | 3 | 4 | 5 | 这里暂时支持对mobilev3 DBnet进行剪枝。尝试了对backbone和对整个模型两种方式压缩。 6 | 7 | #### 参数解释 8 | prune_model_all.py 文件参数 9 | |参数|含义|额外说明| 10 | |-|-|-| 11 | |config|算法的配置文件|| 12 | |cut_percent|剪枝比率|由于类似resnet的跨层连接,剪枝比率不完全等于这里设置的,可能偏小| 13 | |base_num|保证剪完后的channel是base_num的倍数|除去剪完后为1的,其余是base_num的倍数| 14 | |checkpoint|稀疏训练好的模型文件|| 15 | |save_prune_model_path|剪完后保存的模型文件地址|这里会生成两个文件(其实可以合成一个保存)| 16 | |img_file|测试的图片|| 17 | 18 | #### 如何操作 19 | 20 | 1. 稀疏训练 21 | ``` 22 | python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str train_pruned --sr_lr 0.00007 --n_epoch 1200 --start_val 600 --base_lr 0.001 --gpu_id 2 23 | ``` 24 | 25 | 2. 模型压缩 26 | ``` 27 | python3 tools/pruned/prune_model_all.py --config ./config/det_DB_mobilev3.yaml --base_num 2 --cut_percent 0.6 --checkpoint ./checkpoint/DB_best.pth.tar --save_prune_model_path ./checkpoint/pruned/ --img_file ./icdar2015/test/img_108.jpg 28 | ``` 29 | 30 | 3. 剪枝后finetune 31 | ``` 32 | python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str total_prune_finetune --pruned_model_dict_path ./checkpoint/pruned/pruned_dict.dict --prune_model_path ./checkpoint/pruned/pruned_dict.pth --prune_type total --n_epoch 200 --start_val 30 --base_lr 0.0008 --gpu_id 2 33 | ``` 34 | 35 | -------------------------------------------------------------------------------- /doc/md/模型蒸馏.md: -------------------------------------------------------------------------------- 1 | ### 模型蒸馏 2 | 3 | 这里使用的是通过大模型的前项输出作为soft label和生成的label同时监督训练,这里目前只对DBnet做了尝试,其余算法应该类似, 4 | 其中soft label监督只对binary和thresh_binary,忽略了thresh,测试下来这种方式效果最好。 5 | 6 | 7 | 8 | #### 如何训练(对压缩后模型进行蒸馏) 9 | 10 | 11 | ``` 12 | python3 tools/det_train.py 13 | --config ./config/det_DB_mobilev3.yaml 14 | --log_str total_prune_20201015_distil3 15 | --pruned_model_dict_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.dict 16 | --prune_model_path ./checkpoint/ag_DB_bb_mobilenet_v3_small_he_DB_Head_bs_16_ep_1200_mobile_slim_all/pruned/pruned_dict.pth 17 | --prune_type total 18 | --n_epoch 200 19 | --start_val 30 20 | --base_lr 0.0008 21 | --gpu_id 2 22 | --t_ratio 0.1 23 | --t_model_path ./checkpoint/ag_DB_bb_resnet50_he_DB_Head_bs_8_ep_1201/DB_best.pth.tar 24 | --t_config ./config/det_DB_resnet50_3_3.yaml 25 | ``` 26 | -------------------------------------------------------------------------------- /doc/show/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/show/1.jpg -------------------------------------------------------------------------------- /doc/show/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/show/2.jpg -------------------------------------------------------------------------------- /doc/show/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/show/3.jpg -------------------------------------------------------------------------------- /doc/show/ocr1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/show/ocr1.jpg -------------------------------------------------------------------------------- /doc/show/ocr2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/doc/show/ocr2.jpg -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | # make nms 2 | cd ptocr/postprocess/lanms/ 3 | make clean 4 | make 5 | 6 | # make pan,pse 7 | cd ../piexlmerge 8 | make clean 9 | make 10 | 11 | # make dbprocess 12 | cd ../dbprocess 13 | make clean 14 | make -------------------------------------------------------------------------------- /onnx/onnx-simple.sh: -------------------------------------------------------------------------------- 1 | python3 -m onnxsim $1 $2 -------------------------------------------------------------------------------- /ptocr/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ -------------------------------------------------------------------------------- /ptocr/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/__pycache__/optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/__pycache__/optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/11 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/DBProcess.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/MakeBorderMap.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/MakeSegMap.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/SASTProcess_ori1.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/DetLoad/__pycache__/transform_img.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/RecLoad/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/11 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/RecLoad/__pycache__/CRNNProcess1.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/RecLoad/__pycache__/DataAgument.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/RecLoad/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/11 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/dataloader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/dataloader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/CommonFunction.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: CommonFunction.py 5 | @time: 2020/08/07 6 | """ 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | class DeConvBnRelu(nn.Module): 11 | def __init__(self,in_channels,out_channels,kernel_size=4,stride=2,with_relu=False,padding=1,bias=False): 12 | """ 13 | :param in_channels: 14 | :param out_channels: 15 | :param kernel_size: 16 | :param stride: 17 | :param padding: 18 | :param bias: 19 | """ 20 | super(DeConvBnRelu,self).__init__() 21 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=padding,bias=bias) # Reduce channels 22 | self.bn = nn.BatchNorm2d(out_channels) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.with_relu = with_relu 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | if self.with_relu: 30 | x = self.relu(x) 31 | return x 32 | 33 | class ConvBnRelu(nn.Module): 34 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding,with_relu=True,bias=False): 35 | """ 36 | :param in_channels: 37 | :param out_channels: 38 | :param kernel_size: 39 | :param stride: 40 | :param padding: 41 | :param bias: 42 | """ 43 | super(ConvBnRelu,self).__init__() 44 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,bias=bias) 45 | self.bn = nn.BatchNorm2d(out_channels) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.with_relu = with_relu 48 | def forward(self, x): 49 | x = self.conv(x) 50 | x = self.bn(x) 51 | if self.with_relu: 52 | x = self.relu(x) 53 | return x 54 | 55 | class DWBlock(nn.Module): 56 | def __init__(self,in_channels,out_channels,kernel_size,stride,bias=False): 57 | super(DWBlock,self).__init__() 58 | self.dw_conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=kernel_size//2,groups=out_channels,bias=bias) 59 | self.point_conv = nn.Conv2d(out_channels,out_channels,kernel_size=1,stride=1,padding=0,bias=bias) 60 | self.point_bn = nn.BatchNorm2d(out_channels) 61 | self.point_relu = nn.ReLU() 62 | 63 | def forward(self, x): 64 | x = self.dw_conv(x) 65 | x = self.point_relu(self.point_bn(self.point_conv(x))) 66 | return x 67 | 68 | 69 | def upsample(x, y, scale=1): 70 | _, _, H, W = y.size() 71 | # return F.upsample(x, size=(H // scale, W // scale), mode='nearest') 72 | return F.interpolate(x, size=(H // scale, W // scale), mode='nearest') 73 | 74 | def upsample_add(x, y): 75 | _, _, H, W = y.size() 76 | # return F.upsample(x, size=(H, W), mode='nearest') + y 77 | return F.interpolate(x, size=(H, W), mode='nearest') + y -------------------------------------------------------------------------------- /ptocr/model/CommonFunction_Q.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: CommonFunction_Q.py 5 | @time: 2020/11/02 6 | """ 7 | import torch.nn as nn 8 | 9 | class ConvBnRelu(nn.Module): 10 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding,groups,bias=False): 11 | """ 12 | :param in_channels: 13 | :param out_channels: 14 | :param kernel_size: 15 | :param stride: 16 | :param padding: 17 | :param bias: 18 | """ 19 | super(ConvBnRelu,self).__init__() 20 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,groups=groups,bias=bias) 21 | self.bn = nn.BatchNorm2d(out_channels) 22 | self.relu = nn.ReLU(inplace=False) 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | x = self.bn(x) 27 | x = self.relu(x) 28 | return x 29 | 30 | 31 | class ConvBn(nn.Module): 32 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding,groups,bias=False): 33 | """ 34 | :param in_channels: 35 | :param out_channels: 36 | :param kernel_size: 37 | :param stride: 38 | :param padding: 39 | :param bias: 40 | """ 41 | super(ConvBn,self).__init__() 42 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,groups=groups,bias=bias) 43 | self.bn = nn.BatchNorm2d(out_channels) 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | x = self.bn(x) 48 | return x -------------------------------------------------------------------------------- /ptocr/model/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | from ..utils.util_function import create_module -------------------------------------------------------------------------------- /ptocr/model/__pycache__/CommonFunction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/__pycache__/CommonFunction.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/det_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/det_model.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/det_model.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/rec_model.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/stn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/stn.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/stn_head.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/architectures/__pycache__/tps_spatial_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/architectures/det_model_q.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_model.py 5 | @time: 2020/08/07 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from .. import create_module 10 | from torch.quantization import QuantStub, DeQuantStub 11 | 12 | class DetModel(nn.Module): 13 | def __init__(self, config): 14 | super(DetModel, self).__init__() 15 | self.algorithm = config['base']['algorithm'] 16 | self.backbone = create_module(config['backbone']['function'])(config['base']['pretrained']) 17 | 18 | self.head = create_module(config['head']['function']) \ 19 | (config['base']['in_channels'], 20 | config['base']['inner_channels']) 21 | 22 | if (config['base']['algorithm']) == 'DB': 23 | self.seg_out = create_module(config['segout']['function'])(config['base']['inner_channels']) 24 | 25 | self.quant = QuantStub() 26 | self.dequant = DeQuantStub() 27 | self.k = config['base']['k'] 28 | 29 | def step_function(self, x, y): 30 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 31 | 32 | def forward(self, data): 33 | if self.training: 34 | if self.algorithm == "DB": 35 | img, gt, gt_mask, thresh_map, thresh_mask = data 36 | if torch.cuda.is_available(): 37 | img, gt, gt_mask, thresh_map, thresh_mask = \ 38 | img.cuda(), gt.cuda(), gt_mask.cuda(), thresh_map.cuda(), thresh_mask.cuda() 39 | gt_batch = dict(gt=gt) 40 | gt_batch['mask'] = gt_mask 41 | gt_batch['thresh_map'] = thresh_map 42 | gt_batch['thresh_mask'] = thresh_mask 43 | else: 44 | img = data 45 | img = self.quant(img) 46 | f_map= self.backbone(img) 47 | head_map = self.head(f_map[-1],f_map[-2],f_map[-3],f_map[-4]) 48 | thresh,binary = self.seg_out(head_map) 49 | thresh = self.dequant(thresh) 50 | binary = self.dequant(binary) 51 | thresh_binary = self.step_function(binary,thresh) 52 | out = {} 53 | out['binary'] = binary 54 | out['thresh'] = thresh 55 | out['thresh_binary'] = thresh_binary 56 | if self.training: 57 | return out, gt_batch 58 | return out 59 | 60 | 61 | class DetLoss(nn.Module): 62 | def __init__(self, config): 63 | super(DetLoss, self).__init__() 64 | self.algorithm = config['base']['algorithm'] 65 | if (config['base']['algorithm']) == 'DB': 66 | self.loss = create_module(config['loss']['function'])(config['loss']['l1_scale'], 67 | config['loss']['bce_scale']) 68 | elif (config['base']['algorithm']) == 'PAN': 69 | self.loss = create_module(config['loss']['function'])(config['loss']['kernel_rate'], 70 | config['loss']['agg_dis_rate']) 71 | elif (config['base']['algorithm']) == 'PSE': 72 | self.loss = create_module(config['loss']['function'])(config['loss']['text_tatio']) 73 | 74 | elif (config['base']['algorithm']) == 'SAST': 75 | self.loss = create_module(config['loss']['function'])(config['loss']['tvo_lw'], 76 | config['loss']['tco_lw'], 77 | config['loss']['score_lw'], 78 | config['loss']['border_lw'] 79 | ) 80 | else: 81 | assert True == False, ('not support this algorithm !!!') 82 | 83 | def forward(self, pre_batch, gt_batch): 84 | return self.loss(pre_batch, gt_batch) 85 | 86 | -------------------------------------------------------------------------------- /ptocr/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_mobilev3.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_mobilev3.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_mobilev3_dcd.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet_3_3.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet_sast.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_resnet_sast_3_3.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/det_scnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/det_scnet.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/rec_crnn_backbone.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/rec_mobilev3_bd.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/rec_vgg.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/backbone/__pycache__/reg_resnet_bd.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_DBHead.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_DBHead.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_DBHead.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_FPEM_FFM_Head.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_FPNHead.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_FPNHead.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_FPNHead.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_FPNHead.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_SASTHead.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_SASTHead.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/det_SASTHead.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/rec_CRNNHead.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/head/__pycache__/rec_FCHead.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/head/det_DBHead.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_DBHead.py 5 | @time: 2020/08/07 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from ..CommonFunction import ConvBnRelu,upsample,upsample_add 10 | 11 | class DB_Head(nn.Module): 12 | def __init__(self, in_channels, inner_channels, bias=False): 13 | """ 14 | :param in_channels: 15 | :param inner_channels: 16 | :param bias: 17 | """ 18 | super(DB_Head, self).__init__() 19 | 20 | self.in5 = ConvBnRelu(in_channels[-1], inner_channels, 1, 1, 0, bias=bias) 21 | self.in4 = ConvBnRelu(in_channels[-2], inner_channels, 1, 1, 0, bias=bias) 22 | self.in3 = ConvBnRelu(in_channels[-3], inner_channels, 1, 1, 0, bias=bias) 23 | self.in2 = ConvBnRelu(in_channels[-4], inner_channels, 1, 1, 0, bias=bias) 24 | 25 | self.out5 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, bias=bias) 26 | self.out4 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, bias=bias) 27 | self.out3 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, bias=bias) 28 | self.out2 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, bias=bias) 29 | 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.kaiming_normal_(m.weight.data) 33 | elif isinstance(m, nn.BatchNorm2d): 34 | m.weight.data.fill_(1.) 35 | m.bias.data.fill_(1e-4) 36 | 37 | def forward(self, x): 38 | 39 | c2, c3, c4, c5 = x 40 | in5 = self.in5(c5) 41 | in4 = self.in4(c4) 42 | in3 = self.in3(c3) 43 | in2 = self.in2(c2) 44 | 45 | out4 = upsample_add(in5,in4) # 1/16 46 | out3 = upsample_add(out4,in3) # 1/8 47 | out2 = upsample_add(out3,in2) # 1/4 48 | 49 | p5 = upsample(self.out5(in5),out2) 50 | p4 = upsample(self.out4(out4),out2) 51 | p3 = upsample(self.out3(out3),out2) 52 | p2 = self.out2(out2) 53 | fuse = torch.cat((p5, p4, p3, p2), 1) 54 | return fuse -------------------------------------------------------------------------------- /ptocr/model/head/det_DBHead_Qua.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: DB_Head_Qua.py 5 | @time: 2020/11/02 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from ..CommonFunction_Q import ConvBnRelu 11 | from torch.quantization import QuantStub, DeQuantStub, fuse_modules 12 | 13 | def upsample(x, y): 14 | _, _, H, W = y.size() 15 | # return F.upsample(x, size=(H // scale, W // scale), mode='nearest') 16 | return F.interpolate(x, size=(H , W), mode='nearest') 17 | 18 | def upsample_add(x, y,skip): 19 | _, _, H, W = y.size() 20 | # return F.upsample(x, size=(H, W), mode='nearest') + y 21 | return skip.add(F.interpolate(x, size=(H, W), mode='nearest'),y) 22 | 23 | class DB_Head(nn.Module): 24 | def __init__(self, in_channels, inner_channels, bias=False): 25 | """ 26 | :param in_channels: 27 | :param inner_channels: 28 | :param bias: 29 | """ 30 | super(DB_Head, self).__init__() 31 | self.skip = nn.quantized.FloatFunctional() 32 | 33 | self.in5 = ConvBnRelu(in_channels[-1], inner_channels, 1, 1, 0,groups=1, bias=bias) 34 | self.in4 = ConvBnRelu(in_channels[-2], inner_channels, 1, 1, 0,groups=1, bias=bias) 35 | self.in3 = ConvBnRelu(in_channels[-3], inner_channels, 1, 1, 0, groups=1,bias=bias) 36 | self.in2 = ConvBnRelu(in_channels[-4], inner_channels, 1, 1, 0,groups=1, bias=bias) 37 | 38 | self.out5 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, groups=1,bias=bias) 39 | self.out4 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1,groups=1, bias=bias) 40 | self.out3 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, groups=1,bias=bias) 41 | self.out2 = ConvBnRelu(inner_channels, inner_channels // 4, 3, 1, 1, groups=1,bias=bias) 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | nn.init.kaiming_normal_(m.weight.data) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1.) 48 | m.bias.data.fill_(1e-4) 49 | self.fuse_model() 50 | 51 | def forward(self, x1,x2,x3,x4): 52 | in5 = self.in5(x1) 53 | in4 = self.in4(x2) 54 | in3 = self.in3(x3) 55 | in2 = self.in2(x4) 56 | 57 | out4 = upsample_add(in5,in4,self.skip) # 1/16 58 | out3 = upsample_add(out4,in3,self.skip) # 1/8 59 | out2 = upsample_add(out3,in2,self.skip) # 1/4 60 | 61 | p5 = upsample(self.out5(in5),out2) 62 | p4 = upsample(self.out4(out4),out2) 63 | p3 = upsample(self.out3(out3),out2) 64 | p2 = self.out2(out2) 65 | fuse = self.skip.cat((p5, p4, p3, p2), 1) 66 | return fuse 67 | 68 | def fuse_model(self): 69 | for m in self.modules(): 70 | if type(m) == ConvBnRelu: 71 | fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True) 72 | 73 | -------------------------------------------------------------------------------- /ptocr/model/head/det_FPNHead.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_FPNHead.py 5 | @time: 2020/08/07 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from ptocr.model.CommonFunction import ConvBnRelu,upsample_add,upsample 10 | 11 | class FPN_Head(nn.Module): 12 | def __init__(self, in_channels, inner_channels,bias=False): 13 | """ 14 | :param in_channels: 15 | :param inner_channels: 16 | :param bias: 17 | """ 18 | super(FPN_Head, self).__init__() 19 | # Top layer 20 | self.toplayer = ConvBnRelu(in_channels[-1], inner_channels, kernel_size=1, stride=1,padding=0,bias=bias) # Reduce channels 21 | # Smooth layers 22 | self.smooth1 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias) 23 | self.smooth2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias) 24 | self.smooth3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias) 25 | # Lateral layers 26 | self.latlayer1 = ConvBnRelu(in_channels[-2], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias) 27 | self.latlayer2 = ConvBnRelu(in_channels[-3], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias) 28 | self.latlayer3 = ConvBnRelu(in_channels[-4], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias) 29 | # Out map 30 | self.conv_out = ConvBnRelu(inner_channels * 4, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias) 31 | 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | nn.init.kaiming_normal_(m.weight.data) 35 | elif isinstance(m, nn.BatchNorm2d): 36 | m.weight.data.fill_(1.) 37 | m.bias.data.fill_(1e-4) 38 | 39 | def forward(self, x): 40 | c2, c3, c4, c5 = x 41 | ## 42 | p5 = self.toplayer(c5) 43 | c4 = self.latlayer1(c4) 44 | p4 = upsample_add(p5, c4) 45 | p4 = self.smooth1(p4) 46 | c3 = self.latlayer2(c3) 47 | p3 = upsample_add(p4, c3) 48 | p3 = self.smooth2(p3) 49 | c2 = self.latlayer3(c2) 50 | p2 = upsample_add(p3, c2) 51 | p2 = self.smooth3(p2) 52 | ## 53 | p3 = upsample(p3, p2) 54 | p4 = upsample(p4, p2) 55 | p5 = upsample(p5, p2) 56 | 57 | fuse = torch.cat((p2, p3, p4, p5), 1) 58 | fuse = self.conv_out(fuse) 59 | return fuse -------------------------------------------------------------------------------- /ptocr/model/head/rec_FCHead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FCModule(nn.Module): 5 | """FCModule 6 | Args: 7 | """ 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | bias=True, 12 | activation='relu', 13 | inplace=True, 14 | dropout=None, 15 | order=('fc', 'act')): 16 | super(FCModule, self).__init__() 17 | self.order = order 18 | self.activation = activation 19 | self.inplace = inplace 20 | 21 | self.with_activatation = activation is not None 22 | self.with_dropout = dropout is not None 23 | 24 | self.fc = nn.Linear(in_channels, out_channels, bias) 25 | 26 | # build activation layer 27 | if self.with_activatation: 28 | # TODO: introduce `act_cfg` and supports more activation layers 29 | if self.activation not in ['relu', 'tanh']: 30 | raise ValueError('{} is currently not supported.'.format( 31 | self.activation)) 32 | if self.activation == 'relu': 33 | self.activate = nn.ReLU(inplace=inplace) 34 | elif self.activation == 'tanh': 35 | self.activate = nn.Tanh() 36 | 37 | if self.with_dropout: 38 | self.dropout = nn.Dropout(p=dropout) 39 | 40 | def forward(self, x): 41 | if self.order == ('fc', 'act'): 42 | x = self.fc(x) 43 | 44 | if self.with_activatation: 45 | x = self.activate(x) 46 | elif self.order == ('act', 'fc'): 47 | if self.with_activatation: 48 | x = self.activate(x) 49 | x = self.fc(x) 50 | 51 | if self.with_dropout: 52 | x = self.dropout(x) 53 | 54 | return x 55 | 56 | class FCModules(nn.Module): 57 | """FCModules 58 | Args: 59 | """ 60 | def __init__(self, 61 | in_channels, 62 | out_channels, 63 | bias=True, 64 | activation='relu', 65 | inplace=True, 66 | dropouts=None, 67 | num_fcs=1): 68 | super().__init__() 69 | 70 | if dropouts is not None: 71 | assert num_fcs == len(dropouts) 72 | dropout = dropouts[0] 73 | else: 74 | dropout = None 75 | 76 | layers = [FCModule(in_channels, out_channels, bias, activation, inplace, dropout)] 77 | for ii in range(1, num_fcs): 78 | if dropouts is not None: 79 | dropout = dropouts[ii] 80 | else: 81 | dropout = None 82 | layers.append(FCModule(out_channels, out_channels, bias, activation, inplace, dropout)) 83 | 84 | self.block = nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | feat = self.block(x) 88 | return feat 89 | 90 | 91 | class FC_Head(nn.Module): 92 | def __init__(self,in_channels, 93 | out_channels,max_length,num_class): 94 | super(FC_Head,self).__init__() 95 | self.adpooling = nn.AdaptiveAvgPool2d(1) 96 | self.fc_end = FCModules(in_channels=in_channels,out_channels=out_channels) 97 | self.fc_out = nn.Linear(out_channels,(num_class+1)*(max_length+1)) 98 | self.num_class = num_class 99 | self.max_length = max_length 100 | def forward(self,x): 101 | x = self.adpooling(x) 102 | x = x.view(x.shape[0],-1) 103 | x = self.fc_end(x) 104 | x1 = self.fc_out(x) 105 | x2 = x1.view(x1.shape[0],self.max_length+1,self.num_class+1) 106 | return x2,x1 -------------------------------------------------------------------------------- /ptocr/model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/basical_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/basical_loss.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/basical_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/centerloss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/ctc_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/db_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/db_loss.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/db_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/fc_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/pan_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/pan_loss.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/pan_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/pan_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/pse_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/pse_loss.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/pse_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/pse_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/sast_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/sast_loss.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/loss/__pycache__/sast_loss.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/loss/ctc_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | class CTCLoss(nn.Module): 3 | def __init__(self,config): 4 | super(CTCLoss,self).__init__() 5 | self.config = config 6 | if config['loss']['ctc_type'] == 'warpctc': 7 | from warpctc_pytorch import CTCLoss as PytorchCTCLoss 8 | self.criterion = PytorchCTCLoss() 9 | else: 10 | from torch.nn import CTCLoss as PytorchCTCLoss 11 | self.criterion = PytorchCTCLoss(reduction = 'none') 12 | 13 | def forward(self,pre_batch,gt_batch): 14 | preds,preds_size = pre_batch['preds'],pre_batch['preds_size'] 15 | labels,labels_len = gt_batch['labels'],gt_batch['labels_len'] 16 | if self.config['loss']['ctc_type'] != 'warpctc': 17 | preds = preds.log_softmax(2).requires_grad_() # torch.ctcloss 18 | loss = self.criterion(preds, labels, preds_size, labels_len) 19 | if self.config['loss']['use_ctc_weight']: 20 | loss = gt_batch['ctc_loss_weight']*loss.cuda() 21 | loss = loss.sum() 22 | return loss/self.config['trainload']['batch_size'] -------------------------------------------------------------------------------- /ptocr/model/loss/db_loss.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: db_loss.py 5 | @time: 2020/08/10 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from .basical_loss import MaskL1Loss,BalanceCrossEntropyLoss,DiceLoss,FocalCrossEntropyLoss,MulClassLoss 10 | 11 | 12 | 13 | 14 | class DBLoss(nn.Module): 15 | def __init__(self, l1_scale=10, bce_scale=1,eps=1e-6): 16 | super(DBLoss, self).__init__() 17 | self.dice_loss = DiceLoss(eps) 18 | self.l1_loss = MaskL1Loss() 19 | self.bce_loss = BalanceCrossEntropyLoss() 20 | self.l1_scale = l1_scale 21 | self.bce_scale = bce_scale 22 | 23 | def forward(self, pred_bach, gt_batch): 24 | bce_loss = self.bce_loss(pred_bach['binary'][:,0], gt_batch['gt'], gt_batch['mask']) 25 | metrics = dict(loss_bce=bce_loss) 26 | if 'thresh' in pred_bach: 27 | l1_loss, l1_metric = self.l1_loss(pred_bach['thresh'][:,0], gt_batch['thresh_map'], gt_batch['thresh_mask']) 28 | dice_loss = self.dice_loss(pred_bach['thresh_binary'][:,0], gt_batch['gt'], gt_batch['mask']) 29 | metrics['loss_thresh'] = dice_loss 30 | loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale 31 | metrics.update(**l1_metric) 32 | else: 33 | loss = bce_loss 34 | return loss, metrics 35 | 36 | class DBLossMul(nn.Module): 37 | def __init__(self, n_class,l1_scale=10, bce_scale=1, class_scale=1,eps=1e-6): 38 | super(DBLossMul, self).__init__() 39 | self.dice_loss = DiceLoss(eps) 40 | self.l1_loss = MaskL1Loss() 41 | self.bce_loss = BalanceCrossEntropyLoss() 42 | self.class_loss = MulClassLoss() 43 | self.l1_scale = l1_scale 44 | self.bce_scale = bce_scale 45 | self.class_scale = class_scale 46 | self.n_class = n_class 47 | 48 | def forward(self, pred_bach, gt_batch): 49 | bce_loss = self.bce_loss(pred_bach['binary'][:,0], gt_batch['gt'], gt_batch['mask']) 50 | class_loss = self.class_loss(pred_bach['binary_class'] ,gt_batch['gt_class'],self.n_class) 51 | metrics = dict(loss_bce=bce_loss) 52 | if 'thresh' in pred_bach: 53 | l1_loss, l1_metric = self.l1_loss(pred_bach['thresh'][:,0], gt_batch['thresh_map'], gt_batch['thresh_mask']) 54 | dice_loss = self.dice_loss(pred_bach['thresh_binary'][:,0], gt_batch['gt'], gt_batch['mask']) 55 | metrics['loss_thresh'] = dice_loss 56 | metrics['loss_class'] = class_loss 57 | loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale + class_loss * self.class_scale 58 | metrics.update(**l1_metric) 59 | else: 60 | loss = bce_loss 61 | return loss, metrics -------------------------------------------------------------------------------- /ptocr/model/loss/fc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .basical_loss import CrossEntropyLoss 4 | 5 | class FCLoss(nn.Module): 6 | def __init__(self,ignore_index = -1): 7 | super(FCLoss, self).__init__() 8 | self.cross_entropy_loss = CrossEntropyLoss(ignore_index = ignore_index) 9 | 10 | 11 | def forward(self, pred_bach, gt_batch): 12 | loss = self.cross_entropy_loss(pred_bach['pred'],gt_batch['gt']) 13 | metrics = dict(loss_fc=loss) 14 | return loss, metrics -------------------------------------------------------------------------------- /ptocr/model/loss/pan_loss.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: pan_loss.py 5 | @time: 2020/08/10 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from .basical_loss import DiceLoss,Agg_loss,Dis_loss,ohem_batch 11 | 12 | class PANLoss(nn.Module): 13 | def __init__(self,kernel_rate=0.5,agg_dis_rate=0.25,eps=1e-6): 14 | super(PANLoss,self).__init__() 15 | self.kernel_rate = kernel_rate 16 | self.agg_dis_rate = agg_dis_rate 17 | self.dice_loss = DiceLoss(eps) 18 | self.agg_loss = Agg_loss() 19 | self.dis_loss = Dis_loss() 20 | 21 | def GetKernelLoss(self,pre_text,pre_kernel,gt_kernel,train_mask): 22 | mask0 = pre_text.data.cpu().numpy() 23 | mask1 = train_mask.data.cpu().numpy() 24 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32') 25 | selected_masks = torch.from_numpy(selected_masks).float() 26 | selected_masks = Variable(selected_masks) 27 | if torch.cuda.is_available(): 28 | selected_masks = selected_masks.cuda() 29 | loss_kernel = self.dice_loss(pre_kernel, gt_kernel, selected_masks) 30 | return loss_kernel 31 | 32 | def GetTextLoss(self,pre_text,gt_text,train_mask): 33 | selected_masks = ohem_batch(pre_text, gt_text, train_mask) 34 | selected_masks = Variable(selected_masks) 35 | if torch.cuda.is_available(): 36 | selected_masks = selected_masks.cuda() 37 | loss_text = self.dice_loss(pre_text, gt_text, selected_masks) 38 | return loss_text 39 | 40 | def forward(self,pred_bach,gt_batch): 41 | pre_text = torch.sigmoid(pred_bach['pre_text']) 42 | pre_kernel = torch.sigmoid(pred_bach['pre_kernel']) 43 | gt_text = gt_batch['gt_text'] 44 | gt_text_key = gt_batch['gt_text_key'] 45 | gt_kernel = gt_batch['gt_kernel'] 46 | gt_kernel_key = gt_batch['gt_kernel_key'] 47 | train_mask = gt_batch['train_mask'] 48 | similarity_vector = pred_bach['similarity_vector'] 49 | 50 | pre_text_select = (pre_text > 0.5).float() 51 | pre_kernel_select = (pre_kernel > 0.5).float() 52 | gt_text_key = gt_text_key*pre_text_select 53 | gt_kernel_key = gt_kernel_key*pre_kernel_select 54 | 55 | 56 | loss_kernel = self.GetKernelLoss(pre_text,pre_kernel,gt_kernel,train_mask) 57 | loss_text = self.GetTextLoss(pre_text,gt_text,train_mask) 58 | loss_agg = self.agg_loss.cal_agg_batch(similarity_vector,gt_kernel_key, gt_text_key,train_mask) 59 | loss_dis = self.dis_loss.cal_Ldis_batch(similarity_vector,gt_kernel_key,train_mask) 60 | loss = loss_text + self.kernel_rate*loss_kernel + self.agg_dis_rate*(loss_agg + loss_dis) 61 | metrics = dict(loss_text=loss_text) 62 | metrics['loss_kernel'] = loss_kernel 63 | metrics['loss_agg'] = loss_agg 64 | metrics['loss_dis'] = loss_dis 65 | return loss,metrics 66 | 67 | -------------------------------------------------------------------------------- /ptocr/model/loss/pse_loss.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: pse_loss.py 5 | @time: 2020/08/10 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from .basical_loss import DiceLoss,ohem_batch 11 | 12 | class PSELoss(nn.Module): 13 | def __init__(self,text_tatio=0.7,eps=1e-6): 14 | super(PSELoss,self).__init__() 15 | self.text_tatio = text_tatio 16 | self.dice_loss = DiceLoss(eps) 17 | 18 | def GetKernelLoss(self, pre_text, pre_kernel, gt_kernel, train_mask): 19 | mask0 = pre_text.data.cpu().numpy() 20 | mask1 = train_mask.data.cpu().numpy() 21 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32') 22 | selected_masks = torch.from_numpy(selected_masks).float() 23 | selected_masks = Variable(selected_masks) 24 | if torch.cuda.is_available(): 25 | selected_masks = selected_masks.cuda() 26 | loss_kernels = [] 27 | for i in range(pre_kernel.shape[1]): 28 | loss_kernel = self.dice_loss(torch.sigmoid(pre_kernel[:,i]), gt_kernel[:,i], selected_masks) 29 | loss_kernels.append(loss_kernel) 30 | return sum(loss_kernels)/len(loss_kernels) 31 | 32 | def GetTextLoss(self, pre_text, gt_text, train_mask): 33 | selected_masks = ohem_batch(pre_text, gt_text, train_mask) 34 | selected_masks = Variable(selected_masks) 35 | if torch.cuda.is_available(): 36 | selected_masks = selected_masks.cuda() 37 | loss_text = self.dice_loss(pre_text, gt_text, selected_masks) 38 | return loss_text 39 | 40 | def forward(self, pred_bach,gt_batch): 41 | pre_text = torch.sigmoid(pred_bach['pre_text']) 42 | pre_kernel = pred_bach['pre_kernel'] 43 | gt_text = gt_batch['gt_text'] 44 | gt_kernel = gt_batch['gt_kernel'] 45 | train_mask = gt_batch['train_mask'] 46 | 47 | loss_text = self.GetTextLoss(pre_text,gt_text,train_mask) 48 | loss_kernel = self.GetKernelLoss(pre_text,pre_kernel,gt_kernel,train_mask) 49 | loss = self.text_tatio*loss_text + (1 - self.text_tatio)*loss_kernel 50 | metrics = dict(loss_text=loss_text) 51 | metrics['loss_kernel'] = loss_kernel 52 | return loss,metrics -------------------------------------------------------------------------------- /ptocr/model/segout/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_DB_segout.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_DB_segout.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_DB_segout.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_PAN_segout.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_PAN_segout.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_PAN_segout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_PAN_segout.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_PSE_segout.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_PSE_segout.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_PSE_segout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_PSE_segout.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_SAST_segout.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-35.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/model/segout/__pycache__/det_SAST_segout.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/model/segout/det_DB_segout_qua.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_DB_segout.py 5 | @time: 2020/08/07 6 | """ 7 | from collections import OrderedDict 8 | import torch 9 | import torch.nn as nn 10 | from ..CommonFunction_Q import ConvBnRelu,ConvBn 11 | from torch.quantization import QuantStub, DeQuantStub, fuse_modules 12 | 13 | class SegDetector(nn.Module): 14 | def __init__(self,inner_channels=256,bias=False, 15 | *args, **kwargs): 16 | ''' 17 | bias: Whether conv layers have bias or not. 18 | adaptive: Whether to use adaptive threshold training or not. 19 | smooth: If true, use bilinear instead of deconv. 20 | serial: If true, thresh prediction will combine segmentation result as input. 21 | ''' 22 | super(SegDetector, self).__init__() 23 | self.binarize = nn.Sequential( 24 | ConvBnRelu(inner_channels, inner_channels // 4, 3,1, padding=1,groups=1,bias=bias), 25 | ConvBn(inner_channels // 4, 1, 1,1,0,groups=1), 26 | nn.Upsample(scale_factor=4), 27 | nn.Sigmoid() 28 | ) 29 | 30 | self.thresh = nn.Sequential( 31 | ConvBnRelu(inner_channels, inner_channels //4, 3, stride=1,padding=1, groups=1,bias=bias), 32 | ConvBn(inner_channels // 4, 1,1,1,0,groups=1,bias=bias), 33 | nn.Upsample(scale_factor=4), 34 | nn.Sigmoid() 35 | ) 36 | self.binarize.apply(self.weights_init) 37 | self.fuse_model() 38 | 39 | def weights_init(self, m): 40 | classname = m.__class__.__name__ 41 | if classname.find('Conv2d') != -1: 42 | nn.init.kaiming_normal_(m.weight.data) 43 | elif classname.find('BatchNorm') != -1: 44 | m.weight.data.fill_(1.) 45 | m.bias.data.fill_(1e-4) 46 | 47 | def forward(self, fuse): 48 | binary = self.binarize(fuse) 49 | thresh = self.thresh(fuse) 50 | return thresh,binary 51 | 52 | def fuse_model(self): 53 | for m in self.modules(): 54 | if type(m) == ConvBnRelu: 55 | fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True) 56 | if type(m) == ConvBn: 57 | fuse_modules(m, ['conv', 'bn'], inplace=True) 58 | 59 | -------------------------------------------------------------------------------- /ptocr/model/segout/det_PAN_segout.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_PAN_segout.py 5 | @time: 2020/08/07 6 | """ 7 | 8 | import torch.nn as nn 9 | from ..CommonFunction import upsample 10 | 11 | class SegDetector(nn.Module): 12 | def __init__(self,inner_channels=128,classes=6): 13 | super(SegDetector,self).__init__() 14 | self.binarize = nn.Conv2d(inner_channels,classes,1,1,0) 15 | def forward(self, x,img): 16 | x = self.binarize(x) 17 | x = upsample(x, img) 18 | if self.training: 19 | pre_batch = dict(pre_text=x[:,0]) 20 | pre_batch['pre_kernel'] = x[:,1] 21 | pre_batch['similarity_vector'] = x[:,2:] 22 | return pre_batch 23 | return x -------------------------------------------------------------------------------- /ptocr/model/segout/det_PSE_segout.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_PSE_segout.py 5 | @time: 2020/08/07 6 | """ 7 | import torch.nn as nn 8 | from ..CommonFunction import upsample 9 | class SegDetector(nn.Module): 10 | def __init__(self,inner_channels=256,classes=7): 11 | super(SegDetector,self).__init__() 12 | self.binarize = nn.Conv2d(inner_channels,classes,1,1,0) 13 | def forward(self, x,img): 14 | x = self.binarize(x) 15 | x = upsample(x,img) 16 | if self.training: 17 | pre_batch = dict(pre_text=x[:,0]) 18 | pre_batch['pre_kernel'] = x[:,1:] 19 | return pre_batch 20 | return x -------------------------------------------------------------------------------- /ptocr/model/segout/det_SAST_segout.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: det_SAST_segout.py 5 | @time: 2020/08/18 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from ..CommonFunction import ConvBnRelu 10 | 11 | class SASTHead1(nn.Module): 12 | def __init__(self): 13 | super(SASTHead1,self).__init__() 14 | self.f_score_conv1 = ConvBnRelu(128, 64, 1, 1, 0) 15 | self.f_score_conv2 = ConvBnRelu(64, 64, 3, 1, 1) 16 | self.f_score_conv3 = ConvBnRelu(64, 128, 1, 1, 0) 17 | self.f_score_conv4 = ConvBnRelu(128, 1, 3, 1, 1,with_relu=False) 18 | 19 | self.f_border_conv1 = ConvBnRelu(128, 64, 1, 1, 0) 20 | self.f_border_conv2 = ConvBnRelu(64, 64, 3, 1, 1) 21 | self.f_border_conv3 = ConvBnRelu(64, 128, 1, 1, 0) 22 | self.f_border_conv4 = ConvBnRelu(128, 4, 3, 1, 1,with_relu=False) 23 | 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | nn.init.kaiming_normal_(m.weight.data) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | m.weight.data.fill_(1.) 29 | m.bias.data.fill_(1e-4) 30 | 31 | def forward(self, x): 32 | f_score = self.f_score_conv1(x) 33 | f_score = self.f_score_conv2(f_score) 34 | f_score = self.f_score_conv3(f_score) 35 | f_score = self.f_score_conv4(f_score) 36 | f_score = torch.sigmoid(f_score) 37 | 38 | f_border = self.f_border_conv1(x) 39 | f_border = self.f_border_conv2(f_border) 40 | f_border = self.f_border_conv3(f_border) 41 | f_border = self.f_border_conv4(f_border) 42 | 43 | return f_score,f_border 44 | 45 | class SASTHead2(nn.Module): 46 | def __init__(self): 47 | super(SASTHead2, self).__init__() 48 | self.f_tvo_conv1 = ConvBnRelu(128, 64, 1, 1, 0) 49 | self.f_tvo_conv2 = ConvBnRelu(64, 64, 3, 1, 1) 50 | self.f_tvo_conv3 = ConvBnRelu(64, 128, 1, 1, 0) 51 | self.f_tvo_conv4 = ConvBnRelu(128, 8, 3, 1, 1, with_relu=False) 52 | 53 | self.f_tco_conv1 = ConvBnRelu(128, 64, 1, 1, 0) 54 | self.f_tco_conv2 = ConvBnRelu(64, 64, 3, 1, 1) 55 | self.f_tco_conv3 = ConvBnRelu(64, 128, 1, 1, 0) 56 | self.f_tco_conv4 = ConvBnRelu(128, 2, 3, 1, 1, with_relu=False) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight.data) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1.) 63 | m.bias.data.fill_(1e-4) 64 | 65 | def forward(self, x): 66 | f_tvo = self.f_tvo_conv1(x) 67 | f_tvo = self.f_tvo_conv2(f_tvo) 68 | f_tvo = self.f_tvo_conv3(f_tvo) 69 | f_tvo = self.f_tvo_conv4(f_tvo) 70 | 71 | f_tco = self.f_tco_conv1(x) 72 | f_tco = self.f_tco_conv2(f_tco) 73 | f_tco = self.f_tco_conv3(f_tco) 74 | f_tco = self.f_tco_conv4(f_tco) 75 | 76 | return f_tvo, f_tco 77 | 78 | class SegDetector(nn.Module): 79 | def __init__(self): 80 | super(SegDetector,self).__init__() 81 | self.sast_head1 = SASTHead1() 82 | self.sast_head2 = SASTHead2() 83 | def forward(self, x,img): 84 | f_score,f_border = self.sast_head1(x) 85 | f_tvo, f_tco = self.sast_head2(x) 86 | predicts = {} 87 | predicts['f_score'] = f_score 88 | predicts['f_border'] = f_border 89 | predicts['f_tvo'] = f_tvo 90 | predicts['f_tco'] = f_tco 91 | return predicts -------------------------------------------------------------------------------- /ptocr/optimizer.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: optimizer.py 5 | @time: 2020/08/11 6 | """ 7 | import torch 8 | 9 | def AdamDecay(config,parameters): 10 | optimizer = torch.optim.Adam(parameters, lr=config['optimizer']['base_lr'], 11 | betas=(config['optimizer']['beta1'], config['optimizer']['beta2']), 12 | weight_decay=config['optimizer']['weight_decay']) 13 | return optimizer 14 | 15 | def SGDDecay(config,parameters): 16 | optimizer = torch.optim.SGD(parameters, lr=config['optimizer']['base_lr'], 17 | momentum=config['optimizer']['momentum'], 18 | weight_decay=config['optimizer']['weight_decay']) 19 | return optimizer 20 | 21 | 22 | def RMSPropDecay(config,parameters): 23 | optimizer = torch.optim.RMSprop(parameters, lr=config['optimizer']['base_lr'], 24 | alpha=config['optimizer']['alpha'], 25 | weight_decay=config['optimizer']['weight_decay'], 26 | momentum=config['optimizer']['momentum']) 27 | return optimizer 28 | 29 | 30 | def lr_poly(base_lr, epoch, max_epoch=1200, factor=0.9): 31 | return base_lr*((1-float(epoch)/max_epoch)**(factor)) 32 | 33 | 34 | def SGDR(lr_max,lr_min,T_cur,T_m,ratio=0.3): 35 | """ 36 | :param lr_max: 最大学习率 37 | :param lr_min: 最小学习率 38 | :param T_cur: 当前的epoch或iter 39 | :param T_m: 隔多少调整的一次 40 | :param ratio: 最大学习率衰减比率 41 | :return: 42 | """ 43 | if T_cur % T_m == 0 and T_cur != 0: 44 | lr_max = lr_max - lr_max * ratio 45 | lr = lr_min+1/2*(lr_max-lr_min)*(1+math.cos((T_cur%T_m/T_m)*math.pi)) 46 | return lr,lr_max 47 | 48 | 49 | def adjust_learning_rate_poly(config, optimizer, epoch): 50 | lr = lr_poly(config['optimizer']['base_lr'], epoch, 51 | config['base']['n_epoch'], config['optimizer_decay']['factor']) 52 | optimizer.param_groups[0]['lr'] = lr 53 | 54 | def adjust_learning_rate_sgdr(config, optimizer, epoch): 55 | lr,lr_max = SGDR(config['optimizer']['lr_max'],config['optimizer']['lr_min'],epoch,config['optimizer']['T_m'],config['optimizer']['ratio']) 56 | optimizer.param_groups[0]['lr'] = lr 57 | config['optimizer']['lr_max'] = lr_max 58 | 59 | def adjust_learning_rate(config, optimizer, epoch): 60 | if epoch in config['optimizer_decay']['schedule']: 61 | adjust_lr = optimizer.param_groups[0]['lr'] * config['optimizer_decay']['gama'] 62 | for param_group in optimizer.param_groups: 63 | param_group['lr'] = adjust_lr 64 | 65 | def adjust_learning_rate_center(config, optimizer, epoch): 66 | if epoch in config['optimizer_decay_center']['schedule']: 67 | adjust_lr = optimizer.param_groups[0]['lr'] * config['optimizer_decay_center']['gama'] 68 | for param_group in optimizer.param_groups: 69 | param_group['lr'] = adjust_lr -------------------------------------------------------------------------------- /ptocr/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/11 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/__pycache__/DBpostprocess.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/__pycache__/SASTpostprocess.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/__pycache__/locality_aware_nms.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | 3 | DEPS = $(shell find include -xtype f) 4 | CXX_SOURCES = cppdbprocess.cpp include/clipper/clipper.cpp 5 | OPENCV = `pkg-config --cflags --libs opencv` 6 | 7 | LIB_SO = cppdbprocess.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | 15 | 16 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import time 6 | import subprocess 7 | import numpy as np 8 | 9 | 10 | 11 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 14 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 15 | 16 | 17 | def cpp_boxes_from_bitmap(pred,bitmap,box_thresh=0.6,det_db_unclip_ratio=2.0): 18 | 19 | from .cppdbprocess import db_cpp 20 | bitmap = bitmap.astype(np.uint8) 21 | bboxes = db_cpp(pred,bitmap,box_thresh,det_db_unclip_ratio) 22 | 23 | return bboxes 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/dbprocess/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/cppdbprocess.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/dbprocess/cppdbprocess.so -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/dbprocess/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/postprocess_op.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #pragma once 16 | 17 | #include "opencv2/core.hpp" 18 | #include "opencv2/imgcodecs.hpp" 19 | #include "opencv2/imgproc.hpp" 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include 27 | #include 28 | #include 29 | 30 | #include "clipper.h" 31 | // #include "utility.h" 32 | 33 | using namespace std; 34 | 35 | namespace cppdbprocess { 36 | 37 | class PostProcessor { 38 | public: 39 | void GetContourArea(const std::vector> &box, 40 | float unclip_ratio, float &distance); 41 | 42 | cv::RotatedRect UnClip(std::vector> box, 43 | const float &unclip_ratio); 44 | 45 | float **Mat2Vec(cv::Mat mat); 46 | 47 | std::vector> 48 | OrderPointsClockwise(std::vector> pts); 49 | 50 | std::vector> GetMiniBoxes(cv::RotatedRect box, 51 | float &ssid); 52 | 53 | float BoxScoreFast(std::vector> box_array, cv::Mat pred); 54 | 55 | std::vector>> 56 | BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, 57 | const float &box_thresh, const float &det_db_unclip_ratio); 58 | 59 | std::vector>> 60 | FilterTagDetRes(std::vector>> boxes, 61 | float ratio_h, float ratio_w, cv::Mat srcimg); 62 | 63 | private: 64 | static bool XsortInt(std::vector a, std::vector b); 65 | 66 | static bool XsortFp32(std::vector a, std::vector b); 67 | 68 | std::vector> Mat2Vector(cv::Mat mat); 69 | 70 | inline int _max(int a, int b) { return a >= b ? a : b; } 71 | 72 | inline int _min(int a, int b) { return a >= b ? b : a; } 73 | 74 | template inline T clamp(T x, T min, T max) { 75 | if (x > max) 76 | return max; 77 | if (x < min) 78 | return min; 79 | return x; 80 | } 81 | 82 | inline float clampf(float x, float min, float max) { 83 | if (x > max) 84 | return max; 85 | if (x < min) 86 | return min; 87 | return x; 88 | } 89 | }; 90 | 91 | } // namespace PaddleOCR 92 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /ptocr/postprocess/dbprocess/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/.gitignore: -------------------------------------------------------------------------------- 1 | adaptor.so 2 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = lanms.h $(shell find include -xtype f) 5 | CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp 6 | 7 | LIB_SO = adaptor.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | 5 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 8 | raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR)) 9 | 10 | 11 | def merge_quadrangle_n9(polys, thres=0.3, precision=10000): 12 | from .adaptor import merge_quadrangle_n9 as nms_impl 13 | if len(polys) == 0: 14 | return np.array([], dtype='float32') 15 | p = polys.copy() 16 | p[:,:8] *= precision 17 | ret = np.array(nms_impl(p, thres), dtype='float32') 18 | ret[:,:8] /= precision 19 | return ret 20 | 21 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/__main__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from . import merge_quadrangle_n9 5 | 6 | if __name__ == '__main__': 7 | # unit square with confidence 1 8 | q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32') 9 | 10 | print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2]))) 11 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/lanms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/adaptor.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include "lanms.h" 7 | 8 | namespace py = pybind11; 9 | 10 | 11 | namespace lanms_adaptor { 12 | 13 | std::vector> polys2floats(const std::vector &polys) { 14 | std::vector> ret; 15 | for (size_t i = 0; i < polys.size(); i ++) { 16 | auto &p = polys[i]; 17 | auto &poly = p.poly; 18 | ret.emplace_back(std::vector{ 19 | float(poly[0].X), float(poly[0].Y), 20 | float(poly[1].X), float(poly[1].Y), 21 | float(poly[2].X), float(poly[2].Y), 22 | float(poly[3].X), float(poly[3].Y), 23 | float(p.score), 24 | }); 25 | } 26 | 27 | return ret; 28 | } 29 | 30 | 31 | /** 32 | * 33 | * \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the 34 | * quadrangle, and the last one is the score 35 | * \param iou_threshold two quadrangles with iou score above this threshold 36 | * will be merged 37 | * 38 | * \return an n-by-9 numpy array, the merged quadrangles 39 | */ 40 | std::vector> merge_quadrangle_n9( 41 | py::array_t quad_n9, 42 | float iou_threshold) { 43 | auto pbuf = quad_n9.request(); 44 | if (pbuf.ndim != 2 || pbuf.shape[1] != 9) 45 | throw std::runtime_error("quadrangles must have a shape of (n, 9)"); 46 | auto n = pbuf.shape[0]; 47 | auto ptr = static_cast(pbuf.ptr); 48 | return polys2floats(lanms::merge_quadrangle_n9(ptr, n, iou_threshold)); 49 | } 50 | 51 | } 52 | 53 | PYBIND11_PLUGIN(adaptor) { 54 | py::module m("adaptor", "NMS"); 55 | 56 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, 57 | "merge quadrangels"); 58 | 59 | return m.ptr(); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/lanms/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(pybind11) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | template constexpr const char format_descriptor< 29 | std::complex, detail::enable_if_t::value>>::value[3]; 30 | 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 34 | static constexpr bool value = true; 35 | static constexpr int index = is_fmt_numeric::index + 3; 36 | }; 37 | 38 | template class type_caster> { 39 | public: 40 | bool load(handle src, bool convert) { 41 | if (!src) 42 | return false; 43 | if (!convert && !PyComplex_Check(src.ptr())) 44 | return false; 45 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 46 | if (result.real == -1.0 && PyErr_Occurred()) { 47 | PyErr_Clear(); 48 | return false; 49 | } 50 | value = std::complex((T) result.real, (T) result.imag); 51 | return true; 52 | } 53 | 54 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 55 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 56 | } 57 | 58 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 59 | }; 60 | NAMESPACE_END(detail) 61 | NAMESPACE_END(pybind11) 62 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + 79 | argument_loader::arg_names() + _("], ") + 80 | make_caster::name() + 81 | _("]")); 82 | }; 83 | 84 | NAMESPACE_END(detail) 85 | NAMESPACE_END(pybind11) 86 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(pybind11) 66 | -------------------------------------------------------------------------------- /ptocr/postprocess/lanms/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | 3 | DEPS = lanms.h $(shell find include -xtype f) 4 | CXX_SOURCES = pixelmerge.cpp include/clipper/clipper.cpp 5 | OPENCV = `pkg-config --cflags --libs opencv` 6 | 7 | LIB_SO = pixelmerge.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | 15 | 16 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import torch 5 | import time 6 | import subprocess 7 | import numpy as np 8 | 9 | 10 | 11 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 14 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 15 | 16 | 17 | def pse(outputs,config): 18 | 19 | from .pixelmerge import pse_cpp,get_points,get_num 20 | 21 | score = torch.sigmoid(outputs[:, 0, :, :]) 22 | outputs = (torch.sign(outputs - config['postprocess']['binary_th']) + 1) / 2 23 | 24 | text = outputs[:, 0, :, :] 25 | kernels = outputs[:, 0:config['base']['classes'], :, :] * text 26 | 27 | score = score.data.cpu().numpy()[0].astype(np.float32) 28 | kernels = kernels.data.cpu().numpy()[0].astype(np.uint8) 29 | 30 | pred = pse_cpp(kernels,config['postprocess']['min_kernel_area'] / (config['postprocess']['scale'] * config['postprocess']['scale'])) 31 | pred = np.array(pred).astype(np.uint8) 32 | label_num = np.max(pred) + 1 33 | label_points = get_points(pred, score, label_num) 34 | 35 | label_values = [] 36 | for label_idx in range(1, label_num): 37 | label_values.append(label_idx) 38 | 39 | # label_values = [] 40 | # label_sum = get_num(pred, label_num) 41 | # for label_idx in range(1, label_num): 42 | # if label_sum[label_idx] < config['postprocess']['min_kernel_area']: 43 | # continue 44 | # label_values.append(label_idx) 45 | 46 | 47 | return pred,label_points,label_values 48 | 49 | 50 | def pan(preds, config): 51 | 52 | from .pixelmerge import pan_cpp, get_points, get_num 53 | 54 | pred = (torch.sign(preds[0, 0:2, :, :]-config['postprocess']['bin_th']) + 1) / 2 55 | score = torch.sigmoid(preds[0]).cpu().numpy().astype(np.float32) 56 | text = pred[0] # text 57 | kernel = (pred[1] * text).cpu().numpy() # kernel 58 | text = text.cpu().numpy() 59 | 60 | # score = torch.sigmoid(preds[0]).cpu().numpy().astype(np.float32) 61 | # pred_t = torch.sigmoid(preds[0, 0:2, :, :]) 62 | # text = pred_t[0]> 0.8 63 | # kernel = (pred_t[1]> 0.8)* text 64 | # text = text.cpu().numpy() 65 | # kernel = kernel.cpu().numpy() 66 | 67 | similarity_vectors = preds[0,2:].permute((1, 2, 0)).cpu().numpy().astype(np.float32) 68 | 69 | 70 | label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4) 71 | label_values = [] 72 | label_sum = get_num(label, label_num) 73 | for label_idx in range(1, label_num): 74 | if label_sum[label_idx] < config['postprocess']['min_kernel_area']: 75 | continue 76 | label_values.append(label_idx) 77 | 78 | pred = pan_cpp(text.astype(np.uint8), similarity_vectors, label, label_num, config['postprocess']['dis_thresh']) 79 | pred = pred.reshape(text.shape) 80 | 81 | label_points = get_points(pred, score, label_num) 82 | 83 | return pred,label_points,label_values 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/piexlmerge/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /ptocr/postprocess/piexlmerge/pixelmerge.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/postprocess/piexlmerge/pixelmerge.so -------------------------------------------------------------------------------- /ptocr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/cal_iou_acc.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/gen_teacher_model.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/prune_script.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/prune_script.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/transform_label.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/transform_label.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/__pycache__/util_function.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/ptocr/utils/__pycache__/util_function.cpython-36.pyc -------------------------------------------------------------------------------- /ptocr/utils/cal_iou_acc.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: cal_iou_acc.py 5 | @time: 2020/08/13 6 | """ 7 | import torch 8 | import numpy as np 9 | 10 | def cal_binary_score(binarys, gt_binarys, training_masks, running_metric_binary, thresh=0.5): 11 | training_masks = training_masks.data.cpu().numpy() 12 | pred_binary = binarys.data.cpu().numpy() * training_masks 13 | pred_binary[pred_binary <= thresh] = 0 14 | pred_binary[pred_binary > thresh] = 1 15 | pred_binary = pred_binary.astype(np.int32) 16 | gt_binary = gt_binarys.data.cpu().numpy() * training_masks 17 | gt_binary = gt_binary.astype(np.int32) 18 | running_metric_binary.update(gt_binary, pred_binary) 19 | score_binary, _ = running_metric_binary.get_scores() 20 | return score_binary 21 | 22 | def cal_text_score(texts, gt_texts, training_masks, running_metric_text, thresh=0.5): 23 | training_masks = training_masks.data.cpu().numpy() 24 | pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks 25 | pred_text[pred_text <= thresh] = 0 26 | pred_text[pred_text > thresh] = 1 27 | pred_text = pred_text.astype(np.int32) 28 | gt_text = gt_texts.data.cpu().numpy() * training_masks 29 | gt_text = gt_text.astype(np.int32) 30 | running_metric_text.update(gt_text, pred_text) 31 | score_text, _ = running_metric_text.get_scores() 32 | return score_text 33 | 34 | def cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks, running_metric_kernel, thresh=0.5): 35 | mask = (gt_texts * training_masks).data.cpu().numpy() 36 | kernel = kernels[:, -1, :, :] 37 | gt_kernel = gt_kernels[:, -1, :, :] 38 | pred_kernel = torch.sigmoid(kernel).data.cpu().numpy() 39 | pred_kernel[pred_kernel <= thresh] = 0 40 | pred_kernel[pred_kernel > thresh] = 1 41 | pred_kernel = (pred_kernel * mask).astype(np.int32) 42 | gt_kernel = gt_kernel.data.cpu().numpy() 43 | gt_kernel = (gt_kernel * mask).astype(np.int32) 44 | running_metric_kernel.update(gt_kernel, pred_kernel) 45 | score_kernel, _ = running_metric_kernel.get_scores() 46 | return score_kernel 47 | 48 | def cal_PAN_PSE(kernels, gt_kernels,texts ,gt_texts, training_masks, running_metric_text,running_metric_kernel): 49 | if(len(kernels.shape)==3): 50 | kernels = kernels.unsqueeze(1) 51 | gt_kernels = gt_kernels.unsqueeze(1) 52 | score_kernel = cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks, running_metric_kernel) 53 | score_text = cal_text_score(texts, gt_texts, training_masks, running_metric_text) 54 | acc = (score_text['Mean Acc'] + score_kernel['Mean Acc'])/2 55 | iou = (score_text['Mean IoU'] + score_kernel['Mean IoU'])/2 56 | return iou,acc 57 | 58 | def cal_DB(texts ,gt_texts, training_masks, running_metric_text): 59 | score_text = cal_binary_score(texts.squeeze(1), gt_texts.squeeze(1), training_masks.squeeze(1), running_metric_text) 60 | acc = score_text['Mean Acc'] 61 | iou = score_text['Mean IoU'] 62 | return iou,acc 63 | 64 | 65 | -------------------------------------------------------------------------------- /ptocr/utils/gen_teacher_model.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: gen_teacher_model.py 5 | @time: 2020/10/15 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import yaml 11 | from ptocr.utils.util_function import create_module,load_model 12 | 13 | 14 | class DiceLoss(nn.Module): 15 | def __init__(self,eps=1e-6): 16 | super(DiceLoss,self).__init__() 17 | self.eps = eps 18 | def forward(self,pre_score,gt_score,train_mask): 19 | pre_score = pre_score.contiguous().view(pre_score.size()[0], -1) 20 | gt_score = gt_score.contiguous().view(gt_score.size()[0], -1) 21 | train_mask = train_mask.contiguous().view(train_mask.size()[0], -1) 22 | 23 | pre_score = pre_score * train_mask 24 | gt_score = gt_score * train_mask 25 | 26 | a = torch.sum(pre_score * gt_score, 1) 27 | b = torch.sum(pre_score * pre_score, 1) + self.eps 28 | c = torch.sum(gt_score * gt_score, 1) + self.eps 29 | d = (2 * a) / (b + c) 30 | dice_loss = torch.mean(d) 31 | return 1 - dice_loss 32 | 33 | def GetTeacherModel(args): 34 | config = yaml.load(open(args.t_config, 'r', encoding='utf-8'), Loader=yaml.FullLoader) 35 | model = create_module(config['architectures']['model_function'])(config) 36 | model = load_model(model,args.t_model_path) 37 | if torch.cuda.is_available(): 38 | model = model.cuda() 39 | return model 40 | 41 | class DistilLoss(nn.Module): 42 | def __init__(self): 43 | 44 | super(DistilLoss, self).__init__() 45 | self.mse = nn.MSELoss() 46 | self.diceloss = DiceLoss() 47 | self.ignore = ['thresh'] 48 | 49 | def forward(self, s_map, t_map): 50 | loss = 0 51 | for key in s_map.keys(): 52 | if(key in self.ignore): 53 | continue 54 | loss += self.diceloss(s_map[key],t_map[key],torch.ones(t_map[key].shape).cuda()) 55 | return loss 56 | 57 | 58 | -------------------------------------------------------------------------------- /ptocr/utils/logger.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | # A simple torch style logger 3 | # (C) Wei YANG 2017 4 | from __future__ import absolute_import 5 | 6 | import logging 7 | 8 | class TrainLog(object): 9 | def __init__(self,LOG_FILE): 10 | file_handler = logging.FileHandler(LOG_FILE) #输出到文件 11 | console_handler = logging.StreamHandler() #输出到控制台 12 | file_handler.setLevel('INFO') #error以上才输出到文件 13 | console_handler.setLevel('INFO') #info以上才输出到控制台 14 | 15 | fmt = '%(asctime)s - %(funcName)s - %(lineno)s - %(levelname)s - %(message)s' 16 | formatter = logging.Formatter(fmt) 17 | file_handler.setFormatter(formatter) #设置输出内容的格式 18 | console_handler.setFormatter(formatter) 19 | 20 | logger = logging.getLogger('TrainLog') 21 | logger.setLevel('INFO') #设置了这个才会把debug以上的输出到控制台 22 | 23 | logger.addHandler(console_handler) 24 | logger.addHandler(file_handler) 25 | self.logger = logger 26 | 27 | def error(self,char): 28 | self.logger.error(char) 29 | def debug(self,char): 30 | self.logger.debug(char) 31 | def info(self,char): 32 | self.logger.info(char) 33 | 34 | class Logger(object): 35 | def __init__(self, fpath, title=None, resume=False): 36 | self.file = None 37 | self.resume = resume 38 | self.title = '' if title == None else title 39 | if fpath is not None: 40 | if resume: 41 | self.file = open(fpath, 'r') 42 | name = self.file.readline() 43 | self.names = name.rstrip().split('\t') 44 | self.numbers = {} 45 | for _, name in enumerate(self.names): 46 | self.numbers[name] = [] 47 | 48 | for numbers in self.file: 49 | numbers = numbers.rstrip().split('\t') 50 | for i in range(0, len(numbers)): 51 | self.numbers[self.names[i]].append(numbers[i]) 52 | self.file.close() 53 | self.file = open(fpath, 'a') 54 | else: 55 | self.file = open(fpath, 'w') 56 | 57 | def set_names(self, names): 58 | if self.resume: 59 | pass 60 | self.numbers = {} 61 | self.names = names 62 | for _, name in enumerate(self.names): 63 | self.file.write(name) 64 | self.file.write('\t') 65 | self.numbers[name] = [] 66 | self.file.write('\n') 67 | self.file.flush() 68 | 69 | def set_split(self, names): 70 | if self.resume: 71 | pass 72 | self.numbers = {} 73 | self.names = names 74 | for _, name in enumerate(self.names): 75 | self.file.write(name) 76 | self.numbers[name] = [] 77 | self.file.write('\n') 78 | self.file.flush() 79 | 80 | def append(self, numbers): 81 | assert len(self.names) == len(numbers), 'Numbers do not match names' 82 | for index, num in enumerate(numbers): 83 | self.file.write("{0:.6f}".format(num)) 84 | self.file.write('\t') 85 | self.numbers[self.names[index]].append(num) 86 | self.file.write('\n') 87 | self.file.flush() 88 | 89 | def close(self): 90 | if self.file is not None: 91 | self.file.close() 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /ptocr/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | class runningScore(object): 7 | 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | 15 | if np.sum((label_pred[mask] < 0)) > 0: 16 | print (label_pred[label_pred < 0]) 17 | hist = np.bincount( 18 | n_class * label_true[mask].astype(int) + 19 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) 20 | return hist 21 | 22 | def update(self, label_trues, label_preds): 23 | # print label_trues.dtype, label_preds.dtype 24 | for lt, lp in zip(label_trues, label_preds): 25 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 26 | 27 | def get_scores(self): 28 | """Returns accuracy score evaluation result. 29 | - overall accuracy 30 | - mean accuracy 31 | - mean IU 32 | - fwavacc 33 | """ 34 | hist = self.confusion_matrix 35 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001) 36 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) 37 | acc_cls = np.nanmean(acc_cls) 38 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) 39 | mean_iu = np.nanmean(iu) 40 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001) 41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 42 | cls_iu = dict(zip(range(self.n_classes), iu)) 43 | 44 | return {'Overall Acc': acc, 45 | 'Mean Acc': acc_cls, 46 | 'FreqW Acc': fwavacc, 47 | 'Mean IoU': mean_iu,}, cls_iu 48 | 49 | def reset(self): 50 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | torch2trt==0.0.3 4 | onnx==1.5.0 5 | onnxruntime==0.5.0 6 | onnx-simplifier==0.1.9 7 | Polygon3 8 | pyclipper 9 | shapely 10 | imgaug -------------------------------------------------------------------------------- /script/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ -------------------------------------------------------------------------------- /script/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/script/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /script/__pycache__/onnx_to_tensorrt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/script/__pycache__/onnx_to_tensorrt.cpython-36.pyc -------------------------------------------------------------------------------- /script/get_key_label.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: get_key_label.py 7 | @time: 2020/11/9 20:33 8 | 9 | """ 10 | 11 | train_list_file = './test_list.txt' 12 | test_list_file = './train_list.txt' 13 | keys_file = './key.txt' 14 | 15 | 16 | fid_key = open(keys_file,'w+',encoding='utf-8') 17 | keys = '' 18 | with open(train_list_file,'r',encoding='utf-8') as fid_train: 19 | lines = fid_train.readlines() 20 | for line in lines: 21 | line = line.strip().split('\t') 22 | keys+=line[-1] 23 | 24 | with open(test_list_file,'r',encoding='utf-8') as fid_test: 25 | lines = fid_test.readlines() 26 | for line in lines: 27 | line = line.strip().split('\t') 28 | keys+=line[-1] 29 | 30 | key = ''.join(list(set(list(keys)))) 31 | fid_key.write(key) -------------------------------------------------------------------------------- /script/get_train_list.py: -------------------------------------------------------------------------------- 1 | """ 2 | #!-*- coding=utf-8 -*- 3 | @author: BADBADBADBADBOY 4 | @contact: 2441124901@qq.com 5 | @software: PyCharm Community Edition 6 | @file: test.py 7 | @time: 2020/9/3 20:17 8 | 9 | """ 10 | import os 11 | import argparse 12 | def gen_train_file(args): 13 | label_path = args.label_path 14 | img_path = args.img_path 15 | files = os.listdir(img_path) 16 | with open(os.path.join(args.save_path,'train_list.txt'),'w+',encoding='utf-8') as fid: 17 | for file in files: 18 | label_str = os.path.join(img_path,file)+'\t'+os.path.join(label_path,os.path.splitext(file)[0]+'.txt')+'\n' 19 | fid.write(label_str) 20 | 21 | 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description='Hyperparams') 26 | parser.add_argument('--label_path', nargs='?', type=str, default=None) 27 | parser.add_argument('--img_path', nargs='?', type=str, default=None) 28 | parser.add_argument('--save_path', nargs='?', type=str, default=None) 29 | args = parser.parse_args() 30 | gen_train_file(args) -------------------------------------------------------------------------------- /script/warp_polar.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: tt.py 5 | @time: 2020/12/25 6 | """ 7 | import cv2 8 | import numpy as np 9 | import sys 10 | 11 | #实现图像的极坐标的转换 center代表及坐标变换中心‘;r是一个二元元组,代表最大与最小的距离;theta代表角度范围 12 | #rstep代表步长; thetastap代表角度的变化步长 13 | def polar(image,center,r,theta=(70,360+70),rstep=0.8,thetastep=360.0/(360*2)): 14 | #得到距离的最小值、最大值 15 | minr,maxr=r 16 | #角度的最小范围 17 | mintheta,maxtheta=theta 18 | #输出图像的高、宽 O:指定形状类型的数组float64 19 | H=int((maxr-minr)/rstep)+1 20 | W=int((maxtheta-mintheta)/thetastep)+1 21 | O=125*np.ones((H,W,3),image.dtype) 22 | #极坐标转换 利用tile函数实现W*1铺成的r个矩阵 并对生成的矩阵进行转置 23 | r=np.linspace(minr,maxr,H) 24 | r=np.tile(r,(W,1)) 25 | r=np.transpose(r) 26 | theta=np.linspace(mintheta,maxtheta,W) 27 | theta=np.tile(theta,(H,1)) 28 | x,y=cv2.polarToCart(r,theta,angleInDegrees=True) 29 | #最近插值法 30 | for i in range(H): 31 | for j in range(W): 32 | px=int(round(x[i][j])+cx) 33 | py=int(round(y[i][j])+cy) 34 | if((px>=0 and px<=w-1) and (py>=0 and py<=h-1)): 35 | O[i][j][0]=image[py][px][0] 36 | O[i][j][1]=image[py][px][1] 37 | O[i][j][2]=image[py][px][2] 38 | 39 | return O 40 | 41 | import time 42 | if __name__=="__main__": 43 | img = cv2.imread(r"C:\Users\fangxuwei\Desktop\111.jpg") 44 | # 传入的图像宽:600 高:400 45 | h, w = img.shape[:2] 46 | print("h:%s w:%s"%(h,w)) 47 | # 极坐标的变换中心(300,200) 48 | # cx, cy = h//2, w//2 49 | cx, cy = 204, 201 50 | # cx, cy = 131, 123 51 | # 圆的半径为10 颜色:灰 最小位数3 52 | cv2.circle(img, (int(cx), int(cy)), 10, (255, 0, 0, 0), 3) 53 | s = time.time() 54 | L = polar(img, (cx, cy), (h//3, w//2)) 55 | # 旋转 56 | L = cv2.flip(L, 0) 57 | print(time.time()-s) 58 | 59 | # 显示与输出 60 | cv2.imshow('img', img) 61 | cv2.imshow('O', L) 62 | cv2.waitKey(0) 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /to_onnx.sh: -------------------------------------------------------------------------------- 1 | python ./script/pytorch_to_onnx.py --config ./config/det_PSE_resnet50.yaml --model_path ./checkpoint/ag_PSE_bb_resnet50_he_FPN_Head_bs_8_ep_600_PSE20201004/PSE_best.pth.tar --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images/img_10.jpg --save_path ./onnx/PSEnet_20201012.onnx --batch_size 2 --max_size 1536 --algorithm PSE --add_padding -------------------------------------------------------------------------------- /to_tensorrt.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python3 ./script/onnx_to_tensorrt.py --onnx_path ./onnx/PSEnet.onnx --trt_engine_path ./onnx/PSEnet_batch.engine --img_path /src/notebooks/detect_text/icdar2015/ch4_test_images/img_10.jpg --batch_size 2 --algorithm PSE --max_size 1536 --add_padding -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/07 6 | """ 7 | -------------------------------------------------------------------------------- /tools/__pycache__/MarginLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/__pycache__/MarginLoss.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | from tools.cal_rescall.script import cal_recall_precison_f1 4 | from tools.cal_rescall.cal_det import cal_det_metrics 5 | 6 | 7 | result = cal_recall_precison_f1('/src/notebooks/detect_text/icdar2015/ch4_test_gts/','/src/notebooks/detect_text/PytorchOCR3/result/result_txt') 8 | print(result) 9 | 10 | out = cal_det_metrics('/src/notebooks/detect_text/icdar2015/ch4_test_gts/', '/src/notebooks/detect_text/PytorchOCR3/result/result_txt') 11 | print(out) -------------------------------------------------------------------------------- /tools/cal_rescall/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 _*- 2 | """ 3 | @author:fxw 4 | @file: __init__.py.py 5 | @time: 2020/08/13 6 | """ 7 | -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/cal_det.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/cal_iou.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-35.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/rrc_evaluation_funcs.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/script.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/script.cpython-35.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/__pycache__/script.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/cal_rescall/__pycache__/script.cpython-36.pyc -------------------------------------------------------------------------------- /tools/cal_rescall/cal_det.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from .cal_iou import DetectionIoUEvaluator 4 | 5 | def load_label_infor(label_file_path, do_ignore=False): 6 | files = os.listdir(label_file_path) 7 | img_name_label_dict = {} 8 | for file in files: 9 | bbox_infor = [] 10 | with open(os.path.join(label_file_path,file), "r",encoding='utf-8') as fin: 11 | lines = fin.readlines() 12 | for line in lines: 13 | txt_dict = {} 14 | substr = line.strip("\n").split(",") 15 | coord = list(map(int,substr[:8])) 16 | text = substr[-1] 17 | ignore = False 18 | if text == "###" and do_ignore: 19 | ignore = True 20 | txt_dict['ignore'] = ignore 21 | txt_dict['points'] = np.array(coord).reshape(4,2).tolist() 22 | txt_dict['text'] = ignore 23 | bbox_infor.append(txt_dict) 24 | if do_ignore: 25 | img_name_label_dict[file.replace('gt_','').replace('.txt','')] = bbox_infor 26 | else: 27 | img_name_label_dict[file.replace('res_','').replace('.txt','')] = bbox_infor 28 | return img_name_label_dict 29 | 30 | 31 | def cal_det_metrics(gt_label_path, save_res_path): 32 | """ 33 | calculate the detection metrics 34 | Args: 35 | gt_label_path(string): The groundtruth detection label file path 36 | save_res_path(string): The saved predicted detection label path 37 | return: 38 | claculated metrics including Hmean, precision and recall 39 | """ 40 | evaluator = DetectionIoUEvaluator() 41 | gt_label_infor = load_label_infor(gt_label_path, do_ignore=True) 42 | dt_label_infor = load_label_infor(save_res_path) 43 | results = [] 44 | for img_name in gt_label_infor: 45 | gt_label = gt_label_infor[img_name] 46 | if img_name not in dt_label_infor: 47 | dt_label = [] 48 | else: 49 | dt_label = dt_label_infor[img_name] 50 | result = evaluator.evaluate_image(gt_label, dt_label) 51 | results.append(result) 52 | methodMetrics = evaluator.combine_results(results) 53 | return methodMetrics 54 | -------------------------------------------------------------------------------- /tools/pruned/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BADBADBADBOY/pytorchOCR/efae5454651e96794dd7a460bd305dc9ccd95359/tools/pruned/__init__.py --------------------------------------------------------------------------------