├── images ├── imgs └── pipeline-cropped.pdf ├── requirements ├── optional.txt ├── mminstall.txt ├── build.txt ├── runtime.txt ├── docs.txt ├── readthedocs.txt └── tests.txt ├── tools ├── use_gpu.sh ├── data │ └── utils │ │ └── txt2lmdb.py ├── pretrain_kjf.sh ├── publish_model.py ├── train_1803.sh ├── use_gpu.py ├── train_1062.sh └── test_kjf.sh ├── configs ├── vie_custom │ ├── e2e_ar_ocr_pretrain │ │ └── ephoie │ │ │ └── ephoie_sdef_nark_3l_local_600e_1280_1061_kjf.py │ ├── e2e_ar_vie │ │ ├── sroie │ │ │ ├── local │ │ │ │ └── readme.py │ │ │ └── readme.py │ │ └── ephoie │ │ │ └── local │ │ │ └── readme.py │ └── _base_ │ │ ├── ocr_datasets │ │ ├── synthtext.py │ │ ├── custom_chn_v2_ar_cloud.py │ │ ├── synthtext_cloud.py │ │ ├── synth_chn_ar_cloud.py │ │ ├── synthtext_ar_cloud.py │ │ ├── local │ │ │ ├── nfv5_2200_ar_local_9999.py │ │ │ └── nfv5_2200_ar_local_1032.py │ │ ├── synth_chn_ar_cloud_kjf.py │ │ ├── ephoie_local.py │ │ ├── cord_ar_cloud.py │ │ └── sroie_cloud.py │ │ └── vie_datasets │ │ ├── nfv5_2200_ar_local_9999.py │ │ ├── nfv5_3128_ar_local_1061.py │ │ ├── nfv5_3125_ar_local_1032.py │ │ ├── nfv5_3125_ar_local_1061.py │ │ ├── nfv5_3125_ar_local_1062.py │ │ ├── nfv5_3125_ar_local_1803.py │ │ ├── nfv5_3128_ar_local_1803.py │ │ ├── cord_cloud.py │ │ ├── nfv4_ar_local_1803.py │ │ ├── cord_ar_local_1032.py │ │ ├── sroie_3090.py │ │ ├── cord_ar_cloud.py │ │ ├── nfv3_ar_cloud.py │ │ ├── nfv4_ar_cloud.py │ │ ├── nfv5_3125_3090.py │ │ ├── sroie_ar_local.py │ │ ├── local │ │ ├── ephoie_ar_local_9999.py │ │ ├── ephoie_ar_local_1033.py │ │ ├── ephoie_ar_local_1061.py │ │ ├── ephoie_ar_local_sort_1033.py │ │ └── ephoie_ar_local_1032.py │ │ └── sroie_ar_cloud_ssd.py └── _base_ │ ├── schedules │ ├── schedule_adadelta_5e.py │ ├── schedule_adam_600e.py │ ├── schedule_adam_step_5e.py │ ├── schedule_adam_step_6e.py │ ├── schedule_adam_step_600e.py │ ├── schedule_adadelta_18e.py │ ├── schedule_sgd_1500e.py │ ├── schedule_sgd_600e.py │ ├── schedule_sgd_1200e.py │ ├── schedule_adam_step_20e.py │ └── schedule_sgd_160e.py │ ├── default_runtime.py │ ├── runtime_10e.py │ ├── recog_models │ ├── satrn.py │ ├── nrtr_modality_transform.py │ ├── crnn.py │ ├── crnn_tps.py │ ├── sar.py │ ├── robust_scanner.py │ └── seg.py │ ├── det_datasets │ ├── ctw1500.py │ ├── icdar2017.py │ ├── icdar2015.py │ └── toy_data.py │ ├── recog_datasets │ ├── ST_charbox_train.py │ ├── MJ_train.py │ ├── seg_toy_data.py │ ├── ST_MJ_train.py │ └── ST_MJ_alphanumeric_train.py │ ├── det_models │ ├── panet_r50_fpem_ffm.py │ ├── drrg_r50_fpn_unet.py │ ├── dbnet_r18_fpnc.py │ ├── textsnake_r50_fpn_unet.py │ ├── dbnet_r50dcnv2_fpnc.py │ ├── fcenet_r50_fpn.py │ └── fcenet_r50dcnv2_fpn.py │ └── recog_pipelines │ ├── crnn_pipeline.py │ ├── crnn_tps_pipeline.py │ ├── nrtr_pipeline.py │ └── sar_pipeline.py ├── tests ├── data │ ├── test_img1.jpg │ ├── test_img1.png │ ├── test_img2.jpg │ ├── toy_dataset │ │ ├── annotations │ │ │ └── test │ │ │ │ ├── gt_img_2.txt │ │ │ │ ├── gt_img_5.txt │ │ │ │ ├── gt_img_4.txt │ │ │ │ ├── gt_img_9.txt │ │ │ │ ├── gt_img_1.txt │ │ │ │ ├── gt_img_8.txt │ │ │ │ ├── gt_img_10.txt │ │ │ │ ├── gt_img_3.txt │ │ │ │ ├── gt_img_7.txt │ │ │ │ └── gt_img_6.txt │ │ ├── imgs │ │ │ └── test │ │ │ │ ├── img_1.jpg │ │ │ │ ├── img_2.jpg │ │ │ │ ├── img_3.jpg │ │ │ │ ├── img_4.jpg │ │ │ │ ├── img_5.jpg │ │ │ │ ├── img_6.jpg │ │ │ │ ├── img_7.jpg │ │ │ │ ├── img_8.jpg │ │ │ │ ├── img_9.jpg │ │ │ │ └── img_10.jpg │ │ └── img_list.txt │ ├── ocr_toy_dataset │ │ ├── imgs │ │ │ ├── 1036169.jpg │ │ │ ├── 1058891.jpg │ │ │ ├── 1058892.jpg │ │ │ ├── 1190237.jpg │ │ │ ├── 1210236.jpg │ │ │ ├── 1223729.jpg │ │ │ ├── 1223731.jpg │ │ │ ├── 1223732.jpg │ │ │ ├── 1223733.jpg │ │ │ └── 1240078.jpg │ │ ├── label.lmdb │ │ │ ├── data.mdb │ │ │ └── lock.mdb │ │ └── label.txt │ ├── ocr_char_ann_toy_dataset │ │ ├── imgs │ │ │ ├── resort_88_101_1.png │ │ │ ├── resort_95_53_6.png │ │ │ ├── richard+feynman_101_8_6.png │ │ │ ├── richard+feynman_104_58_9.png │ │ │ ├── richard+feynman_110_1_6.png │ │ │ ├── richard+feynman_12_61_4.png │ │ │ ├── richard+feynman_130_74_1.png │ │ │ ├── richard+feynman_134_30_15.png │ │ │ ├── richard+feynman_15_43_4.png │ │ │ └── richard+feynman_18_18_5.png │ │ └── instances_test.txt │ └── kie_toy_dataset │ │ ├── class_list.txt │ │ └── dict.txt ├── test_models │ ├── test_ocr_fuser.py │ ├── test_ocr_head.py │ ├── test_ocr_neck.py │ ├── test_targets.py │ └── test_ocr_preprocessor.py ├── test_tools │ └── test_data_converter.py ├── test_core │ └── test_end2end_vis.py ├── test_utils │ ├── test_model.py │ ├── test_version_utils.py │ ├── test_string_util.py │ └── test_check_argument.py ├── test_dataset │ └── test_test_time_aug.py └── test_apis │ └── test_image_misc.py ├── docs ├── en │ ├── requirements.txt │ ├── _static │ │ ├── images │ │ │ └── mmocr.png │ │ └── css │ │ │ └── readthedocs.css │ ├── datasets │ │ ├── ner.md │ │ └── kie.md │ ├── Makefile │ ├── make.bat │ ├── merge_docs.sh │ ├── tools.md │ └── index.rst └── zh_cn │ ├── _static │ ├── images │ │ └── mmocr.png │ └── css │ │ └── readthedocs.css │ ├── cp_origin_docs.sh │ ├── datasets │ ├── ner.md │ └── kie.md │ ├── Makefile │ ├── make.bat │ ├── merge_docs.sh │ └── index.rst ├── mmocr ├── version.py ├── models │ ├── common │ │ ├── detectors │ │ │ └── __init__.py │ │ ├── backbones │ │ │ └── __init__.py │ │ ├── layers │ │ │ └── __init__.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── dice_loss.py │ │ │ └── focal_loss.py │ │ ├── __init__.py │ │ └── modules │ │ │ └── __init__.py │ ├── kie │ │ ├── extractors │ │ │ └── __init__.py │ │ ├── heads │ │ │ └── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ └── __init__.py │ ├── spotting │ │ ├── modules │ │ │ ├── old │ │ │ │ └── __init__.py │ │ │ ├── ops │ │ │ │ ├── modules │ │ │ │ │ └── __init__.py │ │ │ │ ├── make.sh │ │ │ │ ├── functions │ │ │ │ │ └── __init__.py │ │ │ │ └── src │ │ │ │ │ ├── vision.cpp │ │ │ │ │ ├── cuda │ │ │ │ │ └── ms_deform_attn_cuda.h │ │ │ │ │ └── cpu │ │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ │ └── ms_deform_attn_cpu.cpp │ │ │ └── __init__.py │ │ ├── optimizers │ │ │ └── __init__.py │ │ ├── recognizers │ │ │ ├── old │ │ │ │ └── __init__.py │ │ │ ├── re_imple_trie │ │ │ │ ├── connects │ │ │ │ │ └── __init__.py │ │ │ │ ├── embedding │ │ │ │ │ └── __init__.py │ │ │ │ ├── __init__.py │ │ │ │ └── custom_davar_builder.py │ │ │ ├── re_imple_pick │ │ │ │ ├── __init__.py │ │ │ │ └── model │ │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── backbone │ │ │ └── __init__.py │ │ ├── detectors │ │ │ └── __init__.py │ │ ├── rois │ │ │ └── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ ├── spotters │ │ │ └── __init__.py │ │ └── __init__.py │ ├── textrecog │ │ ├── heads │ │ │ └── __init__.py │ │ ├── necks │ │ │ └── __init__.py │ │ ├── fusers │ │ │ └── __init__.py │ │ ├── preprocessor │ │ │ ├── __init__.py │ │ │ └── base_preprocessor.py │ │ ├── recognizer │ │ │ ├── satrn.py │ │ │ ├── crnn.py │ │ │ ├── sar.py │ │ │ ├── nrtr.py │ │ │ ├── robust_scanner.py │ │ │ └── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ ├── convertors │ │ │ └── __init__.py │ │ ├── encoders │ │ │ ├── base_encoder.py │ │ │ ├── __init__.py │ │ │ └── channel_reduction_encoder.py │ │ ├── backbones │ │ │ └── __init__.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── lstm_layer.py │ │ │ ├── robust_scanner_fusion_layer.py │ │ │ ├── dot_product_attention_layer.py │ │ │ └── position_aware_layer.py │ │ ├── __init__.py │ │ └── decoders │ │ │ ├── __init__.py │ │ │ └── base_decoder.py │ ├── ner │ │ ├── decoders │ │ │ └── __init__.py │ │ ├── encoders │ │ │ └── __init__.py │ │ ├── convertors │ │ │ └── __init__.py │ │ ├── classifiers │ │ │ └── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── activations.py │ │ ├── losses │ │ │ └── __init__.py │ │ └── __init__.py │ ├── textdet │ │ ├── necks │ │ │ └── __init__.py │ │ ├── modules │ │ │ └── __init__.py │ │ ├── losses │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── dense_heads │ │ │ └── __init__.py │ │ ├── postprocess │ │ │ ├── base_postprocessor.py │ │ │ └── __init__.py │ │ └── detectors │ │ │ ├── __init__.py │ │ │ ├── dbnet.py │ │ │ ├── psenet.py │ │ │ ├── panet.py │ │ │ ├── textsnake.py │ │ │ └── fcenet.py │ └── __init__.py ├── datasets │ ├── utils │ │ └── __init__.py │ ├── builder.py │ ├── pipelines │ │ └── textdet_targets │ │ │ ├── __init__.py │ │ │ └── psenet_targets.py │ ├── __init__.py │ └── ocr_dataset.py ├── core │ ├── deployment │ │ └── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ └── kie_metric.py │ └── __init__.py ├── apis │ └── __init__.py └── utils │ ├── collect_env.py │ ├── logger.py │ ├── __init__.py │ ├── fileio.py │ ├── data_convert_util.py │ └── string_util.py ├── requirements.txt ├── .idea ├── other.xml ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── sshConfigs.xml └── ie_e2e.iml ├── custom_utils ├── module_list_cuda101.txt ├── dict_default.json ├── module_list_v100.txt ├── clean_pths.py └── dataset │ └── prepare_pretrain.py └── setup.cfg /images/imgs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements/optional.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv-full>=1.3.4 2 | mmdet>=2.11.0 3 | -------------------------------------------------------------------------------- /tools/use_gpu.sh: -------------------------------------------------------------------------------- 1 | python use_gpu.py --size 13000 --gpus 4 --interval 0.01 -------------------------------------------------------------------------------- /configs/vie_custom/e2e_ar_ocr_pretrain/ephoie/ephoie_sdef_nark_3l_local_600e_1280_1061_kjf.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data/test_img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/test_img1.jpg -------------------------------------------------------------------------------- /tests/data/test_img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/test_img1.png -------------------------------------------------------------------------------- /tests/data/test_img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/test_img2.jpg -------------------------------------------------------------------------------- /docs/en/requirements.txt: -------------------------------------------------------------------------------- 1 | recommonmark 2 | sphinx 3 | sphinx_markdown_tables 4 | sphinx_rtd_theme 5 | -------------------------------------------------------------------------------- /images/pipeline-cropped.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/images/pipeline-cropped.pdf -------------------------------------------------------------------------------- /docs/en/_static/images/mmocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/docs/en/_static/images/mmocr.png -------------------------------------------------------------------------------- /docs/zh_cn/_static/images/mmocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/docs/zh_cn/_static/images/mmocr.png -------------------------------------------------------------------------------- /mmocr/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.4.0' 4 | short_version = __version__ 5 | -------------------------------------------------------------------------------- /requirements/build.txt: -------------------------------------------------------------------------------- 1 | # These must be installed before building mmocr 2 | numpy 3 | pyclipper 4 | torch>=1.1 5 | timm==0.4.5 6 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_2.txt: -------------------------------------------------------------------------------- 1 | 602,173,635,175,634,197,602,196,EXIT 2 | 734,310,792,320,792,364,738,361,I2R 3 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_5.txt: -------------------------------------------------------------------------------- 1 | 408,409,437,436,434,461,405,433,### 2 | 437,434,443,440,441,467,435,462,### 3 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_1.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_2.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_3.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_4.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_5.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_6.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_7.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_8.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_9.jpg -------------------------------------------------------------------------------- /mmocr/models/common/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .single_stage import SingleStageDetector 2 | 3 | __all__ = ['SingleStageDetector'] 4 | -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1036169.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1036169.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1058891.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1058891.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1058892.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1058892.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1190237.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1190237.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1210236.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1210236.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1223729.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1223729.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1223731.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1223731.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1223732.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1223732.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1223733.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1223733.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/imgs/1240078.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/imgs/1240078.jpg -------------------------------------------------------------------------------- /tests/data/toy_dataset/imgs/test/img_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/toy_dataset/imgs/test/img_10.jpg -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/label.lmdb/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/label.lmdb/data.mdb -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/label.lmdb/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_toy_dataset/label.lmdb/lock.mdb -------------------------------------------------------------------------------- /mmocr/models/common/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .unet import UNet 3 | 4 | __all__ = ['UNet'] 5 | -------------------------------------------------------------------------------- /mmocr/models/kie/extractors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .sdmgr import SDMGR 3 | 4 | __all__ = ['SDMGR'] 5 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/old/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/24 21:05 4 | # @Author : WeiHua 5 | -------------------------------------------------------------------------------- /mmocr/models/spotting/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/23 15:11 4 | # @Author : WeiHua 5 | -------------------------------------------------------------------------------- /mmocr/models/kie/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .sdmgr_head import SDMGRHead 3 | 4 | __all__ = ['SDMGRHead'] 5 | -------------------------------------------------------------------------------- /mmocr/models/kie/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .sdmgr_loss import SDMGRLoss 3 | 4 | __all__ = ['SDMGRLoss'] 5 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/old/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/4/18 20:44 4 | # @Author : WeiHua 5 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .seg_head import SegHead 3 | 4 | __all__ = ['SegHead'] 5 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpn_ocr import FPNOCR 3 | 4 | __all__ = ['FPNOCR'] 5 | -------------------------------------------------------------------------------- /mmocr/models/common/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_layers import TFDecoderLayer, TFEncoderLayer 2 | 3 | __all__ = ['TFEncoderLayer', 'TFDecoderLayer'] 4 | -------------------------------------------------------------------------------- /mmocr/models/ner/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fc_decoder import FCDecoder 3 | 4 | __all__ = ['FCDecoder'] 5 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/fusers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .abi_fuser import ABIFuser 3 | 4 | __all__ = ['ABIFuser'] 5 | -------------------------------------------------------------------------------- /mmocr/models/ner/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bert_encoder import BertEncoder 3 | 4 | __all__ = ['BertEncoder'] 5 | -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/resort_88_101_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/resort_88_101_1.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/resort_95_53_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/resort_95_53_6.png -------------------------------------------------------------------------------- /mmocr/models/ner/convertors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ner_convertor import NerConvertor 3 | 4 | __all__ = ['NerConvertor'] 5 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_4.txt: -------------------------------------------------------------------------------- 1 | 692,268,710,268,710,293,692,293,### 2 | 663,224,733,230,737,246,661,242,### 3 | 668,242,737,244,734,260,670,256,### 4 | -------------------------------------------------------------------------------- /mmocr/models/ner/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ner_classifier import NerClassifier 3 | 4 | __all__ = ['NerClassifier'] 5 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_trie/connects/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/6/4 16:02 4 | # @Author : WeiHua 5 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_trie/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/6/4 16:02 4 | # @Author : WeiHua 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/build.txt 2 | -r requirements/optional.txt 3 | -r requirements/runtime.txt 4 | -r requirements/tests.txt 5 | ipdb 6 | tqdm 7 | pytorch-crf 8 | -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_101_8_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_101_8_6.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_104_58_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_104_58_9.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_110_1_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_110_1_6.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_12_61_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_12_61_4.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_130_74_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_130_74_1.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_134_30_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_134_30_15.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_15_43_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_15_43_4.png -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_18_18_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfkuang/CFAM/HEAD/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_18_18_5.png -------------------------------------------------------------------------------- /tests/data/toy_dataset/img_list.txt: -------------------------------------------------------------------------------- 1 | img_10.jpg 2 | img_1.jpg 3 | img_2.jpg 4 | img_3.jpg 5 | img_4.jpg 6 | img_5.jpg 7 | img_6.jpg 8 | img_7.jpg 9 | img_8.jpg 10 | img_9.jpg 11 | -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | imgaug 2 | lanms-neo==1.0.2 3 | lmdb 4 | matplotlib 5 | numba>=0.45.1 6 | numpy 7 | pyclipper 8 | rapidfuzz 9 | scikit-image 10 | six 11 | terminaltables 12 | -------------------------------------------------------------------------------- /docs/en/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../images/mmocr.png"); 3 | background-size: 110px 40px; 4 | height: 40px; 5 | width: 110px; 6 | } 7 | -------------------------------------------------------------------------------- /docs/zh_cn/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../images/mmocr.png"); 3 | background-size: 110px 40px; 4 | height: 40px; 5 | width: 110px; 6 | } 7 | -------------------------------------------------------------------------------- /mmocr/models/ner/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .activations import GeluNew 3 | from .bert import BertModel 4 | 5 | __all__ = ['BertModel', 'GeluNew'] 6 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /mmocr/models/common/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dice_loss import DiceLoss 3 | from .focal_loss import FocalLoss 4 | 5 | __all__ = ['DiceLoss', 'FocalLoss'] 6 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_9.txt: -------------------------------------------------------------------------------- 1 | 344,206,384,207,381,228,342,227,EXIT 2 | 47,183,94,183,83,212,42,206,### 3 | 913,515,1068,526,1081,595,921,578,STAGE 4 | 240,291,273,291,273,298,240,297,### 5 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | myst-parser 3 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 4 | sphinx==4.0.2 5 | sphinx_copybutton 6 | sphinx_markdown_tables 7 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /mmocr/models/spotting/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/31 10:49 4 | # @Author : WeiHua 5 | 6 | from .custom_resnet import CustomResNet 7 | 8 | __all__ = ['CustomResNet'] -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adadelta_5e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adadelta', lr=1.0) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[]) 6 | total_epochs = 5 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adam_600e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adam', lr=1e-3) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9) 6 | total_epochs = 600 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adam_step_5e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adam', lr=1e-3) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[3, 4]) 6 | total_epochs = 5 7 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adam_step_6e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adam', lr=1e-3) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[3, 4]) 6 | total_epochs = 6 7 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_trie/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/6/4 16:02 4 | # @Author : WeiHua 5 | 6 | 7 | from .custom_davar_builder import build_connect, build_embedding 8 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adam_step_600e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adam', lr=1e-4) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[200, 400]) 6 | total_epochs = 600 7 | -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | imgaug 2 | kwarray 3 | lanms-neo==1.0.2 4 | lmdb 5 | matplotlib 6 | mmcv 7 | mmdet 8 | pyclipper 9 | rapidfuzz 10 | regex 11 | scikit-image 12 | scipy 13 | shapely 14 | titlecase 15 | torch 16 | torchvision 17 | -------------------------------------------------------------------------------- /configs/vie_custom/e2e_ar_vie/sroie/local/readme.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/5/30 22:44 4 | # @Author : WeiHua 5 | 6 | sdef = sroie_default = "DEFAULT" + "random resize" + "learning rate=2e-4" + "dropout = 0.2" -------------------------------------------------------------------------------- /mmocr/models/textrecog/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_preprocessor import BasePreprocessor 3 | from .tps_preprocessor import TPSPreprocessor 4 | 5 | __all__ = ['BasePreprocessor', 'TPSPreprocessor'] 6 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adadelta_18e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adadelta', lr=0.5) 3 | optimizer_config = dict(grad_clip=dict(max_norm=0.5)) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[8, 14, 16]) 6 | total_epochs = 18 7 | -------------------------------------------------------------------------------- /mmocr/models/spotting/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/8 17:25 4 | # @Author : WeiHua 5 | 6 | from .rei_mask_rcnn import ReI_OCRMaskRCNN 7 | 8 | __all__ = [ 9 | 'ReI_OCRMaskRCNN' 10 | ] 11 | -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | asynctest 2 | codecov 3 | flake8 4 | isort 5 | # Note: used for kwarray.group_items, this may be ported to mmcv in the future. 6 | kwarray 7 | pytest 8 | pytest-cov 9 | pytest-runner 10 | ubelt 11 | xdoctest >= 0.10.0 12 | yapf 13 | -------------------------------------------------------------------------------- /mmocr/models/ner/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .masked_cross_entropy_loss import MaskedCrossEntropyLoss 3 | from .masked_focal_loss import MaskedFocalLoss 4 | 5 | __all__ = ['MaskedCrossEntropyLoss', 'MaskedFocalLoss'] 6 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_pick/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :__init__.py.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/7/5 19:57 9 | ''' 10 | 11 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_pick/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :__init__.py.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/7/5 20:19 9 | ''' 10 | -------------------------------------------------------------------------------- /mmocr/models/spotting/rois/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/11 15:39 4 | # @Author : WeiHua 5 | from .rei_standard_roi_head import ReI_StandardRoIHead 6 | 7 | __all__ = [ 8 | 'ReI_StandardRoIHead' 9 | ] 10 | -------------------------------------------------------------------------------- /mmocr/models/textdet/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpem_ffm import FPEM_FFM 3 | from .fpn_cat import FPNC 4 | from .fpn_unet import FPN_UNet 5 | from .fpnf import FPNF 6 | 7 | __all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNet'] 8 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_sgd_1500e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=1e-3, momentum=0.90, weight_decay=5e-4) 3 | optimizer_config = dict(grad_clip=None) 4 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) 5 | total_epochs = 1500 6 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_sgd_600e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=1e-3, momentum=0.99, weight_decay=5e-4) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='step', step=[200, 400]) 6 | total_epochs = 600 7 | -------------------------------------------------------------------------------- /mmocr/models/textdet/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .gcn import GCN 3 | from .local_graph import LocalGraphs 4 | from .proposal_local_graph import ProposalLocalGraphs 5 | 6 | __all__ = ['LocalGraphs', 'ProposalLocalGraphs', 'GCN'] 7 | -------------------------------------------------------------------------------- /tests/data/ocr_toy_dataset/label.txt: -------------------------------------------------------------------------------- 1 | 1223731.jpg GRAND 2 | 1223733.jpg HOTEL 3 | 1223732.jpg HOTEL 4 | 1223729.jpg PACIFIC 5 | 1036169.jpg 03/09/2009 6 | 1190237.jpg ANING 7 | 1058891.jpg Virgin 8 | 1058892.jpg america 9 | 1240078.jpg ATTACK 10 | 1210236.jpg DAVIDSON 11 | -------------------------------------------------------------------------------- /mmocr/models/spotting/losses/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/15 21:16 4 | # @Author : WeiHua 5 | 6 | from .multi_step_loss import MultiStepLoss 7 | from .master_loss import MASTERTFLoss 8 | 9 | __all__ = ['MultiStepLoss', 'MASTERTFLoss'] -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_sgd_1200e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) 6 | total_epochs = 1200 7 | -------------------------------------------------------------------------------- /docs/zh_cn/cp_origin_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copy *.md files from docs/ if it doesn't have a Chinese translation 4 | 5 | for filename in $(find ../en/ -name '*.md' -printf "%P\n"); 6 | do 7 | mkdir -p $(dirname $filename) 8 | cp -n ../en/$filename ./$filename 9 | done 10 | -------------------------------------------------------------------------------- /mmocr/models/kie/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import extractors, heads, losses 3 | 4 | from .extractors import * # NOQA 5 | from .heads import * # NOQA 6 | from .losses import * # NOQA 7 | 8 | __all__ = extractors.__all__ + heads.__all__ + losses.__all__ 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/satrn.py: -------------------------------------------------------------------------------- 1 | from mmocr.models.builder import DETECTORS 2 | from .encode_decode_recognizer import EncodeDecodeRecognizer 3 | 4 | 5 | @DETECTORS.register_module() 6 | class SATRN(EncodeDecodeRecognizer): 7 | """Implementation of `SATRN `_""" 8 | -------------------------------------------------------------------------------- /mmocr/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .loader import HardDiskLoader, LmdbLoader 3 | from .parser import LineJsonParser, LineStrParser, CustomLineJsonParser 4 | 5 | __all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser', 'CustomLineJsonParser'] 6 | -------------------------------------------------------------------------------- /mmocr/models/spotting/spotters/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/8 16:01 4 | # @Author : WeiHua 5 | from .spotter_mixin import SpotterMixin 6 | from .two_stage_vie_spotter import TwoStageSpotter 7 | 8 | __all__ = [ 9 | 'SpotterMixin', 'TwoStageSpotter' 10 | ] 11 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ce_loss import CELoss, SARLoss, TFLoss 3 | from .ctc_loss import CTCLoss 4 | from .mix_loss import ABILoss 5 | from .seg_loss import SegLoss 6 | 7 | __all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'ABILoss'] 8 | -------------------------------------------------------------------------------- /configs/vie_custom/e2e_ar_vie/ephoie/local/readme.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/9 15:57 4 | # @Author : WeiHua 5 | 6 | """ 7 | Here, default means: auto-regression & 1280 & learning rate = 1e4 & both as query & Fuse feature maps & REC Weight = KIE Weight = 10.0 8 | 9 | """ 10 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/crnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .encode_decode_recognizer import EncodeDecodeRecognizer 4 | 5 | 6 | @DETECTORS.register_module() 7 | class CRNNNet(EncodeDecodeRecognizer): 8 | """CTC-loss based recognizer.""" 9 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_adam_step_20e.py: -------------------------------------------------------------------------------- 1 | optimizer = dict(type='Adam', lr=1e-4) 2 | optimizer_config = dict(grad_clip=None) 3 | lr_config = dict( 4 | policy='step', 5 | step=[16, 18], 6 | warmup='linear', 7 | warmup_iters=1, 8 | warmup_ratio=0.001, 9 | warmup_by_epoch=True) 10 | total_epochs = 20 11 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_1.txt: -------------------------------------------------------------------------------- 1 | 377,117,463,117,465,130,378,130,Genaxis Theatre 2 | 493,115,519,115,519,131,493,131,[06] 3 | 374,155,409,155,409,170,374,170,### 4 | 492,151,551,151,551,170,492,170,62-03 5 | 376,198,422,198,422,212,376,212,Carpark 6 | 494,190,539,189,539,205,494,206,### 7 | 374,1,494,0,492,85,372,86,### 8 | -------------------------------------------------------------------------------- /mmocr/core/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .deploy_utils import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, 3 | TensorRTDetector, TensorRTRecognizer) 4 | 5 | __all__ = [ 6 | 'ONNXRuntimeRecognizer', 'ONNXRuntimeDetector', 'TensorRTDetector', 7 | 'TensorRTRecognizer' 8 | ] 9 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/sar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .encode_decode_recognizer import EncodeDecodeRecognizer 4 | 5 | 6 | @DETECTORS.register_module() 7 | class SARNet(EncodeDecodeRecognizer): 8 | """Implementation of `SAR `_""" 9 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/nrtr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .encode_decode_recognizer import EncodeDecodeRecognizer 4 | 5 | 6 | @DETECTORS.register_module() 7 | class NRTR(EncodeDecodeRecognizer): 8 | """Implementation of `NRTR `_""" 9 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=5, 5 | hooks=[ 6 | dict(type='TextLoggerHook') 7 | 8 | ]) 9 | # yapf:enable 10 | dist_params = dict(backend='nccl') 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | workflow = [('train', 1)] 15 | -------------------------------------------------------------------------------- /custom_utils/module_list_cuda101.txt: -------------------------------------------------------------------------------- 1 | ipdb 0.13.9 2 | mmcv-full 1.4.2 3 | mmdet 2.19.1 4 | mmocr 0.4.0 /home/whua/code/ie_e2e 5 | opencv-python 4.5.5.62 6 | timm 0.4.5 7 | torch 1.7.1+cu92 8 | torchaudio 0.7.2 9 | torchvision 0.8.2+cu92 10 | tqdm 4.62.3 -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_8.txt: -------------------------------------------------------------------------------- 1 | 568,347,623,350,617,380,568,375,WHY 2 | 626,347,673,345,668,382,625,380,PAY 3 | 675,351,725,350,726,381,678,379,FOR 4 | 598,381,728,385,724,420,598,413,NOTHING? 5 | 762,351,845,357,845,380,760,377,### 6 | 562,588,613,588,611,632,564,633,### 7 | 615,593,730,603,727,646,614,634,### 8 | 560,634,730,650,730,691,556,678,### 9 | -------------------------------------------------------------------------------- /mmocr/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import backbones, layers, losses, modules 3 | 4 | from .backbones import * # NOQA 5 | from .losses import * # NOQA 6 | from .layers import * # NOQA 7 | from .modules import * # NOQA 8 | 9 | __all__ = backbones.__all__ + losses.__all__ + layers.__all__ + modules.__all__ 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_sgd_160e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[80, 128]) 11 | total_epochs = 160 12 | -------------------------------------------------------------------------------- /mmocr/models/common/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_module import (MultiHeadAttention, PositionalEncoding, 2 | PositionwiseFeedForward, 3 | ScaledDotProductAttention) 4 | 5 | __all__ = [ 6 | 'ScaledDotProductAttention', 'MultiHeadAttention', 7 | 'PositionwiseFeedForward', 'PositionalEncoding' 8 | ] 9 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_10.txt: -------------------------------------------------------------------------------- 1 | 261,138,284,140,279,158,260,158,### 2 | 288,138,417,140,416,161,290,157,HarbourFront 3 | 743,145,779,146,780,163,746,163,CC22 4 | 783,129,831,132,833,155,785,153,bua 5 | 831,133,870,135,874,156,835,155,### 6 | 159,205,230,204,231,218,159,219,### 7 | 785,158,856,158,860,178,787,179,### 8 | 1011,157,1079,160,1076,173,1011,170,### 9 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/convertors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .abi import ABIConvertor 3 | from .attn import AttnConvertor 4 | from .base import BaseConvertor 5 | from .ctc import CTCConvertor 6 | from .seg import SegConvertor 7 | 8 | __all__ = [ 9 | 'BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor', 10 | 'ABIConvertor' 11 | ] 12 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/encoders/base_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import BaseModule 3 | 4 | from mmocr.models.builder import ENCODERS 5 | 6 | 7 | @ENCODERS.register_module() 8 | class BaseEncoder(BaseModule): 9 | """Base Encoder class for text recognition.""" 10 | 11 | def forward(self, feat, **kwargs): 12 | return feat 13 | -------------------------------------------------------------------------------- /configs/_base_/runtime_10e.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=10) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=5, 5 | hooks=[ 6 | dict(type='TextLoggerHook') 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | dist_params = dict(backend='nccl') 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | workflow = [('train', 1)] 15 | -------------------------------------------------------------------------------- /tests/data/ocr_char_ann_toy_dataset/instances_test.txt: -------------------------------------------------------------------------------- 1 | resort_88_101_1.png From: 2 | resort_95_53_6.png out 3 | richard+feynman_101_8_6.png the 4 | richard+feynman_104_58_9.png fast 5 | richard+feynman_110_1_6.png many 6 | richard+feynman_12_61_4.png the 7 | richard+feynman_130_74_1.png the 8 | richard+feynman_134_30_15.png how 9 | richard+feynman_15_43_4.png the 10 | richard+feynman_18_18_5.png Lines: 11 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/satrn.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True) 3 | 4 | model = dict( 5 | type='SATRN', 6 | backbone=dict(type='ShallowCNN'), 7 | encoder=dict(type='SatrnEncoder'), 8 | decoder=dict(type='TFDecoder'), 9 | loss=dict(type='TFLoss'), 10 | label_convertor=label_convertor, 11 | max_seq_len=40) 12 | -------------------------------------------------------------------------------- /mmocr/models/textdet/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .db_loss import DBLoss 3 | from .drrg_loss import DRRGLoss 4 | from .fce_loss import FCELoss 5 | from .pan_loss import PANLoss 6 | from .pse_loss import PSELoss 7 | from .textsnake_loss import TextSnakeLoss 8 | 9 | __all__ = [ 10 | 'PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss', 'FCELoss', 'DRRGLoss' 11 | ] 12 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/preprocessor/base_preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import BaseModule 3 | 4 | from mmocr.models.builder import PREPROCESSOR 5 | 6 | 7 | @PREPROCESSOR.register_module() 8 | class BasePreprocessor(BaseModule): 9 | """Base Preprocessor class for text recognition.""" 10 | 11 | def forward(self, x, **kwargs): 12 | return x 13 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/robust_scanner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .encode_decode_recognizer import EncodeDecodeRecognizer 4 | 5 | 6 | @DETECTORS.register_module() 7 | class RobustScanner(EncodeDecodeRecognizer): 8 | """Implementation of `RobustScanner. 9 | 10 | 11 | """ 12 | -------------------------------------------------------------------------------- /mmocr/datasets/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | LOADERS = Registry('loader') 5 | PARSERS = Registry('parser') 6 | 7 | 8 | def build_loader(cfg): 9 | """Build anno file loader.""" 10 | return build_from_cfg(cfg, LOADERS) 11 | 12 | 13 | def build_parser(cfg): 14 | """Build anno file parser.""" 15 | return build_from_cfg(cfg, PARSERS) 16 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .nrtr_modality_transformer import NRTRModalityTransform 3 | from .resnet31_ocr import ResNet31OCR 4 | from .resnet_abi import ResNetABI 5 | from .shallow_cnn import ShallowCNN 6 | from .very_deep_vgg import VeryDeepVgg 7 | 8 | __all__ = [ 9 | 'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN', 10 | 'ResNetABI' 11 | ] 12 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/nrtr_modality_transform.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True) 3 | 4 | model = dict( 5 | type='NRTR', 6 | backbone=dict(type='NRTRModalityTransform'), 7 | encoder=dict(type='NRTREncoder', n_layers=12), 8 | decoder=dict(type='NRTRDecoder'), 9 | loss=dict(type='TFLoss'), 10 | label_convertor=label_convertor, 11 | max_seq_len=40) 12 | -------------------------------------------------------------------------------- /tests/test_models/test_ocr_fuser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmocr.models.textrecog.fusers import ABIFuser 5 | 6 | 7 | def test_base_alignment(): 8 | model = ABIFuser(d_model=512, num_chars=90, max_seq_len=40) 9 | l_feat = torch.randn(1, 40, 512) 10 | v_feat = torch.randn(1, 40, 512) 11 | result = model(l_feat, v_feat) 12 | assert result['logits'].shape == torch.Size([1, 40, 90]) 13 | -------------------------------------------------------------------------------- /mmocr/models/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import classifiers, convertors, decoders, encoders, losses 3 | 4 | from .classifiers import * # NOQA 5 | from .convertors import * # NOQA 6 | from .decoders import * # NOQA 7 | from .encoders import * # NOQA 8 | from .losses import * # NOQA 9 | 10 | __all__ = ( 11 | classifiers.__all__ + convertors.__all__ + decoders.__all__ + 12 | encoders.__all__ + losses.__all__) 13 | -------------------------------------------------------------------------------- /mmocr/models/textdet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import dense_heads, detectors, losses, necks, postprocess 3 | 4 | from .dense_heads import * # NOQA 5 | from .detectors import * # NOQA 6 | from .losses import * # NOQA 7 | from .necks import * # NOQA 8 | from .postprocess import * # NOQA 9 | 10 | __all__ = ( 11 | dense_heads.__all__ + detectors.__all__ + losses.__all__ + necks.__all__ + 12 | postprocess.__all__) 13 | -------------------------------------------------------------------------------- /mmocr/models/textdet/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .db_head import DBHead 3 | from .drrg_head import DRRGHead 4 | from .fce_head import FCEHead 5 | from .head_mixin import HeadMixin 6 | from .pan_head import PANHead 7 | from .pse_head import PSEHead 8 | from .textsnake_head import TextSnakeHead 9 | 10 | __all__ = [ 11 | 'PSEHead', 'PANHead', 'DBHead', 'FCEHead', 'TextSnakeHead', 'DRRGHead', 12 | 'HeadMixin' 13 | ] 14 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/crnn.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True) 3 | 4 | model = dict( 5 | type='CRNNNet', 6 | preprocessor=None, 7 | backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), 8 | encoder=None, 9 | decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), 10 | loss=dict(type='CTCLoss'), 11 | label_convertor=label_convertor, 12 | pretrained=None) 13 | -------------------------------------------------------------------------------- /custom_utils/dict_default.json: -------------------------------------------------------------------------------- 1 | ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "!", "\"", "#", "$", "%", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "<", "=", ">", "?", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", " "] -------------------------------------------------------------------------------- /configs/_base_/det_datasets/ctw1500.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'IcdarDataset' 2 | data_root = 'data/ctw1500' 3 | 4 | train = dict( 5 | type=dataset_type, 6 | ann_file=f'{data_root}/instances_training.json', 7 | img_prefix=f'{data_root}/imgs', 8 | pipeline=None) 9 | 10 | test = dict( 11 | type=dataset_type, 12 | ann_file=f'{data_root}/instances_test.json', 13 | img_prefix=f'{data_root}/imgs', 14 | pipeline=None) 15 | 16 | train_list = [train] 17 | 18 | test_list = [test] 19 | -------------------------------------------------------------------------------- /configs/_base_/det_datasets/icdar2017.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'IcdarDataset' 2 | data_root = 'data/icdar2017' 3 | 4 | train = dict( 5 | type=dataset_type, 6 | ann_file=f'{data_root}/instances_training.json', 7 | img_prefix=f'{data_root}/imgs', 8 | pipeline=None) 9 | 10 | test = dict( 11 | type=dataset_type, 12 | ann_file=f'{data_root}/instances_val.json', 13 | img_prefix=f'{data_root}/imgs', 14 | pipeline=None) 15 | 16 | train_list = [train] 17 | 18 | test_list = [test] 19 | -------------------------------------------------------------------------------- /configs/_base_/det_datasets/icdar2015.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'IcdarDataset' 2 | data_root = 'data/icdar2015' 3 | 4 | train = dict( 5 | type=dataset_type, 6 | ann_file=f'{data_root}/instances_training.json', 7 | img_prefix=f'{data_root}/imgs', 8 | pipeline=None) 9 | 10 | test = dict( 11 | type=dataset_type, 12 | ann_file=f'{data_root}/instances_test.json', 13 | img_prefix=f'{data_root}/imgs', 14 | pipeline=None) 15 | 16 | train_list = [train] 17 | 18 | test_list = [test] 19 | -------------------------------------------------------------------------------- /custom_utils/module_list_v100.txt: -------------------------------------------------------------------------------- 1 | ipdb 0.13.9 2 | mmcv-full 1.3.8 3 | mmdet 2.14.0 4 | mmocr 0.4.0 /apdcephfs/share_887471/common/whua/code/ie_e2e 5 | opencv-python 4.5.4.60 6 | timm 0.4.5 7 | torch 1.7.0+cu101 8 | torchvision 0.8.0 9 | tqdm 4.50.0 10 | -------------------------------------------------------------------------------- /mmocr/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .hmean import eval_hmean 3 | from .hmean_ic13 import eval_hmean_ic13 4 | from .hmean_iou import eval_hmean_iou 5 | from .kie_metric import compute_f1_score 6 | from .ner_metric import eval_ner_f1 7 | from .ocr_metric import eval_ocr_metric 8 | from .vie_metric import eval_vie_e2e 9 | 10 | __all__ = [ 11 | 'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean', 12 | 'compute_f1_score', 'eval_ner_f1', 'eval_vie_e2e' 13 | ] 14 | -------------------------------------------------------------------------------- /configs/vie_custom/e2e_ar_vie/sroie/readme.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/9 22:14 4 | # @Author : WeiHua 5 | 6 | """ 7 | default -> 8 | 1280 9 | 1e4 10 | feature-fuse-v1 11 | rec & kie weights = 10 12 | Adam optimizer 13 | shuffle instances' order 14 | 4 decoder layer 15 | no node-level modeling 16 | no text encoder 17 | with data augmentation 18 | 19 | v0 -> fuse without sum during up-sampling -> default 20 | v1 -> fuse with sum during up-sampling 21 | 22 | """ -------------------------------------------------------------------------------- /tests/data/kie_toy_dataset/class_list.txt: -------------------------------------------------------------------------------- 1 | 0 Ignore 2 | 1 Store_name_value 3 | 2 Store_name_key 4 | 3 Store_addr_value 5 | 4 Store_addr_key 6 | 5 Tel_value 7 | 6 Tel_key 8 | 7 Date_value 9 | 8 Date_key 10 | 9 Time_value 11 | 10 Time_key 12 | 11 Prod_item_value 13 | 12 Prod_item_key 14 | 13 Prod_quantity_value 15 | 14 Prod_quantity_key 16 | 15 Prod_price_value 17 | 16 Prod_price_key 18 | 17 Subtotal_value 19 | 18 Subtotal_key 20 | 19 Tax_value 21 | 20 Tax_key 22 | 21 Tips_value 23 | 22 Tips_key 24 | 23 Total_value 25 | 24 Total_key 26 | 25 Others -------------------------------------------------------------------------------- /mmocr/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import init_detector, model_inference 3 | from .test import single_gpu_test 4 | from .train import init_random_seed, train_detector 5 | from .utils import (disable_text_recog_aug_test, replace_image_to_tensor, 6 | tensor2grayimgs) 7 | 8 | __all__ = [ 9 | 'model_inference', 'train_detector', 'init_detector', 'init_random_seed', 10 | 'replace_image_to_tensor', 'disable_text_recog_aug_test', 11 | 'single_gpu_test', 'tensor2grayimgs' 12 | ] 13 | -------------------------------------------------------------------------------- /mmocr/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmocr 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMOCR'] = mmocr.__version__ + '+' + get_git_hash()[:7] 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print(f'{name}: {val}') 18 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_3.txt: -------------------------------------------------------------------------------- 1 | 58,80,191,71,194,114,61,123,fusionopolis 2 | 147,21,176,21,176,36,147,36,### 3 | 328,75,391,81,387,112,326,113,### 4 | 401,76,448,84,445,108,402,111,### 5 | 780,7,1015,6,1016,37,788,42,### 6 | 221,72,311,80,312,117,222,118,fusionopolis 7 | 113,19,144,19,144,33,113,33,### 8 | 257,28,308,28,308,57,257,57,### 9 | 140,120,196,115,195,129,141,133,### 10 | 86,176,110,177,112,189,89,196,### 11 | 101,193,129,185,132,198,103,204,### 12 | 223,175,244,150,294,183,235,197,### 13 | 140,239,174,232,176,247,142,256,### 14 | -------------------------------------------------------------------------------- /tests/test_models/test_ocr_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmocr.models.textrecog import SegHead 6 | 7 | 8 | def test_seg_head(): 9 | with pytest.raises(AssertionError): 10 | SegHead(num_classes='100') 11 | with pytest.raises(AssertionError): 12 | SegHead(num_classes=-1) 13 | 14 | seg_head = SegHead(num_classes=37) 15 | out_neck = (torch.rand(1, 128, 32, 32), ) 16 | out_head = seg_head(out_neck) 17 | assert out_head.shape == torch.Size([1, 37, 32, 32]) 18 | -------------------------------------------------------------------------------- /docs/zh_cn/datasets/ner.md: -------------------------------------------------------------------------------- 1 | # 命名实体识别(专名识别) 2 | 3 | ## 概览 4 | 5 | 命名实体识别任务的数据集,文件目录应按如下配置: 6 | 7 | ```text 8 | └── cluener2020 9 | ├── cluener_predict.json 10 | ├── dev.json 11 | ├── README.md 12 | ├── test.json 13 | ├── train.json 14 | └── vocab.txt 15 | 16 | ``` 17 | 18 | ## 准备步骤 19 | 20 | ### CLUENER2020 21 | 22 | - 下载并解压 [cluener_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip) 至 `cluener2020/`。 23 | 24 | - 下载 [vocab.txt](https://download.openmmlab.com/mmocr/data/cluener_public/vocab.txt) 然后将 `vocab.txt` 移动到 `cluener2020/` 文件夹下 25 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .abinet_vision_model import ABIVisionModel 3 | from .base_encoder import BaseEncoder 4 | from .channel_reduction_encoder import ChannelReductionEncoder 5 | from .nrtr_encoder import NRTREncoder 6 | from .sar_encoder import SAREncoder 7 | from .satrn_encoder import SatrnEncoder 8 | from .transformer import TransformerEncoder 9 | 10 | __all__ = [ 11 | 'SAREncoder', 'NRTREncoder', 'BaseEncoder', 'ChannelReductionEncoder', 12 | 'SatrnEncoder', 'TransformerEncoder', 'ABIVisionModel' 13 | ] 14 | -------------------------------------------------------------------------------- /mmocr/datasets/pipelines/textdet_targets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_textdet_targets import BaseTextDetTargets 3 | from .dbnet_targets import DBNetTargets 4 | from .drrg_targets import DRRGTargets 5 | from .fcenet_targets import FCENetTargets 6 | from .panet_targets import PANetTargets 7 | from .psenet_targets import PSENetTargets 8 | from .textsnake_targets import TextSnakeTargets 9 | 10 | __all__ = [ 11 | 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets', 12 | 'FCENetTargets', 'TextSnakeTargets', 'DRRGTargets' 13 | ] 14 | -------------------------------------------------------------------------------- /mmocr/models/textdet/postprocess/base_postprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | 4 | class BasePostprocessor: 5 | 6 | def __init__(self, text_repr_type='poly'): 7 | assert text_repr_type in ['poly', 'quad' 8 | ], f'Invalid text repr type {text_repr_type}' 9 | 10 | self.text_repr_type = text_repr_type 11 | 12 | def is_valid_instance(self, area, confidence, area_thresh, 13 | confidence_thresh): 14 | 15 | return bool(area >= area_thresh and confidence > confidence_thresh) 16 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/recognizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .abinet import ABINet 3 | from .base import BaseRecognizer 4 | from .crnn import CRNNNet 5 | from .encode_decode_recognizer import EncodeDecodeRecognizer 6 | from .nrtr import NRTR 7 | from .robust_scanner import RobustScanner 8 | from .sar import SARNet 9 | from .satrn import SATRN 10 | from .seg_recognizer import SegRecognizer 11 | 12 | __all__ = [ 13 | 'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR', 14 | 'SegRecognizer', 'RobustScanner', 'SATRN', 'ABINet' 15 | ] 16 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dbnet import DBNet 3 | from .drrg import DRRG 4 | from .fcenet import FCENet 5 | from .ocr_mask_rcnn import OCRMaskRCNN 6 | from .panet import PANet 7 | from .psenet import PSENet 8 | from .single_stage_text_detector import SingleStageTextDetector 9 | from .text_detector_mixin import TextDetectorMixin 10 | from .textsnake import TextSnake 11 | 12 | __all__ = [ 13 | 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', 14 | 'PANet', 'PSENet', 'TextSnake', 'FCENet', 'DRRG' 15 | ] 16 | -------------------------------------------------------------------------------- /mmocr/models/spotting/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/11 15:39 4 | # @Author : WeiHua 5 | 6 | from . import detectors, rois, spotters, recognizers, losses 7 | from .detectors import * # NOQA 8 | from .rois import * # NOQA 9 | from .spotters import * # NOQA 10 | from .recognizers import * # NOQA 11 | from .losses import * # NOQA 12 | from .modules import * 13 | from .backbone import * 14 | 15 | __all__ = ( 16 | detectors.__all__ + rois.__all__ + spotters.__all__ + recognizers.__all__ + 17 | losses.__all__ + modules.__all__ + backbone.__all__ 18 | ) 19 | -------------------------------------------------------------------------------- /tests/test_models/test_ocr_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmocr.models.textrecog.necks import FPNOCR 5 | 6 | 7 | def test_fpn_ocr(): 8 | in_s1 = torch.rand(1, 128, 32, 256) 9 | in_s2 = torch.rand(1, 256, 16, 128) 10 | in_s3 = torch.rand(1, 512, 8, 64) 11 | in_s4 = torch.rand(1, 512, 4, 32) 12 | 13 | fpn_ocr = FPNOCR(in_channels=[128, 256, 512, 512], out_channels=256) 14 | fpn_ocr.init_weights() 15 | fpn_ocr.train() 16 | 17 | out_neck = fpn_ocr((in_s1, in_s2, in_s3, in_s4)) 18 | assert out_neck[0].shape == torch.Size([1, 256, 32, 256]) 19 | -------------------------------------------------------------------------------- /tests/data/kie_toy_dataset/dict.txt: -------------------------------------------------------------------------------- 1 | / 2 | \ 3 | . 4 | $ 5 | £ 6 | € 7 | ¥ 8 | : 9 | - 10 | , 11 | * 12 | # 13 | ( 14 | ) 15 | % 16 | @ 17 | ! 18 | ' 19 | & 20 | = 21 | > 22 | + 23 | " 24 | × 25 | ? 26 | < 27 | [ 28 | ] 29 | _ 30 | 0 31 | 1 32 | 2 33 | 3 34 | 4 35 | 5 36 | 6 37 | 7 38 | 8 39 | 9 40 | a 41 | b 42 | c 43 | d 44 | e 45 | f 46 | g 47 | h 48 | i 49 | j 50 | k 51 | l 52 | m 53 | n 54 | o 55 | p 56 | q 57 | r 58 | s 59 | t 60 | u 61 | v 62 | w 63 | x 64 | y 65 | z 66 | A 67 | B 68 | C 69 | D 70 | E 71 | F 72 | G 73 | H 74 | I 75 | J 76 | K 77 | L 78 | M 79 | N 80 | O 81 | P 82 | Q 83 | R 84 | S 85 | T 86 | U 87 | V 88 | W 89 | X 90 | Y 91 | Z -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_7.txt: -------------------------------------------------------------------------------- 1 | 346,133,400,130,401,148,345,153,### 2 | 301,127,349,123,351,154,303,158,### 3 | 869,67,920,61,923,85,872,91,citi 4 | 886,144,934,141,932,157,884,160,smrt 5 | 634,106,812,86,816,104,634,121,### 6 | 418,117,469,112,471,143,420,148,### 7 | 634,124,781,107,783,123,635,135,### 8 | 634,138,844,117,843,141,636,155,### 9 | 468,124,518,117,525,138,468,143,### 10 | 301,181,532,162,530,182,301,201,### 11 | 296,157,396,147,400,165,300,174,### 12 | 420,151,526,136,527,154,421,163,### 13 | 617,251,657,250,656,282,616,285,### 14 | 695,246,738,243,738,276,698,278,### 15 | 739,241,760,241,763,260,742,262,### 16 | -------------------------------------------------------------------------------- /tests/test_tools/test_data_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Test orientation check and ignore method.""" 3 | 4 | import shutil 5 | import tempfile 6 | 7 | from mmocr.utils import drop_orientation 8 | 9 | 10 | def test_drop_orientation(): 11 | img_file = 'tests/data/test_img2.jpg' 12 | output_file = drop_orientation(img_file) 13 | assert output_file is img_file 14 | 15 | img_file = 'tests/data/test_img1.jpg' 16 | tmp_dir = tempfile.TemporaryDirectory() 17 | dst_file = shutil.copy(img_file, tmp_dir.name) 18 | output_file = drop_orientation(dst_file) 19 | assert output_file[-3:] == 'png' 20 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /mmocr/models/textdet/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_postprocessor import BasePostprocessor 3 | from .db_postprocessor import DBPostprocessor 4 | from .drrg_postprocessor import DRRGPostprocessor 5 | from .fce_postprocessor import FCEPostprocessor 6 | from .pan_postprocessor import PANPostprocessor 7 | from .pse_postprocessor import PSEPostprocessor 8 | from .textsnake_postprocessor import TextSnakePostprocessor 9 | 10 | __all__ = [ 11 | 'BasePostprocessor', 'PSEPostprocessor', 'PANPostprocessor', 12 | 'DBPostprocessor', 'DRRGPostprocessor', 'FCEPostprocessor', 13 | 'TextSnakePostprocessor' 14 | ] 15 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/crnn_tps.py: -------------------------------------------------------------------------------- 1 | # model 2 | label_convertor = dict( 3 | type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True) 4 | 5 | model = dict( 6 | type='CRNNNet', 7 | preprocessor=dict( 8 | type='TPSPreprocessor', 9 | num_fiducial=20, 10 | img_size=(32, 100), 11 | rectified_img_size=(32, 100), 12 | num_img_channel=1), 13 | backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), 14 | encoder=None, 15 | decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), 16 | loss=dict(type='CTCLoss'), 17 | label_convertor=label_convertor, 18 | pretrained=None) 19 | -------------------------------------------------------------------------------- /docs/en/datasets/ner.md: -------------------------------------------------------------------------------- 1 | # Named Entity Recognition 2 | 3 | ## Overview 4 | 5 | The structure of the named entity recognition dataset directory is organized as follows. 6 | 7 | ```text 8 | └── cluener2020 9 | ├── cluener_predict.json 10 | ├── dev.json 11 | ├── README.md 12 | ├── test.json 13 | ├── train.json 14 | └── vocab.txt 15 | ``` 16 | 17 | ## Preparation Steps 18 | 19 | ### CLUENER2020 20 | 21 | - Download and extract [cluener_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip) to `cluener2020/` 22 | - Download [vocab.txt](https://download.openmmlab.com/mmocr/data/cluener_public/vocab.txt) and move `vocab.txt` to `cluener2020/` 23 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .conv_layer import BasicBlock, Bottleneck 3 | from .dot_product_attention_layer import DotProductAttentionLayer 4 | from .lstm_layer import BidirectionalLSTM 5 | from .position_aware_layer import PositionAwareLayer 6 | from .robust_scanner_fusion_layer import RobustScannerFusionLayer 7 | from .satrn_layers import Adaptive2DPositionalEncoding, SatrnEncoderLayer 8 | 9 | __all__ = [ 10 | 'BidirectionalLSTM', 'Adaptive2DPositionalEncoding', 'BasicBlock', 11 | 'Bottleneck', 'RobustScannerFusionLayer', 'DotProductAttentionLayer', 12 | 'PositionAwareLayer', 'SatrnEncoderLayer' 13 | ] 14 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/layers/lstm_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | 7 | def __init__(self, nIn, nHidden, nOut): 8 | super().__init__() 9 | 10 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 11 | self.embedding = nn.Linear(nHidden * 2, nOut) 12 | 13 | def forward(self, input): 14 | recurrent, _ = self.rnn(input) 15 | T, b, h = recurrent.size() 16 | t_rec = recurrent.view(T * b, h) 17 | 18 | output = self.embedding(t_rec) # [T * b, nOut] 19 | output = output.view(T, b, -1) 20 | 21 | return output 22 | -------------------------------------------------------------------------------- /custom_utils/clean_pths.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/11 11:20 4 | # @Author : WeiHua 5 | import glob 6 | import shutil 7 | 8 | from tqdm import tqdm 9 | import os 10 | 11 | 12 | if __name__ == '__main__': 13 | dirs = glob.glob('/apdcephfs/share_887471/common/whua/logs/ie_ar_e2e_log/*') 14 | for dir_ in tqdm(dirs): 15 | pths = glob.glob(os.path.join(dir_, '*.pth')) 16 | for pth_ in pths: 17 | if 'epoch_' in pth_: 18 | num_epoch = pth_.split('/')[-1].split('.')[0] 19 | num_epoch = int(num_epoch[6:]) 20 | if num_epoch < 270: 21 | os.remove(pth_) 22 | -------------------------------------------------------------------------------- /configs/_base_/recog_datasets/ST_charbox_train.py: -------------------------------------------------------------------------------- 1 | # Text Recognition Training set, including: 2 | # Synthetic Datasets: SynthText (with character level boxes) 3 | 4 | train_img_root = 'data/mixture' 5 | 6 | train_img_prefix = f'{train_img_root}/SynthText' 7 | 8 | train_ann_file = f'{train_img_root}/SynthText/instances_train.txt' 9 | 10 | train = dict( 11 | type='OCRSegDataset', 12 | img_prefix=train_img_prefix, 13 | ann_file=train_ann_file, 14 | loader=dict( 15 | type='HardDiskLoader', 16 | repeat=1, 17 | parser=dict( 18 | type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), 19 | pipeline=None, 20 | test_mode=False) 21 | 22 | train_list = [train] 23 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/sar.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='AttnConvertor', dict_type='DICT90', with_unknown=True) 3 | 4 | model = dict( 5 | type='SARNet', 6 | backbone=dict(type='ResNet31OCR'), 7 | encoder=dict( 8 | type='SAREncoder', 9 | enc_bi_rnn=False, 10 | enc_do_rnn=0.1, 11 | enc_gru=False, 12 | ), 13 | decoder=dict( 14 | type='ParallelSARDecoder', 15 | enc_bi_rnn=False, 16 | dec_bi_rnn=False, 17 | dec_do_rnn=0, 18 | dec_gru=False, 19 | pred_dropout=0.1, 20 | d_k=512, 21 | pred_concat=True), 22 | loss=dict(type='SARLoss'), 23 | label_convertor=label_convertor, 24 | max_seq_len=30) 25 | -------------------------------------------------------------------------------- /docs/en/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /configs/_base_/recog_datasets/MJ_train.py: -------------------------------------------------------------------------------- 1 | # Text Recognition Training set, including: 2 | # Synthetic Datasets: Syn90k 3 | 4 | train_root = 'data/mixture/Syn90k' 5 | 6 | train_img_prefix = f'{train_root}/mnt/ramdisk/max/90kDICT32px' 7 | train_ann_file = f'{train_root}/label.lmdb' 8 | 9 | train = dict( 10 | type='OCRDataset', 11 | img_prefix=train_img_prefix, 12 | ann_file=train_ann_file, 13 | loader=dict( 14 | type='LmdbLoader', 15 | repeat=1, 16 | parser=dict( 17 | type='LineStrParser', 18 | keys=['filename', 'text'], 19 | keys_idx=[0, 1], 20 | separator=' ')), 21 | pipeline=None, 22 | test_mode=False) 23 | 24 | train_list = [train] 25 | -------------------------------------------------------------------------------- /docs/zh_cn/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /configs/_base_/det_models/panet_r50_fpem_ffm.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PANet', 3 | pretrained='torchvision://resnet50', 4 | backbone=dict( 5 | type='mmdet.ResNet', 6 | depth=50, 7 | num_stages=4, 8 | out_indices=(0, 1, 2, 3), 9 | frozen_stages=1, 10 | norm_cfg=dict(type='BN', requires_grad=True), 11 | norm_eval=True, 12 | style='caffe'), 13 | neck=dict(type='FPEM_FFM', in_channels=[256, 512, 1024, 2048]), 14 | bbox_head=dict( 15 | type='PANHead', 16 | in_channels=[128, 128, 128, 128], 17 | out_channels=6, 18 | loss=dict(type='PANLoss', speedup_bbox_thr=32), 19 | postprocessor=dict(type='PANPostprocessor', text_repr_type='poly')), 20 | train_cfg=None, 21 | test_cfg=None) 22 | -------------------------------------------------------------------------------- /tests/test_core/test_end2end_vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | from mmocr.core import det_recog_show_result 5 | 6 | 7 | def test_det_recog_show_result(): 8 | img = np.ones((100, 100, 3), dtype=np.uint8) * 255 9 | det_recog_res = { 10 | 'result': [{ 11 | 'box': [51, 88, 51, 62, 85, 62, 85, 88], 12 | 'box_score': 0.9417, 13 | 'text': 'hell', 14 | 'text_score': 0.8834 15 | }] 16 | } 17 | 18 | vis_img = det_recog_show_result(img, det_recog_res) 19 | 20 | assert vis_img.shape[0] == 100 21 | assert vis_img.shape[1] == 200 22 | assert vis_img.shape[2] == 3 23 | 24 | det_recog_res['result'][0]['text'] = '中文' 25 | det_recog_show_result(img, det_recog_res) 26 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/robust_scanner.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='AttnConvertor', dict_type='DICT90', with_unknown=True) 3 | 4 | hybrid_decoder = dict(type='SequenceAttentionDecoder') 5 | 6 | position_decoder = dict(type='PositionAttentionDecoder') 7 | 8 | model = dict( 9 | type='RobustScanner', 10 | backbone=dict(type='ResNet31OCR'), 11 | encoder=dict( 12 | type='ChannelReductionEncoder', 13 | in_channels=512, 14 | out_channels=128, 15 | ), 16 | decoder=dict( 17 | type='RobustScannerDecoder', 18 | dim_input=512, 19 | dim_model=128, 20 | hybrid_decoder=hybrid_decoder, 21 | position_decoder=position_decoder), 22 | loss=dict(type='SARLoss'), 23 | label_convertor=label_convertor, 24 | max_seq_len=30) 25 | -------------------------------------------------------------------------------- /.idea/ie_e2e.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 13 | 14 | 16 | -------------------------------------------------------------------------------- /configs/_base_/det_models/drrg_r50_fpn_unet.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='DRRG', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 11 | norm_eval=True, 12 | style='caffe'), 13 | neck=dict( 14 | type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32), 15 | bbox_head=dict( 16 | type='DRRGHead', 17 | in_channels=32, 18 | text_region_thr=0.3, 19 | center_region_thr=0.4, 20 | loss=dict(type='DRRGLoss'), 21 | postprocessor=dict(type='DRRGPostprocessor', link_thr=0.80))) 22 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import (backbones, convertors, decoders, encoders, fusers, heads, 3 | losses, necks, preprocessor, recognizer) 4 | 5 | from .backbones import * # NOQA 6 | from .convertors import * # NOQA 7 | from .decoders import * # NOQA 8 | from .encoders import * # NOQA 9 | from .heads import * # NOQA 10 | from .losses import * # NOQA 11 | from .necks import * # NOQA 12 | from .preprocessor import * # NOQA 13 | from .recognizer import * # NOQA 14 | from .fusers import * # NOQA 15 | 16 | __all__ = ( 17 | backbones.__all__ + convertors.__all__ + decoders.__all__ + 18 | encoders.__all__ + heads.__all__ + losses.__all__ + necks.__all__ + 19 | preprocessor.__all__ + recognizer.__all__ + fusers.__all__) 20 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/layers/robust_scanner_fusion_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class RobustScannerFusionLayer(BaseModule): 8 | 9 | def __init__(self, dim_model, dim=-1, init_cfg=None): 10 | super().__init__(init_cfg=init_cfg) 11 | 12 | self.dim_model = dim_model 13 | self.dim = dim 14 | 15 | self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) 16 | self.glu_layer = nn.GLU(dim=dim) 17 | 18 | def forward(self, x0, x1): 19 | assert x0.size() == x1.size() 20 | fusion_input = torch.cat([x0, x1], self.dim) 21 | output = self.linear_layer(fusion_input) 22 | output = self.glu_layer(output) 23 | 24 | return output 25 | -------------------------------------------------------------------------------- /configs/_base_/det_models/dbnet_r18_fpnc.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='DBNet', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=18, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), 11 | norm_eval=False, 12 | style='caffe'), 13 | neck=dict( 14 | type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256), 15 | bbox_head=dict( 16 | type='DBHead', 17 | in_channels=256, 18 | loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True), 19 | postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')), 20 | train_cfg=None, 21 | test_cfg=None) 22 | -------------------------------------------------------------------------------- /tests/data/toy_dataset/annotations/test/gt_img_6.txt: -------------------------------------------------------------------------------- 1 | 875,92,910,92,910,112,875,112,### 2 | 748,95,787,95,787,109,748,109,### 3 | 106,395,150,394,153,425,106,424,### 4 | 165,393,213,396,210,421,165,421,### 5 | 706,52,747,49,746,62,705,64,### 6 | 111,459,206,461,207,482,113,480,Reserve 7 | 831,9,894,9,894,22,831,22,### 8 | 641,456,693,454,693,467,641,469,CAUTION 9 | 839,32,891,32,891,47,839,47,### 10 | 788,46,831,46,831,59,788,59,### 11 | 830,95,872,95,872,106,830,106,### 12 | 921,92,952,92,952,111,921,111,### 13 | 968,40,1013,40,1013,53,968,53,### 14 | 1002,89,1031,89,1031,100,1002,100,### 15 | 1043,38,1098,38,1098,52,1043,52,### 16 | 1069,85,1138,85,1138,99,1069,99,### 17 | 1128,36,1178,36,1178,52,1128,52,### 18 | 1168,84,1200,84,1200,97,1168,97,### 19 | 1223,27,1259,27,1255,49,1219,49,### 20 | 1264,28,1279,28,1279,46,1264,46,### 21 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/26 16:20 4 | # @Author : WeiHua 5 | 6 | from .master_encoder import MasterEncoder 7 | from .align import feature_mask, db_like_fuser, db_fuser 8 | from .text_encoder import build_text_encoder 9 | from .text_decoder import build_text_decoder 10 | from .kie_modules import KIEDecoder, KIEDecoderSerial 11 | from .global_modeling import GlobalModeling 12 | from .cross_interact import InteractBlock, build_mimic 13 | from .kv_catcher import build_kv_catcher 14 | 15 | __all__ = ['MasterEncoder', 'feature_mask', 'db_like_fuser', 16 | 'db_fuser', 'build_text_encoder', 'build_text_decoder', 17 | 'KIEDecoder', 'GlobalModeling', 'KIEDecoderSerial', 18 | 'InteractBlock', 'build_mimic', 'build_kv_catcher'] 19 | -------------------------------------------------------------------------------- /configs/_base_/det_models/textsnake_r50_fpn_unet.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='TextSnake', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 11 | norm_eval=True, 12 | style='caffe'), 13 | neck=dict( 14 | type='FPN_UNet', in_channels=[256, 512, 1024, 2048], out_channels=32), 15 | bbox_head=dict( 16 | type='TextSnakeHead', 17 | in_channels=32, 18 | loss=dict(type='TextSnakeLoss'), 19 | postprocessor=dict( 20 | type='TextSnakePostprocessor', text_repr_type='poly')), 21 | train_cfg=None, 22 | test_cfg=None) 23 | -------------------------------------------------------------------------------- /configs/_base_/recog_models/seg.py: -------------------------------------------------------------------------------- 1 | label_convertor = dict( 2 | type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) 3 | 4 | model = dict( 5 | type='SegRecognizer', 6 | backbone=dict( 7 | type='ResNet31OCR', 8 | layers=[1, 2, 5, 3], 9 | channels=[32, 64, 128, 256, 512, 512], 10 | out_indices=[0, 1, 2, 3], 11 | stage4_pool_cfg=dict(kernel_size=2, stride=2), 12 | last_stage_pool=True), 13 | neck=dict( 14 | type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256), 15 | head=dict( 16 | type='SegHead', 17 | in_channels=256, 18 | upsample_param=dict(scale_factor=2.0, mode='nearest')), 19 | loss=dict( 20 | type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True), 21 | label_convertor=label_convertor) 22 | -------------------------------------------------------------------------------- /tests/test_utils/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | from mmcv.cnn.bricks import ConvModule 5 | 6 | from mmocr.utils import revert_sync_batchnorm 7 | 8 | 9 | def test_revert_sync_batchnorm(): 10 | conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu') 11 | conv_syncbn.train() 12 | x = torch.randn(1, 3, 10, 10) 13 | # Will raise an ValueError saying SyncBN does not run on CPU 14 | with pytest.raises(ValueError): 15 | y = conv_syncbn(x) 16 | conv_bn = revert_sync_batchnorm(conv_syncbn) 17 | y = conv_bn(x) 18 | assert y.shape == (1, 8, 9, 9) 19 | assert conv_bn.training == conv_syncbn.training 20 | conv_syncbn.eval() 21 | conv_bn = revert_sync_batchnorm(conv_syncbn) 22 | assert conv_bn.training == conv_syncbn.training 23 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /docs/zh_cn/datasets/kie.md: -------------------------------------------------------------------------------- 1 | # 关键信息提取 2 | 3 | ## 概览 4 | 5 | 关键信息提取任务的数据集,文件目录应按如下配置: 6 | 7 | ```text 8 | └── wildreceipt 9 | ├── class_list.txt 10 | ├── dict.txt 11 | ├── image_files 12 | ├── test.txt 13 | └── train.txt 14 | ``` 15 | 16 | ## 准备步骤 17 | 18 | ### WildReceipt 19 | 20 | - 下载并解压 [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar) 21 | 22 | ### WildReceiptOpenset 23 | 24 | - 准备好 [WildReceipt](#WildReceipt)。 25 | - 转换 WildReceipt 成 OpenSet 格式: 26 | ```bash 27 | # 你可以运行以下命令以获取更多可用参数: 28 | # python tools/data/kie/closeset_to_openset.py -h 29 | python tools/data/kie/closeset_to_openset.py data/wildreceipt/train.txt data/wildreceipt/openset_train.txt 30 | python tools/data/kie/closeset_to_openset.py data/wildreceipt/test.txt data/wildreceipt/openset_test.txt 31 | ``` 32 | :::{note} 33 | [这篇教程](../tutorials/kie_closeset_openset.md)里讲述了更多 CloseSet 和 OpenSet 数据格式之间的区别。 34 | ::: 35 | -------------------------------------------------------------------------------- /docs/en/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /mmocr/datasets/pipelines/textdet_targets/psenet_targets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.datasets.builder import PIPELINES 3 | 4 | from . import PANetTargets 5 | 6 | 7 | @PIPELINES.register_module() 8 | class PSENetTargets(PANetTargets): 9 | """Generate the ground truth targets of PSENet: Shape robust text detection 10 | with progressive scale expansion network. 11 | 12 | [https://arxiv.org/abs/1903.12473]. This code is partially adapted from 13 | https://github.com/whai362/PSENet. 14 | 15 | Args: 16 | shrink_ratio(tuple(float)): The ratios for shrinking text instances. 17 | max_shrink(int): The maximum shrinking distance. 18 | """ 19 | 20 | def __init__(self, 21 | shrink_ratio=(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4), 22 | max_shrink=20): 23 | super().__init__(shrink_ratio=shrink_ratio, max_shrink=max_shrink) 24 | -------------------------------------------------------------------------------- /docs/zh_cn/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | norecursedirs=tests/integration/* 9 | addopts=tests 10 | 11 | [yapf] 12 | based_on_style = pep8 13 | blank_line_before_nested_class_or_def = true 14 | split_before_expression_after_opening_paren = true 15 | split_penalty_import_names=0 16 | SPLIT_PENALTY_AFTER_OPENING_BRACKET=800 17 | 18 | [isort] 19 | line_length = 79 20 | multi_line_output = 0 21 | known_standard_library = setuptools 22 | known_first_party = mmocr 23 | known_third_party = PIL,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pytest,pytorch_sphinx_theme,rapidfuzz,requests,scipy,shapely,skimage,titlecase,torch,torchvision,ts,yaml 24 | no_lines_before = STDLIB,LOCALFOLDER 25 | default_section = THIRDPARTY 26 | 27 | [style] 28 | BASED_ON_STYLE = pep8 29 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 30 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 31 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/layers/dot_product_attention_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DotProductAttentionLayer(nn.Module): 8 | 9 | def __init__(self, dim_model=None): 10 | super().__init__() 11 | 12 | self.scale = dim_model**-0.5 if dim_model is not None else 1. 13 | 14 | def forward(self, query, key, value, mask=None): 15 | n, seq_len = mask.size() 16 | logits = torch.matmul(query.permute(0, 2, 1), key) * self.scale 17 | 18 | if mask is not None: 19 | mask = mask.view(n, 1, seq_len) 20 | logits = logits.masked_fill(mask, float('-inf')) 21 | 22 | weights = F.softmax(logits, dim=2) 23 | 24 | glimpse = torch.matmul(weights, value.transpose(1, 2)) 25 | 26 | glimpse = glimpse.permute(0, 2, 1).contiguous() 27 | 28 | return glimpse 29 | -------------------------------------------------------------------------------- /mmocr/models/common/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from mmocr.models.builder import LOSSES 6 | 7 | 8 | @LOSSES.register_module() 9 | class DiceLoss(nn.Module): 10 | 11 | def __init__(self, eps=1e-6): 12 | super().__init__() 13 | assert isinstance(eps, float) 14 | self.eps = eps 15 | 16 | def forward(self, pred, target, mask=None): 17 | 18 | pred = pred.contiguous().view(pred.size()[0], -1) 19 | target = target.contiguous().view(target.size()[0], -1) 20 | 21 | if mask is not None: 22 | mask = mask.contiguous().view(mask.size()[0], -1) 23 | pred = pred * mask 24 | target = target * mask 25 | 26 | a = torch.sum(pred * target) 27 | b = torch.sum(pred) 28 | c = torch.sum(target) 29 | d = (2 * a) / (b + c + self.eps) 30 | 31 | return 1 - d 32 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .abinet_language_decoder import ABILanguageDecoder 3 | from .abinet_vision_decoder import ABIVisionDecoder 4 | from .base_decoder import BaseDecoder 5 | from .crnn_decoder import CRNNDecoder 6 | from .nrtr_decoder import NRTRDecoder 7 | from .position_attention_decoder import PositionAttentionDecoder 8 | from .robust_scanner_decoder import RobustScannerDecoder 9 | from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder 10 | from .sar_decoder_with_bs import ParallelSARDecoderWithBS 11 | from .sequence_attention_decoder import SequenceAttentionDecoder 12 | 13 | __all__ = [ 14 | 'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder', 15 | 'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder', 16 | 'SequenceAttentionDecoder', 'PositionAttentionDecoder', 17 | 'RobustScannerDecoder', 'ABILanguageDecoder', 'ABIVisionDecoder' 18 | ] 19 | -------------------------------------------------------------------------------- /tools/data/utils/txt2lmdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmocr.utils import lmdb_converter 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | '--imglist', '-i', required=True, help='input imglist path') 11 | parser.add_argument( 12 | '--output', '-o', required=True, help='output lmdb path') 13 | parser.add_argument( 14 | '--batch_size', 15 | '-b', 16 | type=int, 17 | default=10000, 18 | help='processing batch size, default 10000') 19 | parser.add_argument( 20 | '--coding', 21 | '-c', 22 | default='utf8', 23 | help='bytes coding scheme, default utf8') 24 | opt = parser.parse_args() 25 | 26 | lmdb_converter( 27 | opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /configs/_base_/recog_datasets/seg_toy_data.py: -------------------------------------------------------------------------------- 1 | prefix = 'tests/data/ocr_char_ann_toy_dataset/' 2 | 3 | train = dict( 4 | type='OCRSegDataset', 5 | img_prefix=f'{prefix}/imgs', 6 | ann_file=f'{prefix}/instances_train.txt', 7 | loader=dict( 8 | type='HardDiskLoader', 9 | repeat=100, 10 | parser=dict( 11 | type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), 12 | pipeline=None, 13 | test_mode=True) 14 | 15 | test = dict( 16 | type='OCRDataset', 17 | img_prefix=f'{prefix}/imgs', 18 | ann_file=f'{prefix}/instances_test.txt', 19 | loader=dict( 20 | type='HardDiskLoader', 21 | repeat=1, 22 | parser=dict( 23 | type='LineStrParser', 24 | keys=['filename', 'text'], 25 | keys_idx=[0, 1], 26 | separator=' ')), 27 | pipeline=None, 28 | test_mode=True) 29 | 30 | train_list = [train] 31 | 32 | test_list = [test] 33 | -------------------------------------------------------------------------------- /configs/_base_/det_models/dbnet_r50dcnv2_fpnc.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='DBNet', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(0, 1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | norm_eval=False, 11 | style='pytorch', 12 | dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), 13 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 14 | stage_with_dcn=(False, True, True, True)), 15 | neck=dict( 16 | type='FPNC', in_channels=[256, 512, 1024, 2048], lateral_channels=256), 17 | bbox_head=dict( 18 | type='DBHead', 19 | in_channels=256, 20 | loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True), 21 | postprocessor=dict(type='DBPostprocessor', text_repr_type='quad')), 22 | train_cfg=None, 23 | test_cfg=None) 24 | -------------------------------------------------------------------------------- /mmocr/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import evaluation 3 | from .mask import extract_boundary, points2boundary, seg2boundary 4 | from .visualize import (det_recog_show_result, imshow_edge, imshow_node, 5 | imshow_pred_boundary, imshow_text_char_boundary, 6 | imshow_text_label, overlay_mask_img, show_feature, 7 | show_img_boundary, show_pred_gt) 8 | from .custom_visualize import imshow_e2e_result 9 | from .e2e_vie_utils import convert_vie_res 10 | 11 | from .evaluation import * # NOQA 12 | 13 | __all__ = [ 14 | 'points2boundary', 'seg2boundary', 'extract_boundary', 'overlay_mask_img', 15 | 'show_feature', 'show_img_boundary', 'show_pred_gt', 16 | 'imshow_pred_boundary', 'imshow_text_char_boundary', 'imshow_text_label', 17 | 'imshow_node', 'det_recog_show_result', 'imshow_edge', 'imshow_e2e_result', 18 | 'convert_vie_res' 19 | ] 20 | __all__ += evaluation.__all__ 21 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/decoders/base_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.runner import BaseModule 3 | 4 | from mmocr.models.builder import DECODERS 5 | 6 | 7 | @DECODERS.register_module() 8 | class BaseDecoder(BaseModule): 9 | """Base decoder class for text recognition.""" 10 | 11 | def __init__(self, init_cfg=None, **kwargs): 12 | super().__init__(init_cfg=init_cfg) 13 | 14 | def forward_train(self, feat, out_enc, targets_dict, img_metas): 15 | raise NotImplementedError 16 | 17 | def forward_test(self, feat, out_enc, img_metas): 18 | raise NotImplementedError 19 | 20 | def forward(self, 21 | feat, 22 | out_enc, 23 | targets_dict=None, 24 | img_metas=None, 25 | train_mode=True): 26 | self.train_mode = train_mode 27 | if train_mode: 28 | return self.forward_train(feat, out_enc, targets_dict, img_metas) 29 | 30 | return self.forward_test(feat, out_enc, img_metas) 31 | -------------------------------------------------------------------------------- /configs/_base_/recog_pipelines/crnn_pipeline.py: -------------------------------------------------------------------------------- 1 | img_norm_cfg = dict(mean=[127], std=[127]) 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', color_type='grayscale'), 5 | dict( 6 | type='ResizeOCR', 7 | height=32, 8 | min_width=100, 9 | max_width=100, 10 | keep_aspect_ratio=False), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='DefaultFormatBundle'), 13 | dict( 14 | type='Collect', 15 | keys=['img'], 16 | meta_keys=['filename', 'resize_shape', 'text', 'valid_ratio']), 17 | ] 18 | test_pipeline = [ 19 | dict(type='LoadImageFromFile', color_type='grayscale'), 20 | dict( 21 | type='ResizeOCR', 22 | height=32, 23 | min_width=32, 24 | max_width=None, 25 | keep_aspect_ratio=True), 26 | dict(type='Normalize', **img_norm_cfg), 27 | dict(type='DefaultFormatBundle'), 28 | dict( 29 | type='Collect', 30 | keys=['img'], 31 | meta_keys=['filename', 'resize_shape', 'valid_ratio']), 32 | ] 33 | -------------------------------------------------------------------------------- /mmocr/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Use `get_logger` method in mmcv to get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmpose". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | return get_logger(__name__.split('.')[0], log_file, log_level) 26 | -------------------------------------------------------------------------------- /mmocr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import common, kie, textdet, textrecog 3 | from .builder import (BACKBONES, CONVERTORS, DECODERS, DETECTORS, ENCODERS, 4 | HEADS, LOSSES, NECKS, PREPROCESSOR, build_backbone, 5 | build_convertor, build_decoder, build_detector, 6 | build_encoder, build_loss, build_preprocessor) 7 | 8 | from .common import * # NOQA 9 | from .kie import * # NOQA 10 | from .ner import * # NOQA 11 | from .textdet import * # NOQA 12 | from .textrecog import * # NOQA 13 | 14 | __all__ = [ 15 | 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone', 16 | 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS', 17 | 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder', 18 | 'build_preprocessor' 19 | ] 20 | __all__ += common.__all__ + kie.__all__ + textdet.__all__ + textrecog.__all__ 21 | 22 | from . import spotting 23 | from .spotting import * # NOQA 24 | 25 | __all__ += spotting.__all__ 26 | -------------------------------------------------------------------------------- /configs/_base_/recog_datasets/ST_MJ_train.py: -------------------------------------------------------------------------------- 1 | # Text Recognition Training set, including: 2 | # Synthetic Datasets: SynthText, Syn90k 3 | 4 | train_root = 'data/mixture' 5 | 6 | train_img_prefix1 = f'{train_root}/Syn90k/mnt/ramdisk/max/90kDICT32px' 7 | train_ann_file1 = f'{train_root}/Syn90k/label.lmdb' 8 | 9 | train1 = dict( 10 | type='OCRDataset', 11 | img_prefix=train_img_prefix1, 12 | ann_file=train_ann_file1, 13 | loader=dict( 14 | type='LmdbLoader', 15 | repeat=1, 16 | parser=dict( 17 | type='LineStrParser', 18 | keys=['filename', 'text'], 19 | keys_idx=[0, 1], 20 | separator=' ')), 21 | pipeline=None, 22 | test_mode=False) 23 | 24 | train_img_prefix2 = f'{train_root}/SynthText/' + \ 25 | 'synthtext/SynthText_patch_horizontal' 26 | train_ann_file2 = f'{train_root}/SynthText/label.lmdb' 27 | 28 | train2 = {key: value for key, value in train1.items()} 29 | train2['img_prefix'] = train_img_prefix2 30 | train2['ann_file'] = train_ann_file2 31 | 32 | train_list = [train1, train2] 33 | -------------------------------------------------------------------------------- /docs/zh_cn/merge_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # gather models 4 | sed -e '$a\\n' -s ../../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# 关键信息提取模型' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md 5 | sed -e '$a\\n' -s ../../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# 文本检测模型' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md 6 | sed -e '$a\\n' -s ../../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# 文本识别模型' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md 7 | sed -e '$a\\n' -s ../../configs/ner/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# 命名实体识别模型' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >ner_models.md 8 | 9 | # replace special symbols in demo.md 10 | cp ../../demo/README.md demo.md 11 | sed -i 's/:heavy_check_mark:/Yes/g' demo.md && sed -i 's/:x:/No/g' demo.md 12 | -------------------------------------------------------------------------------- /mmocr/models/ner/utils/activations.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Adapted from https://github.com/lonePatient/BERT-NER-Pytorch 3 | # Original licence: Copyright (c) 2020 Weitang Liu, under the MIT License. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from mmocr.models.builder import ACTIVATION_LAYERS 12 | 13 | 14 | @ACTIVATION_LAYERS.register_module() 15 | class GeluNew(nn.Module): 16 | """Implementation of the gelu activation function currently in Google Bert 17 | repo (identical to OpenAI GPT). 18 | 19 | Also see https://arxiv.org/abs/1606.08415 20 | """ 21 | 22 | def forward(self, x): 23 | """Forward function. 24 | 25 | Args: 26 | x (torch.Tensor): The input tensor. 27 | 28 | Returns: 29 | torch.Tensor: Activated tensor. 30 | """ 31 | return 0.5 * x * (1 + torch.tanh( 32 | math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | -------------------------------------------------------------------------------- /configs/_base_/recog_pipelines/crnn_tps_pipeline.py: -------------------------------------------------------------------------------- 1 | img_norm_cfg = dict(mean=[0.5], std=[0.5]) 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', color_type='grayscale'), 5 | dict( 6 | type='ResizeOCR', 7 | height=32, 8 | min_width=100, 9 | max_width=100, 10 | keep_aspect_ratio=False), 11 | dict(type='ToTensorOCR'), 12 | dict(type='NormalizeOCR', **img_norm_cfg), 13 | dict( 14 | type='Collect', 15 | keys=['img'], 16 | meta_keys=[ 17 | 'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio' 18 | ]), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile', color_type='grayscale'), 22 | dict( 23 | type='ResizeOCR', 24 | height=32, 25 | min_width=32, 26 | max_width=100, 27 | keep_aspect_ratio=False), 28 | dict(type='ToTensorOCR'), 29 | dict(type='NormalizeOCR', **img_norm_cfg), 30 | dict( 31 | type='Collect', 32 | keys=['img'], 33 | meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio']), 34 | ] 35 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/dbnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .single_stage_text_detector import SingleStageTextDetector 4 | from .text_detector_mixin import TextDetectorMixin 5 | 6 | 7 | @DETECTORS.register_module() 8 | class DBNet(TextDetectorMixin, SingleStageTextDetector): 9 | """The class for implementing DBNet text detector: Real-time Scene Text 10 | Detection with Differentiable Binarization. 11 | 12 | [https://arxiv.org/abs/1911.08947]. 13 | """ 14 | 15 | def __init__(self, 16 | backbone, 17 | neck, 18 | bbox_head, 19 | train_cfg=None, 20 | test_cfg=None, 21 | pretrained=None, 22 | show_score=False, 23 | init_cfg=None): 24 | SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, 25 | train_cfg, test_cfg, pretrained, 26 | init_cfg) 27 | TextDetectorMixin.__init__(self, show_score) 28 | -------------------------------------------------------------------------------- /configs/_base_/recog_pipelines/nrtr_pipeline.py: -------------------------------------------------------------------------------- 1 | img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 2 | train_pipeline = [ 3 | dict(type='LoadImageFromFile'), 4 | dict( 5 | type='ResizeOCR', 6 | height=32, 7 | min_width=32, 8 | max_width=160, 9 | keep_aspect_ratio=True, 10 | width_downsample_ratio=0.25), 11 | dict(type='ToTensorOCR'), 12 | dict(type='NormalizeOCR', **img_norm_cfg), 13 | dict( 14 | type='Collect', 15 | keys=['img'], 16 | meta_keys=[ 17 | 'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio' 18 | ]), 19 | ] 20 | 21 | test_pipeline = [ 22 | dict(type='LoadImageFromFile'), 23 | dict( 24 | type='ResizeOCR', 25 | height=32, 26 | min_width=32, 27 | max_width=160, 28 | keep_aspect_ratio=True), 29 | dict(type='ToTensorOCR'), 30 | dict(type='NormalizeOCR', **img_norm_cfg), 31 | dict( 32 | type='Collect', 33 | keys=['img'], 34 | meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio']) 35 | ] 36 | -------------------------------------------------------------------------------- /mmocr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset 3 | 4 | from . import utils 5 | from .base_dataset import BaseDataset 6 | from .icdar_dataset import IcdarDataset 7 | from .kie_dataset import KIEDataset 8 | from .ner_dataset import NerDataset 9 | from .ocr_dataset import OCRDataset 10 | from .ocr_seg_dataset import OCRSegDataset 11 | from .openset_kie_dataset import OpensetKIEDataset 12 | from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets 13 | from .text_det_dataset import TextDetDataset 14 | from .uniform_concat_dataset import UniformConcatDataset 15 | from .vie_e2e_dataset import VIEE2EDataset 16 | 17 | from .utils import * # NOQA 18 | 19 | __all__ = [ 20 | 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', 21 | 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', 22 | 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets', 23 | 'NerDataset', 'UniformConcatDataset', 'OpensetKIEDataset', 24 | 'VIEE2EDataset' 25 | ] 26 | 27 | __all__ += utils.__all__ 28 | -------------------------------------------------------------------------------- /mmocr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | from .box_util import is_on_same_line, stitch_boxes_into_lines 5 | from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type, 6 | is_type_list, valid_boundary) 7 | from .collect_env import collect_env 8 | from .data_convert_util import convert_annotations 9 | from .fileio import list_from_file, list_to_file 10 | from .img_util import drop_orientation, is_not_png 11 | from .lmdb_util import lmdb_converter 12 | from .logger import get_root_logger 13 | from .model import revert_sync_batchnorm 14 | from .string_util import StringStrip 15 | 16 | __all__ = [ 17 | 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 18 | 'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist', 19 | 'valid_boundary', 'lmdb_converter', 'drop_orientation', 20 | 'convert_annotations', 'is_not_png', 'list_to_file', 'list_from_file', 21 | 'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip', 22 | 'revert_sync_batchnorm' 23 | ] 24 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/psenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .single_stage_text_detector import SingleStageTextDetector 4 | from .text_detector_mixin import TextDetectorMixin 5 | 6 | 7 | @DETECTORS.register_module() 8 | class PSENet(TextDetectorMixin, SingleStageTextDetector): 9 | """The class for implementing PSENet text detector: Shape Robust Text 10 | Detection with Progressive Scale Expansion Network. 11 | 12 | [https://arxiv.org/abs/1806.02559]. 13 | """ 14 | 15 | def __init__(self, 16 | backbone, 17 | neck, 18 | bbox_head, 19 | train_cfg=None, 20 | test_cfg=None, 21 | pretrained=None, 22 | show_score=False, 23 | init_cfg=None): 24 | SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, 25 | train_cfg, test_cfg, pretrained, 26 | init_cfg) 27 | TextDetectorMixin.__init__(self, show_score) 28 | -------------------------------------------------------------------------------- /configs/_base_/det_models/fcenet_r50_fpn.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='FCENet', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 11 | norm_eval=False, 12 | style='pytorch'), 13 | neck=dict( 14 | type='mmdet.FPN', 15 | in_channels=[512, 1024, 2048], 16 | out_channels=256, 17 | add_extra_convs='on_output', 18 | num_outs=3, 19 | relu_before_extra_convs=True, 20 | act_cfg=None), 21 | bbox_head=dict( 22 | type='FCEHead', 23 | in_channels=256, 24 | scales=(8, 16, 32), 25 | fourier_degree=5, 26 | loss=dict(type='FCELoss', num_sample=50), 27 | postprocessor=dict( 28 | type='FCEPostprocessor', 29 | text_repr_type='quad', 30 | num_reconstr_points=50, 31 | alpha=1.2, 32 | beta=1.0, 33 | score_thr=0.3))) 34 | -------------------------------------------------------------------------------- /mmocr/models/common/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class FocalLoss(nn.Module): 8 | """Multi-class Focal loss implementation. 9 | 10 | Args: 11 | gamma (float): The larger the gamma, the smaller 12 | the loss weight of easier samples. 13 | weight (float): A manual rescaling weight given to each 14 | class. 15 | ignore_index (int): Specifies a target value that is ignored 16 | and does not contribute to the input gradient. 17 | """ 18 | 19 | def __init__(self, gamma=2, weight=None, ignore_index=-100): 20 | super().__init__() 21 | self.gamma = gamma 22 | self.weight = weight 23 | self.ignore_index = ignore_index 24 | 25 | def forward(self, input, target): 26 | logit = F.log_softmax(input, dim=1) 27 | pt = torch.exp(logit) 28 | logit = (1 - pt)**self.gamma * logit 29 | loss = F.nll_loss( 30 | logit, target, self.weight, ignore_index=self.ignore_index) 31 | return loss 32 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/panet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .single_stage_text_detector import SingleStageTextDetector 4 | from .text_detector_mixin import TextDetectorMixin 5 | 6 | 7 | @DETECTORS.register_module() 8 | class PANet(TextDetectorMixin, SingleStageTextDetector): 9 | """The class for implementing PANet text detector: 10 | 11 | Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel 12 | Aggregation Network [https://arxiv.org/abs/1908.05900]. 13 | """ 14 | 15 | def __init__(self, 16 | backbone, 17 | neck, 18 | bbox_head, 19 | train_cfg=None, 20 | test_cfg=None, 21 | pretrained=None, 22 | show_score=False, 23 | init_cfg=None): 24 | SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, 25 | train_cfg, test_cfg, pretrained, 26 | init_cfg) 27 | TextDetectorMixin.__init__(self, show_score) 28 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/textsnake.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .single_stage_text_detector import SingleStageTextDetector 4 | from .text_detector_mixin import TextDetectorMixin 5 | 6 | 7 | @DETECTORS.register_module() 8 | class TextSnake(TextDetectorMixin, SingleStageTextDetector): 9 | """The class for implementing TextSnake text detector: TextSnake: A 10 | Flexible Representation for Detecting Text of Arbitrary Shapes. 11 | 12 | [https://arxiv.org/abs/1807.01544] 13 | """ 14 | 15 | def __init__(self, 16 | backbone, 17 | neck, 18 | bbox_head, 19 | train_cfg=None, 20 | test_cfg=None, 21 | pretrained=None, 22 | show_score=False, 23 | init_cfg=None): 24 | SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, 25 | train_cfg, test_cfg, pretrained, 26 | init_cfg) 27 | TextDetectorMixin.__init__(self, show_score) 28 | -------------------------------------------------------------------------------- /mmocr/core/evaluation/kie_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def compute_f1_score(preds, gts, ignores=[]): 6 | """Compute the F1-score of prediction. 7 | 8 | Args: 9 | preds (Tensor): The predicted probability NxC map 10 | with N and C being the sample number and class 11 | number respectively. 12 | gts (Tensor): The ground truth vector of size N. 13 | ignores (list): The index set of classes that are ignored when 14 | reporting results. 15 | Note: all samples are participated in computing. 16 | 17 | Returns: 18 | The numpy list of f1-scores of valid classes. 19 | """ 20 | C = preds.size(1) 21 | classes = torch.LongTensor(sorted(set(range(C)) - set(ignores))) 22 | hist = torch.bincount( 23 | gts * C + preds.argmax(1), minlength=C**2).view(C, C).float() 24 | diag = torch.diag(hist) 25 | recalls = diag / hist.sum(1).clamp(min=1) 26 | precisions = diag / hist.sum(0).clamp(min=1) 27 | f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8) 28 | return f1[classes].cpu().numpy() 29 | -------------------------------------------------------------------------------- /configs/_base_/det_datasets/toy_data.py: -------------------------------------------------------------------------------- 1 | root = 'tests/data/toy_dataset' 2 | 3 | # dataset with type='TextDetDataset' 4 | train1 = dict( 5 | type='TextDetDataset', 6 | img_prefix=f'{root}/imgs', 7 | ann_file=f'{root}/instances_test.txt', 8 | loader=dict( 9 | type='HardDiskLoader', 10 | repeat=4, 11 | parser=dict( 12 | type='LineJsonParser', 13 | keys=['file_name', 'height', 'width', 'annotations'])), 14 | pipeline=None, 15 | test_mode=False) 16 | 17 | # dataset with type='IcdarDataset' 18 | train2 = dict( 19 | type='IcdarDataset', 20 | ann_file=f'{root}/instances_test.json', 21 | img_prefix=f'{root}/imgs', 22 | pipeline=None) 23 | 24 | test = dict( 25 | type='TextDetDataset', 26 | img_prefix=f'{root}/imgs', 27 | ann_file=f'{root}/instances_test.txt', 28 | loader=dict( 29 | type='HardDiskLoader', 30 | repeat=1, 31 | parser=dict( 32 | type='LineJsonParser', 33 | keys=['file_name', 'height', 'width', 'annotations'])), 34 | pipeline=None, 35 | test_mode=True) 36 | 37 | train_list = [train1, train2] 38 | 39 | test_list = [test] 40 | -------------------------------------------------------------------------------- /docs/en/merge_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # gather models 4 | sed -e '$a\\n' -s ../../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Key Information Extraction Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md 5 | sed -e '$a\\n' -s ../../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md 6 | sed -e '$a\\n' -s ../../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md 7 | sed -e '$a\\n' -s ../../configs/ner/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Named Entity Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >ner_models.md 8 | 9 | # replace special symbols in demo.md 10 | cp ../../demo/README.md demo.md 11 | sed -i 's/:heavy_check_mark:/Yes/g' demo.md && sed -i 's/:x:/No/g' demo.md 12 | -------------------------------------------------------------------------------- /configs/_base_/recog_datasets/ST_MJ_alphanumeric_train.py: -------------------------------------------------------------------------------- 1 | # Text Recognition Training set, including: 2 | # Synthetic Datasets: SynthText, Syn90k 3 | # Both annotations are filtered so that 4 | # only alphanumeric terms are left 5 | 6 | train_root = 'data/mixture' 7 | 8 | train_img_prefix1 = f'{train_root}/Syn90k/mnt/ramdisk/max/90kDICT32px' 9 | train_ann_file1 = f'{train_root}/Syn90k/label.lmdb' 10 | 11 | train1 = dict( 12 | type='OCRDataset', 13 | img_prefix=train_img_prefix1, 14 | ann_file=train_ann_file1, 15 | loader=dict( 16 | type='LmdbLoader', 17 | repeat=1, 18 | parser=dict( 19 | type='LineStrParser', 20 | keys=['filename', 'text'], 21 | keys_idx=[0, 1], 22 | separator=' ')), 23 | pipeline=None, 24 | test_mode=False) 25 | 26 | train_img_prefix2 = f'{train_root}/SynthText/' + \ 27 | 'synthtext/SynthText_patch_horizontal' 28 | train_ann_file2 = f'{train_root}/SynthText/alphanumeric_label.lmdb' 29 | 30 | train2 = {key: value for key, value in train1.items()} 31 | train2['img_prefix'] = train_img_prefix2 32 | train2['ann_file'] = train_ann_file2 33 | 34 | train_list = [train1, train2] 35 | -------------------------------------------------------------------------------- /tests/test_utils/test_version_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr import digit_version 3 | 4 | 5 | def test_digit_version(): 6 | assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) 7 | assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) 8 | assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) 9 | assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) 10 | assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) 11 | assert digit_version('1.0') == digit_version('1.0.0') 12 | assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') 13 | assert digit_version('1.0.0dev') < digit_version('1.0.0a') 14 | assert digit_version('1.0.0a') < digit_version('1.0.0a1') 15 | assert digit_version('1.0.0a') < digit_version('1.0.0b') 16 | assert digit_version('1.0.0b') < digit_version('1.0.0rc') 17 | assert digit_version('1.0.0rc1') < digit_version('1.0.0') 18 | assert digit_version('1.0.0') < digit_version('1.0.0post') 19 | assert digit_version('1.0.0post') < digit_version('1.0.0post1') 20 | assert digit_version('v1') == (1, 0, 0, 0, 0, 0) 21 | assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0) 22 | -------------------------------------------------------------------------------- /docs/en/datasets/kie.md: -------------------------------------------------------------------------------- 1 | # Key Information Extraction 2 | 3 | ## Overview 4 | 5 | The structure of the key information extraction dataset directory is organized as follows. 6 | 7 | ```text 8 | └── wildreceipt 9 | ├── class_list.txt 10 | ├── dict.txt 11 | ├── image_files 12 | ├── openset_train.txt 13 | ├── openset_test.txt 14 | ├── test.txt 15 | └── train.txt 16 | ``` 17 | 18 | ## Preparation Steps 19 | 20 | ### WildReceipt 21 | 22 | - Just download and extract [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar). 23 | 24 | ### WildReceiptOpenset 25 | 26 | - Step0: have [WildReceipt](#WildReceipt) prepared. 27 | - Step1: Convert annotation files to OpenSet format: 28 | ```bash 29 | # You may find more available arguments by running 30 | # python tools/data/kie/closeset_to_openset.py -h 31 | python tools/data/kie/closeset_to_openset.py data/wildreceipt/train.txt data/wildreceipt/openset_train.txt 32 | python tools/data/kie/closeset_to_openset.py data/wildreceipt/test.txt data/wildreceipt/openset_test.txt 33 | ``` 34 | :::{note} 35 | You can learn more about the key differences between CloseSet and OpenSet annotations in our [tutorial](../tutorials/kie_closeset_openset.md). 36 | ::: 37 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/layers/position_aware_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | 5 | class PositionAwareLayer(nn.Module): 6 | 7 | def __init__(self, dim_model, rnn_layers=2): 8 | super().__init__() 9 | 10 | self.dim_model = dim_model 11 | 12 | self.rnn = nn.LSTM( 13 | input_size=dim_model, 14 | hidden_size=dim_model, 15 | num_layers=rnn_layers, 16 | batch_first=True) 17 | 18 | self.mixer = nn.Sequential( 19 | nn.Conv2d( 20 | dim_model, dim_model, kernel_size=3, stride=1, padding=1), 21 | nn.ReLU(True), 22 | nn.Conv2d( 23 | dim_model, dim_model, kernel_size=3, stride=1, padding=1)) 24 | 25 | def forward(self, img_feature): 26 | n, c, h, w = img_feature.size() 27 | 28 | rnn_input = img_feature.permute(0, 2, 3, 1).contiguous() 29 | rnn_input = rnn_input.view(n * h, w, c) 30 | rnn_output, _ = self.rnn(rnn_input) 31 | rnn_output = rnn_output.view(n, h, w, c) 32 | rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous() 33 | 34 | out = self.mixer(rnn_output) 35 | 36 | return out 37 | -------------------------------------------------------------------------------- /mmocr/utils/fileio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | 4 | import mmcv 5 | 6 | 7 | def list_to_file(filename, lines): 8 | """Write a list of strings to a text file. 9 | 10 | Args: 11 | filename (str): The output filename. It will be created/overwritten. 12 | lines (list(str)): Data to be written. 13 | """ 14 | mmcv.mkdir_or_exist(os.path.dirname(filename)) 15 | with open(filename, 'w', encoding='utf-8') as fw: 16 | for line in lines: 17 | fw.write(f'{line}\n') 18 | 19 | 20 | def list_from_file(filename, encoding='utf-8'): 21 | """Load a text file and parse the content as a list of strings. The 22 | trailing "\\r" and "\\n" of each line will be removed. 23 | 24 | Note: 25 | This will be replaced by mmcv's version after it supports encoding. 26 | 27 | Args: 28 | filename (str): Filename. 29 | encoding (str): Encoding used to open the file. Default utf-8. 30 | 31 | Returns: 32 | list[str]: A list of strings. 33 | """ 34 | item_list = [] 35 | with open(filename, 'r', encoding=encoding) as f: 36 | for line in f: 37 | item_list.append(line.rstrip('\n\r')) 38 | return item_list 39 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/synthtext.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/5 16:38 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/code/MaskTextSpotterV3-master/datasets/synthtext/SynthText/' 8 | ann_root = '/home/whua/code/MaskTextSpotterV3-master/datasets/synthtext/e2e_format/' 9 | 10 | loader = dict( 11 | type='HardDiskLoader', 12 | repeat=1, 13 | parser=dict( 14 | type='LineJsonParser', 15 | keys=['file_name', 'height', 'width', 'annotations'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{ann_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{ann_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=None, 26 | data_type='ocr', 27 | max_seq_len=60) 28 | 29 | test = dict( 30 | type=dataset_type, 31 | ann_file=f'{ann_root}/train.txt', 32 | loader=loader, 33 | dict_file=f'{ann_root}/dict.json', 34 | img_prefix=data_root, 35 | pipeline=None, 36 | test_mode=False, 37 | class_file=None, 38 | data_type='ocr', 39 | max_seq_len=60) 40 | 41 | train_list = [train] 42 | 43 | test_list = [test] 44 | -------------------------------------------------------------------------------- /custom_utils/dataset/prepare_pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/3/1 19:34 4 | # @Author : WeiHua 5 | import json 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == '__main__': 10 | corpus = [] 11 | with open('/apdcephfs/share_887471/common/whua/dataset/ie_e2e/SynthText/train.txt', 'r', encoding='utf-8') as f: 12 | for line_ in tqdm(f.readlines()): 13 | if line_.strip() == "": 14 | continue 15 | info_ = json.loads(line_.strip()) 16 | for anno_ in info_['annotations']: 17 | corpus.append(anno_['text']) 18 | with open('/apdcephfs/share_887471/common/ocr_benchmark/benchmark/MJSynth/annotation.txt', 'r', encoding='utf-8') as f: 19 | for line_ in tqdm(f.readlines()): 20 | if line_.strip() == "": 21 | continue 22 | info_ = line_.strip().split(' ') 23 | if len(info_) != 2: 24 | print(f"invalid line:{line_}, pass it") 25 | corpus.append(info_[1]) 26 | with open('/apdcephfs/share_887471/common/whua/st_mj_corpus.txt', 'w', encoding='utf-8') as saver: 27 | for line_ in tqdm(corpus): 28 | saver.write(line_+'\n') 29 | 30 | 31 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/custom_chn_v2_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/data/whua/dataset/custom_chn_synth/merged' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/annotation.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=None, 26 | data_type='ocr', 27 | max_seq_len=80, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | """ 32 | avg:11.705568749328094, max:25, min:1 33 | avg_height:441.12469581024425, avg_width:1158.8753041897558 34 | ext key:[] 35 | max_len:74 36 | avg_ins_height:69.38160162741796, avg_ins_width:264.441446308023 37 | Total instance num: 2395451, total image num: 204642 38 | """ 39 | 40 | 41 | train_list = [train] 42 | 43 | test_list = [train] 44 | -------------------------------------------------------------------------------- /tools/pretrain_kjf.sh: -------------------------------------------------------------------------------- 1 | #DDP train 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port=10011 \ 3 | /home/jfkuang/code/ie_e2e/tools/train.py \ 4 | /home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_ocr_pretrain/custom_dataset/synth_chn_default_dp02_rc_lr2e4_dpp02_1803_30epoch_kjf.py \ 5 | --work-dir=/home/jfkuang/logs/ie_e2e_log/ephoie_pretrain_chn_bs8_960_higher200_2e4 --launcher pytorch --gpus 8 \ 6 | --deterministic --resume-from=/home/jfkuang/logs/ie_e2e_log/ephoie_pretrain_chn_bs8_960_higher200_2e4/latest.pth 7 | 8 | #train 9 | #CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python \ 10 | #/home/jfkuang/code/ie_e2e/tools/train.py \ 11 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_ocr_pretrain/custom_dataset/synth_chn_default_dp02_rc_rr_cj_blsh_lr4e4_dpp02_cloud_kjf.py \ 12 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/ephoie_pretrain_chn_bs8_480_lower100 --gpus 8 13 | 14 | #single try 15 | #CUDA_VISIBLE_DEVICES=0 python /home/jfkuang/code/ie_e2e/tools/train.py \ 16 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_ocr_pretrain/custom_dataset/synth_chn_default_dp02_rc_rr_cj_blsh_lr4e4_dpp02_cloud_kjf.py \ 17 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/ephoie_pretrain_chn_bs8_480_lower100 --gpus 1 \ 18 | #--deterministic -------------------------------------------------------------------------------- /configs/_base_/det_models/fcenet_r50dcnv2_fpn.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='FCENet', 3 | backbone=dict( 4 | type='mmdet.ResNet', 5 | depth=50, 6 | num_stages=4, 7 | out_indices=(1, 2, 3), 8 | frozen_stages=-1, 9 | norm_cfg=dict(type='BN', requires_grad=True), 10 | norm_eval=True, 11 | style='pytorch', 12 | dcn=dict(type='DCNv2', deform_groups=2, fallback_on_stride=False), 13 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 14 | stage_with_dcn=(False, True, True, True)), 15 | neck=dict( 16 | type='mmdet.FPN', 17 | in_channels=[512, 1024, 2048], 18 | out_channels=256, 19 | add_extra_convs='on_output', 20 | num_outs=3, 21 | relu_before_extra_convs=True, 22 | act_cfg=None), 23 | bbox_head=dict( 24 | type='FCEHead', 25 | in_channels=256, 26 | scales=(8, 16, 32), 27 | fourier_degree=5, 28 | loss=dict(type='FCELoss', num_sample=50), 29 | postprocessor=dict( 30 | type='FCEPostprocessor', 31 | text_repr_type='poly', 32 | num_reconstr_points=50, 33 | alpha=1.0, 34 | beta=2.0, 35 | score_thr=0.3))) 36 | -------------------------------------------------------------------------------- /docs/zh_cn/index.rst: -------------------------------------------------------------------------------- 1 | 欢迎来到 MMOCR 的中文文档! 2 | ======================================= 3 | 4 | 您可以在页面左下角切换中英文文档。 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: 开始 9 | 10 | install.md 11 | getting_started.md 12 | demo.md 13 | training.md 14 | testing.md 15 | deployment.md 16 | model_serving.md 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: 教程 21 | 22 | tutorials/config.md 23 | tutorials/dataset_types.md 24 | tutorials/kie_closeset_openset.md 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: 模型库 29 | 30 | modelzoo.md 31 | textdet_models.md 32 | textrecog_models.md 33 | kie_models.md 34 | ner_models.md 35 | 36 | .. toctree:: 37 | :maxdepth: 2 38 | :caption: 数据集 39 | 40 | datasets/det.md 41 | datasets/recog.md 42 | datasets/kie.md 43 | datasets/ner.md 44 | 45 | .. toctree:: 46 | :maxdepth: 2 47 | :caption: 杂项 48 | 49 | tools.md 50 | changelog.md 51 | 52 | .. toctree:: 53 | :caption: API 参考 54 | 55 | api.rst 56 | 57 | .. toctree:: 58 | :caption: 切换语言 59 | 60 | English 61 | 简体中文 62 | 63 | 导引 64 | ================== 65 | 66 | * :ref:`genindex` 67 | * :ref:`search` 68 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/synthtext_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/5 16:38 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/ocr_benchmark/benchmark/SynthText' 8 | ann_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/SynthText/' 9 | 10 | loader = dict( 11 | type='HardDiskLoader', 12 | repeat=1, 13 | parser=dict( 14 | type='LineJsonParser', 15 | keys=['file_name', 'height', 'width', 'annotations'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{ann_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{ann_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=None, 26 | data_type='ocr', 27 | max_seq_len=60, 28 | check_outside=True,) 29 | 30 | test = dict( 31 | type=dataset_type, 32 | ann_file=f'{ann_root}/train.txt', 33 | loader=loader, 34 | dict_file=f'{ann_root}/dict.json', 35 | img_prefix=data_root, 36 | pipeline=None, 37 | test_mode=False, 38 | class_file=None, 39 | data_type='ocr', 40 | max_seq_len=60, 41 | check_outside=True,) 42 | 43 | train_list = [train] 44 | 45 | test_list = [test] 46 | -------------------------------------------------------------------------------- /docs/en/tools.md: -------------------------------------------------------------------------------- 1 | # Useful Tools 2 | 3 | We provide some useful tools under `mmocr/tools` directory. 4 | 5 | ## Publish a Model 6 | 7 | Before you upload a model to AWS, you may want to 8 | (1) convert the model weights to CPU tensors, (2) delete the optimizer states and 9 | (3) compute the hash of the checkpoint file and append the hash id to the filename. These functionalities could be achieved by `tools/publish_model.py`. 10 | 11 | ```shell 12 | python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME} 13 | ``` 14 | 15 | For example, 16 | 17 | ```shell 18 | python tools/publish_model.py work_dirs/psenet/latest.pth psenet_r50_fpnf_sbn_1x_20190801.pth 19 | ``` 20 | 21 | The final output filename will be `psenet_r50_fpnf_sbn_1x_20190801-{hash id}.pth`. 22 | 23 | 24 | ## Convert txt annotation to lmdb format 25 | Sometimes, loading a large txt annotation file with multiple workers can cause OOM (out of memory) error. You can convert the file into lmdb format using `tools/data/utils/txt2lmdb.py` and use LmdbLoader in your config to avoid this issue. 26 | ```bash 27 | python tools/data/utils/txt2lmdb.py -i -o 28 | ``` 29 | For example, 30 | ```bash 31 | python tools/data/utils/txt2lmdb.py -i data/mixture/Syn90k/label.txt -o data/mixture/Syn90k/label.lmdb 32 | ``` 33 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /tests/test_models/test_targets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | from mmocr.datasets.pipelines.textdet_targets.dbnet_targets import DBNetTargets 5 | 6 | 7 | def test_invalid_polys(): 8 | 9 | dbtarget = DBNetTargets() 10 | 11 | poly = np.array([[256.1229216, 347.17471155], [257.63126133, 347.0069367], 12 | [257.70317729, 347.65337423], 13 | [256.19488113, 347.82114909]]) 14 | 15 | assert dbtarget.invalid_polygon(poly) 16 | 17 | poly = np.array([[570.34735492, 18 | 335.00214526], [570.99778839, 335.00327318], 19 | [569.69077318, 338.47009908], 20 | [569.04038393, 338.46894904]]) 21 | assert dbtarget.invalid_polygon(poly) 22 | 23 | poly = np.array([[481.18343777, 24 | 305.03190065], [479.88478587, 305.10684512], 25 | [479.90976971, 305.53968843], [480.99197962, 26 | 305.4772347]]) 27 | assert dbtarget.invalid_polygon(poly) 28 | 29 | poly = np.array([[0, 0], [2, 0], [2, 2], [0, 2]]) 30 | assert dbtarget.invalid_polygon(poly) 31 | 32 | poly = np.array([[0, 0], [10, 0], [10, 10], [0, 10]]) 33 | assert not dbtarget.invalid_polygon(poly) 34 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/synth_chn_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/interns/v_willwhua/dataset/ocr_benchmark/synth_chinese/data/syntext' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/custom_json_format.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/custom_dict.json', 22 | img_prefix=f'{data_root}/syn_130k_images', 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=None, 26 | data_type='ocr', 27 | max_seq_len=75, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | """ 32 | Total sample num: 134514 33 | max_pt_num:4 34 | avg_ins:11.237142602257014, max:321, min:1 35 | avg_height:420.61183222564193, avg_width:487.8559555139242 36 | max_len:71 37 | avg_ins_height:35.7767097812647, avg_ins_width:81.25013148728493 38 | """ 39 | 40 | train_list = [train] 41 | 42 | test_list = [train] 43 | -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/12 17:11 4 | # @Author : WeiHua 5 | 6 | from .ar_reader import AutoRegReader 7 | from mmocr.models.spotting.recognizers.old.ar_reader_v1 import AutoRegReaderV1 8 | from mmocr.models.spotting.recognizers.old.ar_reader_idpdt import AutoRegReaderIDPDT 9 | from .ar_reader_serial import AutoRegReaderSerial 10 | from .ar_reader_serial_local_ie import AutoRegReaderSerialLocalIE 11 | from .re_imple_trie.trie import CustomTRIE 12 | from .ar_reader_nar_ie import AutoRegReaderNARIE 13 | from .counters import CSRNetDecoder 14 | from .ar_reader_nar_ie_0726 import AutoRegReaderNARIE0726 15 | # from .rnn_attention_nar_ie import RNNRecNARIE 16 | from .ar_reader_nar_ie_0726_kvc import AutoRegReaderNARIE0726_kvc 17 | from .ar_reader_nar_ie_0726_kvc_decoder import AutoRegReaderNARIE0726_kvc_decoder 18 | from .ar_reader_nar_ie_0726_kvc_head import AutoRegReaderNARIE0726_kvc_head 19 | 20 | __all__ = [ 21 | 'AutoRegReader', 'AutoRegReaderV1', 22 | 'AutoRegReaderIDPDT', 'AutoRegReaderSerial', 23 | 'AutoRegReaderSerialLocalIE', 'CustomTRIE', 24 | 'AutoRegReaderNARIE', 'CSRNetDecoder', 25 | 'AutoRegReaderNARIE0726', 'AutoRegReaderNARIE0726_kvc', 26 | 'AutoRegReaderNARIE0726_kvc_head', 'AutoRegReaderNARIE0726_kvc_decoder' 27 | ] 28 | -------------------------------------------------------------------------------- /mmocr/models/textrecog/encoders/channel_reduction_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | from mmocr.models.builder import ENCODERS 5 | from .base_encoder import BaseEncoder 6 | 7 | 8 | @ENCODERS.register_module() 9 | class ChannelReductionEncoder(BaseEncoder): 10 | """Change the channel number with a one by one convoluational layer. 11 | 12 | Args: 13 | in_channels (int): Number of input channels. 14 | out_channels (int): Number of output channels. 15 | init_cfg (dict or list[dict], optional): Initialization configs. 16 | """ 17 | 18 | def __init__(self, 19 | in_channels, 20 | out_channels, 21 | init_cfg=dict(type='Xavier', layer='Conv2d')): 22 | super().__init__(init_cfg=init_cfg) 23 | 24 | self.layer = nn.Conv2d( 25 | in_channels, out_channels, kernel_size=1, stride=1, padding=0) 26 | 27 | def forward(self, feat, img_metas=None): 28 | """ 29 | Args: 30 | feat (Tensor): Image features with the shape of 31 | :math:`(N, C_{in}, H, W)`. 32 | img_metas (None): Unused. 33 | 34 | Returns: 35 | Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`. 36 | """ 37 | return self.layer(feat) 38 | -------------------------------------------------------------------------------- /mmocr/datasets/ocr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.datasets.builder import DATASETS 3 | 4 | from mmocr.core.evaluation.ocr_metric import eval_ocr_metric 5 | from mmocr.datasets.base_dataset import BaseDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class OCRDataset(BaseDataset): 10 | 11 | def pre_pipeline(self, results): 12 | results['img_prefix'] = self.img_prefix 13 | results['text'] = results['img_info']['text'] 14 | 15 | def evaluate(self, results, metric='acc', logger=None, **kwargs): 16 | """Evaluate the dataset. 17 | 18 | Args: 19 | results (list): Testing results of the dataset. 20 | metric (str | list[str]): Metrics to be evaluated. 21 | logger (logging.Logger | str | None): Logger used for printing 22 | related information during evaluation. Default: None. 23 | Returns: 24 | dict[str: float] 25 | """ 26 | gt_texts = [] 27 | pred_texts = [] 28 | for i in range(len(self)): 29 | item_info = self.data_infos[i] 30 | text = item_info['text'] 31 | gt_texts.append(text) 32 | pred_texts.append(results[i]['text']) 33 | 34 | eval_results = eval_ocr_metric(pred_texts, gt_texts) 35 | 36 | return eval_results 37 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/synthtext_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/1/5 16:38 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/ocr_benchmark/benchmark/SynthText' 8 | ann_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/SynthText/' 9 | 10 | loader = dict( 11 | type='HardDiskLoader', 12 | repeat=1, 13 | parser=dict( 14 | type='LineJsonParser', 15 | keys=['file_name', 'height', 'width', 'annotations'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{ann_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{ann_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=None, 26 | data_type='ocr', 27 | max_seq_len=60, 28 | check_outside=True, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{ann_root}/train.txt', 35 | loader=loader, 36 | dict_file=f'{ann_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=False, 40 | class_file=None, 41 | data_type='ocr', 42 | max_seq_len=60, 43 | check_outside=True, 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_2200_ar_local_9999.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/data/whua/dataset/ie_e2e/nfv5_2200' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=110, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=110, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3128_ar_local_1061.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e/nfv5_3128' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import argparse 4 | import subprocess 5 | 6 | import torch 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Process a checkpoint to be published') 12 | parser.add_argument('in_file', help='input checkpoint filename') 13 | parser.add_argument('out_file', help='output checkpoint filename') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def process_checkpoint(in_file, out_file): 19 | checkpoint = torch.load(in_file, map_location='cpu') 20 | # remove optimizer for smaller file size 21 | if 'optimizer' in checkpoint: 22 | del checkpoint['optimizer'] 23 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 24 | # add the code here. 25 | if 'meta' in checkpoint: 26 | checkpoint['meta'] = {'CLASSES': 0} 27 | torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) 28 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 29 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 30 | subprocess.Popen(['mv', out_file, final_file]) 31 | 32 | 33 | def main(): 34 | args = parse_args() 35 | process_checkpoint(args.in_file, args.out_file) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/local/nfv5_2200_ar_local_9999.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/data/whua/dataset/ie_e2e/nfv5_2200' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='ocr', 27 | max_seq_len=110, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='ocr', 42 | max_seq_len=110, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3125_ar_local_1032.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/nfv5_3125' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3125_ar_local_1061.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/nfv5_3125' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3125_ar_local_1062.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/nfv5_3125' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3125_ar_local_1803.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/nfv5_3125' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3128_ar_local_1803.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/nfv5_3128' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='vie', 27 | max_seq_len=125, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /tools/train_1803.sh: -------------------------------------------------------------------------------- 1 | #CUDA_VISIBLE_DEVICES=2 python /home/jfkuang/code/ie_e2e/tools/train.py \ 2 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_local_1803_vis.py \ 3 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/test 4 | 5 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=4 --master_port=10011 \ 6 | #/home/jfkuang/code/ie_e2e/tools/train.py \ 7 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/sroie/local/sroie_ie_1803.py \ 8 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/ours_another_seed --launcher pytorch --gpus 4 \ 9 | #--deterministic --seed 1364371869 10 | 11 | #single test + vis 12 | #CUDA_VISIBLE_DEVICES=3 python /home/jfkuang/code/ie_e2e/tools/test.py \ 13 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_local_1803_vis.py \ 14 | #/home/jfkuang/logs/ie_e2e_log/vies_kvc_nfv5/epoch_180.pth \ 15 | #--eval hmean-iou --show-dir /home/jfkuang/logs/vis/vies_kvc_nfv5 16 | 17 | CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch --nproc_per_node=4 --master_port=10019 \ 18 | /home/jfkuang/code/ie_e2e/tools/train.py \ 19 | /home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/cord/cord_baseline_1280_200e_1803.py \ 20 | --work-dir=/home/jfkuang/logs/ie_e2e_log/ours_cord_600e_new_weights --launcher pytorch --gpus 4 \ 21 | --deterministic --seed 3407 22 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/local/nfv5_2200_ar_local_1032.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/datasets/ie_e2e_format/nfv5_2200' 8 | 9 | loader = dict( 10 | type='HardDiskLoader', 11 | repeat=1, 12 | parser=dict( 13 | type='CustomLineJsonParser', 14 | keys=['file_name', 'height', 'width', 'annotations'], 15 | optional_keys=['entity_dict'])) 16 | 17 | train = dict( 18 | type=dataset_type, 19 | ann_file=f'{data_root}/train.txt', 20 | loader=loader, 21 | dict_file=f'{data_root}/dict.json', 22 | img_prefix=data_root, 23 | pipeline=None, 24 | test_mode=False, 25 | class_file=f'{data_root}/class_list.json', 26 | data_type='ocr', 27 | max_seq_len=110, 28 | order_type='shuffle', 29 | auto_reg=True, 30 | pre_parse_anno=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='ocr', 42 | max_seq_len=110, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /mmocr/utils/data_convert_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | 5 | def convert_annotations(image_infos, out_json_name): 6 | """Convert the annotation into coco style. 7 | 8 | Args: 9 | image_infos(list): The list of image information dicts 10 | out_json_name(str): The output json filename 11 | 12 | Returns: 13 | out_json(dict): The coco style dict 14 | """ 15 | assert isinstance(image_infos, list) 16 | assert isinstance(out_json_name, str) 17 | assert out_json_name 18 | 19 | out_json = dict() 20 | img_id = 0 21 | ann_id = 0 22 | out_json['images'] = [] 23 | out_json['categories'] = [] 24 | out_json['annotations'] = [] 25 | for image_info in image_infos: 26 | image_info['id'] = img_id 27 | anno_infos = image_info.pop('anno_info') 28 | out_json['images'].append(image_info) 29 | for anno_info in anno_infos: 30 | anno_info['image_id'] = img_id 31 | anno_info['id'] = ann_id 32 | out_json['annotations'].append(anno_info) 33 | ann_id += 1 34 | img_id += 1 35 | cat = dict(id=1, name='text') 36 | out_json['categories'].append(cat) 37 | 38 | if len(out_json['annotations']) == 0: 39 | out_json.pop('annotations') 40 | mmcv.dump(out_json, out_json_name) 41 | 42 | return out_json 43 | -------------------------------------------------------------------------------- /tests/test_utils/test_string_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | 4 | from mmocr.utils import StringStrip 5 | 6 | 7 | def test_string_strip(): 8 | strip_list = [True, False] 9 | strip_pos_list = ['both', 'left', 'right'] 10 | strip_str_list = [None, ' '] 11 | 12 | in_str_list = [ 13 | ' hello ', 'hello ', ' hello', ' hello', 'hello ', 'hello ', 'hello', 14 | 'hello', 'hello', 'hello', 'hello', 'hello' 15 | ] 16 | out_str_list = [ 17 | 'hello', 'hello', 'hello', 'hello', 'hello', 'hello', 'hello', 'hello', 18 | 'hello', 'hello', 'hello', 'hello' 19 | ] 20 | 21 | for idx1, strip in enumerate(strip_list): 22 | for idx2, strip_pos in enumerate(strip_pos_list): 23 | for idx3, strip_str in enumerate(strip_str_list): 24 | tmp_args = dict( 25 | strip=strip, strip_pos=strip_pos, strip_str=strip_str) 26 | strip_class = StringStrip(**tmp_args) 27 | i = idx1 * len(strip_pos_list) * len( 28 | strip_str_list) + idx2 * len(strip_str_list) + idx3 29 | 30 | assert strip_class(in_str_list[i]) == out_str_list[i] 31 | 32 | with pytest.raises(AssertionError): 33 | StringStrip(strip='strip') 34 | StringStrip(strip_pos='head') 35 | StringStrip(strip_str=['\n', '\t']) 36 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/cord_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/cord/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='LineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'])) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=36, 29 | order_type='shuffle') 30 | 31 | test = dict( 32 | type=dataset_type, 33 | ann_file=f'{data_root}/test.txt', 34 | loader=loader, 35 | dict_file=f'{data_root}/dict.json', 36 | img_prefix=data_root, 37 | pipeline=None, 38 | test_mode=True, 39 | class_file=f'{data_root}/class_list.json', 40 | data_type='vie', 41 | max_seq_len=36, 42 | order_type='origin') 43 | 44 | train_list = [train] 45 | 46 | test_list = [test] 47 | -------------------------------------------------------------------------------- /tests/test_dataset/test_test_time_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import pytest 4 | 5 | from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR 6 | 7 | 8 | def test_resize_ocr(): 9 | input_img1 = np.ones((64, 256, 3), dtype=np.uint8) 10 | input_img2 = np.ones((64, 32, 3), dtype=np.uint8) 11 | 12 | rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270]) 13 | 14 | # test invalid arguments 15 | with pytest.raises(AssertionError): 16 | MultiRotateAugOCR(transforms=[], rotate_degrees=[45]) 17 | with pytest.raises(AssertionError): 18 | MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5]) 19 | 20 | # test call with input_img1 21 | results = {'img_shape': input_img1.shape, 'img': input_img1} 22 | results = rci(results) 23 | assert np.allclose([64, 256, 3], results['img_shape']) 24 | assert len(results['img']) == 1 25 | assert len(results['img_shape']) == 1 26 | assert np.allclose([64, 256, 3], results['img_shape'][0]) 27 | 28 | # test call with input_img2 29 | results = {'img_shape': input_img2.shape, 'img': input_img2} 30 | results = rci(results) 31 | assert np.allclose([64, 32, 3], results['img_shape']) 32 | assert len(results['img']) == 3 33 | assert len(results['img_shape']) == 3 34 | assert np.allclose([64, 32, 3], results['img_shape'][0]) 35 | -------------------------------------------------------------------------------- /tests/test_models/test_ocr_preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmocr.models.textrecog.preprocessor import (BasePreprocessor, 6 | TPSPreprocessor) 7 | 8 | 9 | def test_tps_preprocessor(): 10 | with pytest.raises(AssertionError): 11 | TPSPreprocessor(num_fiducial=-1) 12 | with pytest.raises(AssertionError): 13 | TPSPreprocessor(img_size=32) 14 | with pytest.raises(AssertionError): 15 | TPSPreprocessor(rectified_img_size=100) 16 | with pytest.raises(AssertionError): 17 | TPSPreprocessor(num_img_channel='bgr') 18 | 19 | tps_preprocessor = TPSPreprocessor( 20 | num_fiducial=20, 21 | img_size=(32, 100), 22 | rectified_img_size=(32, 100), 23 | num_img_channel=1) 24 | tps_preprocessor.init_weights() 25 | tps_preprocessor.train() 26 | 27 | batch_img = torch.randn(1, 1, 32, 100) 28 | processed = tps_preprocessor(batch_img) 29 | assert processed.shape == torch.Size([1, 1, 32, 100]) 30 | 31 | 32 | def test_base_preprocessor(): 33 | preprocessor = BasePreprocessor() 34 | preprocessor.init_weights() 35 | preprocessor.train() 36 | 37 | batch_img = torch.randn(1, 1, 32, 100) 38 | processed = preprocessor(batch_img) 39 | assert processed.shape == torch.Size([1, 1, 32, 100]) 40 | -------------------------------------------------------------------------------- /mmocr/models/textdet/detectors/fcenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmocr.models.builder import DETECTORS 3 | from .single_stage_text_detector import SingleStageTextDetector 4 | from .text_detector_mixin import TextDetectorMixin 5 | 6 | 7 | @DETECTORS.register_module() 8 | class FCENet(TextDetectorMixin, SingleStageTextDetector): 9 | """The class for implementing FCENet text detector 10 | FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped Text 11 | Detection 12 | 13 | [https://arxiv.org/abs/2104.10442] 14 | """ 15 | 16 | def __init__(self, 17 | backbone, 18 | neck, 19 | bbox_head, 20 | train_cfg=None, 21 | test_cfg=None, 22 | pretrained=None, 23 | show_score=False, 24 | init_cfg=None): 25 | SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, 26 | train_cfg, test_cfg, pretrained, 27 | init_cfg) 28 | TextDetectorMixin.__init__(self, show_score) 29 | 30 | def simple_test(self, img, img_metas, rescale=False): 31 | x = self.extract_feat(img) 32 | outs = self.bbox_head(x) 33 | boundaries = self.bbox_head.get_boundary(outs, img_metas, rescale) 34 | 35 | return [boundaries] 36 | -------------------------------------------------------------------------------- /tools/use_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :use_gpu.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/10/13 22:39 9 | ''' 10 | import torch 11 | import time 12 | import os 13 | import argparse 14 | import shutil 15 | import sys 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Matrix multiplication') 20 | parser.add_argument('--gpus', help='gpu amount', required=True, type=int) 21 | parser.add_argument('--size', help='matrix size', required=True, type=int) 22 | parser.add_argument('--interval', help='sleep interval', required=True, type=float) 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def matrix_multiplication(args): 28 | a_list, b_list, result = [], [], [] 29 | size = (args.size, args.size) 30 | 31 | for i in range(args.gpus): 32 | a_list.append(torch.rand(size, device=i+2)) 33 | b_list.append(torch.rand(size, device=i+2)) 34 | result.append(torch.rand(size, device=i+2)) 35 | 36 | while True: 37 | for i in range(args.gpus): 38 | result[i] = a_list[i] * b_list[i] 39 | time.sleep(args.interval) 40 | 41 | 42 | if __name__ == "__main__": 43 | # usage: python matrix_multiplication_gpus.py --size 20000 --gpus 2 --interval 0.01 44 | args = parse_args() 45 | matrix_multiplication(args) -------------------------------------------------------------------------------- /mmocr/models/spotting/recognizers/re_imple_trie/custom_davar_builder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/6/4 15:51 4 | # @Author : WeiHua 5 | 6 | from mmocr.models.spotting.recognizers.re_imple_trie.connects.multimodal_context_module import MultiModalContextModule 7 | from mmocr.models.spotting.recognizers.re_imple_trie.connects.multimodal_feature_merge import MultiModalFusion 8 | from mmocr.models.spotting.recognizers.re_imple_trie.connects.bert_encoder import BertEncoder 9 | from mmocr.models.spotting.recognizers.re_imple_trie.embedding.node_embedding import NodeEmbedding 10 | from mmocr.models.spotting.recognizers.re_imple_trie.embedding.position_embedding import PositionEmbedding2D 11 | from mmocr.models.spotting.recognizers.re_imple_trie.embedding.sentence_embedding import SentenceEmbeddingCNN 12 | 13 | CONNECT_MODULE = { 14 | "MultiModalContextModule": MultiModalContextModule, 15 | "MultiModalFusion": MultiModalFusion, 16 | "BertEncoder": BertEncoder 17 | } 18 | 19 | EMBEDDING_MODULE = { 20 | "NodeEmbedding": NodeEmbedding, 21 | "PositionEmbedding2D": PositionEmbedding2D, 22 | "SentenceEmbeddingCNN": SentenceEmbeddingCNN 23 | } 24 | 25 | 26 | def build_connect(cfg): 27 | func = CONNECT_MODULE[cfg.pop('type')] 28 | return func(**cfg) 29 | 30 | 31 | def build_embedding(cfg): 32 | func = EMBEDDING_MODULE[cfg.pop('type')] 33 | return func(**cfg) 34 | -------------------------------------------------------------------------------- /tools/train_1062.sh: -------------------------------------------------------------------------------- 1 | #single train 2 | #CUDA_VISIBLE_DEVICES=0 python /home/jfkuang/code/ie_e2e/tools/train.py \ 3 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/cord/cord_baseline_ie_head_kvc_1280_200e_1062.py \ 4 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/test 5 | 6 | #11.6 7 | #CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=10063 \ 8 | #/home/jfkuang/code/ie_e2e/tools/train.py \ 9 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_local_1062.py \ 10 | #--work-dir=/home/jfkuang/logs/ie_e2e_log/VIES_GT --launcher pytorch --gpus 4 \ 11 | #--deterministic --seed 3407 12 | 13 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=10019 \ 14 | /home/jfkuang/code/ie_e2e/tools/train.py \ 15 | /home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_local_1062.py \ 16 | --work-dir=/home/jfkuang/logs/ie_e2e_log/encoed_feature_as_entity --launcher pytorch --gpus 4 \ 17 | --deterministic --seed 3407 18 | 19 | 20 | 21 | #single test 22 | #CUDA_VISIBLE_DEVICES=0 python /home/jfkuang/code/ie_e2e/tools/test.py \ 23 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_local_1062.py \ 24 | #/home/jfkuang/logs/ie_e2e_log/ours_GT/epoch_10.pth \ 25 | #--eval hmean-iou --show-dir /data2/jfkuang/logs/vis/test -------------------------------------------------------------------------------- /docs/en/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to MMOCR's documentation! 2 | ======================================= 3 | 4 | You can switch between English and Chinese in the lower-left corner of the layout. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Getting Started 9 | 10 | install.md 11 | getting_started.md 12 | demo.md 13 | training.md 14 | testing.md 15 | deployment.md 16 | model_serving.md 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Tutorials 21 | 22 | tutorials/config.md 23 | tutorials/dataset_types.md 24 | tutorials/kie_closeset_openset.md 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: Model Zoo 29 | 30 | modelzoo.md 31 | textdet_models.md 32 | textrecog_models.md 33 | kie_models.md 34 | ner_models.md 35 | 36 | .. toctree:: 37 | :maxdepth: 2 38 | :caption: Dataset Zoo 39 | 40 | datasets/det.md 41 | datasets/recog.md 42 | datasets/kie.md 43 | datasets/ner.md 44 | 45 | .. toctree:: 46 | :maxdepth: 2 47 | :caption: Miscellaneous 48 | 49 | tools.md 50 | changelog.md 51 | 52 | .. toctree:: 53 | :caption: API Reference 54 | 55 | api.rst 56 | 57 | .. toctree:: 58 | :caption: Switch Language 59 | 60 | English 61 | 简体中文 62 | 63 | Indices and tables 64 | ================== 65 | 66 | * :ref:`genindex` 67 | * :ref:`search` 68 | -------------------------------------------------------------------------------- /configs/_base_/recog_pipelines/sar_pipeline.py: -------------------------------------------------------------------------------- 1 | img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 2 | train_pipeline = [ 3 | dict(type='LoadImageFromFile'), 4 | dict( 5 | type='ResizeOCR', 6 | height=48, 7 | min_width=48, 8 | max_width=160, 9 | keep_aspect_ratio=True, 10 | width_downsample_ratio=0.25), 11 | dict(type='ToTensorOCR'), 12 | dict(type='NormalizeOCR', **img_norm_cfg), 13 | dict( 14 | type='Collect', 15 | keys=['img'], 16 | meta_keys=[ 17 | 'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio' 18 | ]), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiRotateAugOCR', 24 | rotate_degrees=[0, 90, 270], 25 | transforms=[ 26 | dict( 27 | type='ResizeOCR', 28 | height=48, 29 | min_width=48, 30 | max_width=160, 31 | keep_aspect_ratio=True, 32 | width_downsample_ratio=0.25), 33 | dict(type='ToTensorOCR'), 34 | dict(type='NormalizeOCR', **img_norm_cfg), 35 | dict( 36 | type='Collect', 37 | keys=['img'], 38 | meta_keys=[ 39 | 'filename', 'ori_shape', 'resize_shape', 'valid_ratio' 40 | ]), 41 | ]) 42 | ] 43 | -------------------------------------------------------------------------------- /tools/test_kjf.sh: -------------------------------------------------------------------------------- 1 | #single test 2 | #CUDA_VISIBLE_DEVICES=6 python /home/jfkuang/code/ie_e2e/tools/test.py \ 3 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_ocr_pretrain/ephoie/ephoie_default_dp02_lr2e4_noalign_add_det_epoch600_pretrain_1032_kjf.py \ 4 | #/home/jfkuang/logs/ie_e2e_log/ephoie_baseline_noalign_epoch600_adddet_pretrain25_1280/epoch_600.pth \ 5 | #--eval hmean-iou --show-dir /home/jfkuang/logs/vis/test_new 6 | 7 | #vies 8 | CUDA_VISIBLE_DEVICES=5 python /home/jfkuang/code/ie_e2e/tools/test.py \ 9 | /home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/sroie/local/sroie_kvc_ie_200e.py \ 10 | /data3/jfkuang/logs/ie_e2e_log/vies_sroie_600epoch/epoch_600.pth \ 11 | --eval hmean-iou-sroie --show-dir /data3/jfkuang/vis_sroie/vis_text_red/ 12 | 13 | #trie 14 | #CUDA_VISIBLE_DEVICES=5 python /home/jfkuang/code/ie_e2e/tools/test.py \ 15 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_trie/v5/local/nfv5_3125_sdef_3l_disen_200e_720_local_3090.py \ 16 | #/data3/jfkuang/vis_weights_trie/epoch_170.pth \ 17 | #--eval hmean-iou --show-dir /data3/jfkuang/vis_weights_trie/vis_no_text_green/ 18 | 19 | #ours 20 | #CUDA_VISIBLE_DEVICES=5 python /home/jfkuang/code/ie_e2e/tools/test.py \ 21 | #/home/jfkuang/code/ie_e2e/configs/vie_custom/e2e_ar_vie/v5/local/nfv5_3125_sdef_rnn_kvc_200e_720_3090_vis.py \ 22 | #/data3/jfkuang/vis_weights_ours/epoch_160.pth \ 23 | #--eval hmean-iou --show-dir /data3/jfkuang/vis_weights_ours/vis_no_text_red/ -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/synth_chn_ar_cloud_kjf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :synth_chn_ar_cloud_kjf.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/6/4 17:12 9 | ''' 10 | dataset_type = 'VIEE2EDataset' 11 | #1032 12 | # data_root = '/data/jfkuang/syntext' 13 | #1803 14 | data_root = '/home/jfkuang/data/syntext' 15 | 16 | loader = dict( 17 | type='HardDiskLoader', 18 | repeat=1, 19 | parser=dict( 20 | type='CustomLineJsonParser', 21 | keys=['file_name', 'height', 'width', 'annotations'], 22 | optional_keys=['entity_dict'])) 23 | 24 | train = dict( 25 | type=dataset_type, 26 | ann_file=f'{data_root}/custom_json_format.txt', 27 | loader=loader, 28 | dict_file=f'{data_root}/custom_dict.json', 29 | img_prefix=f'{data_root}/syn_130k_images', 30 | pipeline=None, 31 | test_mode=False, 32 | class_file=None, 33 | data_type='ocr', 34 | max_seq_len=75, 35 | order_type='shuffle', 36 | auto_reg=True, 37 | pre_parse_anno=True) 38 | """ 39 | Total sample num: 134514 40 | max_pt_num:4 41 | avg_ins:11.237142602257014, max:321, min:1 42 | avg_height:420.61183222564193, avg_width:487.8559555139242 43 | max_len:71 44 | avg_ins_height:35.7767097812647, avg_ins_width:81.25013148728493 45 | """ 46 | 47 | train_list = [train] 48 | 49 | test_list = [train] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv4_ar_local_1803.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e/nfv4' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'],)) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=125, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/ephoie_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/datasets/ie_e2e_format/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='ocr', 29 | max_seq_len=80, 30 | order_type='shuffle') 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='ocr', 42 | max_seq_len=80, 43 | order_type='origin') 44 | 45 | train_list = [train] 46 | 47 | test_list = [test] 48 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/cord_ar_local_1032.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e_dataset/cord/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='LineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'])) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=36, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=36, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/sroie_3090.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :sroie_3090.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/10/13 17:46 9 | ''' 10 | dataset_type = 'VIEE2EDataset' 11 | data_root = '/data3/jfkuang/data/sroie/e2e_format' 12 | 13 | loader = dict( 14 | type='HardDiskLoader', 15 | repeat=1, 16 | parser=dict( 17 | type='CustomLineJsonParser', 18 | keys=['file_name', 'height', 'width', 'annotations'], 19 | optional_keys=['entity_dict'])) 20 | 21 | train = dict( 22 | type=dataset_type, 23 | ann_file=f'{data_root}/train_update_screen.txt', 24 | loader=loader, 25 | dict_file=f'{data_root}/dict.json', 26 | img_prefix=data_root, 27 | pipeline=None, 28 | test_mode=False, 29 | class_file=f'{data_root}/class_list.json', 30 | data_type='vie', 31 | max_seq_len=72, 32 | order_type='shuffle', 33 | auto_reg=True) 34 | 35 | test = dict( 36 | type=dataset_type, 37 | ann_file=f'{data_root}/test_screen.txt', 38 | loader=loader, 39 | dict_file=f'{data_root}/dict.json', 40 | img_prefix=data_root, 41 | pipeline=None, 42 | test_mode=True, 43 | class_file=f'{data_root}/class_list.json', 44 | data_type='vie', 45 | max_seq_len=72, 46 | order_type='origin', 47 | auto_reg=True) 48 | 49 | train_list = [train] 50 | 51 | test_list = [test] 52 | -------------------------------------------------------------------------------- /tests/test_apis/test_image_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import pytest 4 | import torch 5 | from numpy.testing import assert_array_equal 6 | 7 | from mmocr.apis.utils import tensor2grayimgs 8 | 9 | 10 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') 11 | def test_tensor2grayimgs(): 12 | 13 | # test tensor obj 14 | with pytest.raises(AssertionError): 15 | tensor = np.random.rand(2, 3, 3) 16 | tensor2grayimgs(tensor) 17 | 18 | # test tensor ndim 19 | with pytest.raises(AssertionError): 20 | tensor = torch.randn(2, 3, 3) 21 | tensor2grayimgs(tensor) 22 | 23 | # test tensor dim-1 24 | with pytest.raises(AssertionError): 25 | tensor = torch.randn(2, 3, 5, 5) 26 | tensor2grayimgs(tensor) 27 | 28 | # test mean length 29 | with pytest.raises(AssertionError): 30 | tensor = torch.randn(2, 1, 5, 5) 31 | tensor2grayimgs(tensor, mean=(1, 1, 1)) 32 | 33 | # test std length 34 | with pytest.raises(AssertionError): 35 | tensor = torch.randn(2, 1, 5, 5) 36 | tensor2grayimgs(tensor, std=(1, 1, 1)) 37 | 38 | tensor = torch.randn(2, 1, 5, 5) 39 | gts = [t.squeeze(0).cpu().numpy().astype(np.uint8) for t in tensor] 40 | outputs = tensor2grayimgs(tensor, mean=(0, ), std=(1, )) 41 | for gt, output in zip(gts, outputs): 42 | assert_array_equal(gt, output) 43 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/cord_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/cord/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='LineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'])) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='ocr', 28 | max_seq_len=36, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='ocr', 42 | max_seq_len=36, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/ocr_datasets/sroie_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/sroie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='ocr', 29 | max_seq_len=72, 30 | order_type='shuffle') 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='ocr', 42 | max_seq_len=72, 43 | order_type='origin') 44 | 45 | train_list = [train] 46 | 47 | test_list = [test] 48 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/cord_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/cord/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='LineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'])) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=36, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=36, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv3_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/common/whua/dataset/ie_e2e/nfv3/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'],)) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=125, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv4_ar_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/apdcephfs/share_887471/interns/v_willwhua/dataset/ie_e2e/nfv4' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'],)) 17 | 18 | train = dict( 19 | type=dataset_type, 20 | ann_file=f'{data_root}/train.txt', 21 | loader=loader, 22 | dict_file=f'{data_root}/dict.json', 23 | img_prefix=data_root, 24 | pipeline=None, 25 | test_mode=False, 26 | class_file=f'{data_root}/class_list.json', 27 | data_type='vie', 28 | max_seq_len=125, 29 | order_type='shuffle', 30 | auto_reg=True) 31 | 32 | test = dict( 33 | type=dataset_type, 34 | ann_file=f'{data_root}/test.txt', 35 | loader=loader, 36 | dict_file=f'{data_root}/dict.json', 37 | img_prefix=data_root, 38 | pipeline=None, 39 | test_mode=True, 40 | class_file=f'{data_root}/class_list.json', 41 | data_type='vie', 42 | max_seq_len=125, 43 | order_type='origin', 44 | auto_reg=True) 45 | 46 | train_list = [train] 47 | 48 | test_list = [test] 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/nfv5_3125_3090.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | ''' 4 | @Project :ie_e2e 5 | @File :nfv5_3125_3090.py 6 | @IDE :PyCharm 7 | @Author :jfkuang 8 | @Date :2022/10/13 17:45 9 | ''' 10 | 11 | dataset_type = 'VIEE2EDataset' 12 | data_root = '/data3/jfkuang/data/nfv5_3125' 13 | 14 | loader = dict( 15 | type='HardDiskLoader', 16 | repeat=1, 17 | parser=dict( 18 | type='CustomLineJsonParser', 19 | keys=['file_name', 'height', 'width', 'annotations'], 20 | optional_keys=['entity_dict'])) 21 | 22 | train = dict( 23 | type=dataset_type, 24 | ann_file=f'{data_root}/train.txt', 25 | loader=loader, 26 | dict_file=f'{data_root}/dict.json', 27 | img_prefix=data_root, 28 | pipeline=None, 29 | test_mode=False, 30 | class_file=f'{data_root}/class_list.json', 31 | data_type='vie', 32 | max_seq_len=125, 33 | order_type='shuffle', 34 | auto_reg=True, 35 | pre_parse_anno=True) 36 | 37 | test = dict( 38 | type=dataset_type, 39 | ann_file=f'{data_root}/test.txt', 40 | loader=loader, 41 | dict_file=f'{data_root}/dict.json', 42 | img_prefix=data_root, 43 | pipeline=None, 44 | test_mode=True, 45 | class_file=f'{data_root}/class_list.json', 46 | data_type='vie', 47 | max_seq_len=125, 48 | order_type='origin', 49 | auto_reg=True) 50 | 51 | train_list = [train] 52 | 53 | test_list = [test] 54 | -------------------------------------------------------------------------------- /mmocr/models/spotting/modules/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/sroie_ar_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e/sroie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=72, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=72, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/local/ephoie_ar_local_9999.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/data/whua/ie_e2e/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=80, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=80, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /mmocr/utils/string_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | class StringStrip: 3 | """Removing the leading and/or the trailing characters based on the string 4 | argument passed. 5 | 6 | Args: 7 | strip (bool): Whether remove characters from both left and right of 8 | the string. Default: True. 9 | strip_pos (str): Which position for removing, can be one of 10 | ('both', 'left', 'right'), Default: 'both'. 11 | strip_str (str|None): A string specifying the set of characters 12 | to be removed from the left and right part of the string. 13 | If None, all leading and trailing whitespaces 14 | are removed from the string. Default: None. 15 | """ 16 | 17 | def __init__(self, strip=True, strip_pos='both', strip_str=None): 18 | assert isinstance(strip, bool) 19 | assert strip_pos in ('both', 'left', 'right') 20 | assert strip_str is None or isinstance(strip_str, str) 21 | 22 | self.strip = strip 23 | self.strip_pos = strip_pos 24 | self.strip_str = strip_str 25 | 26 | def __call__(self, in_str): 27 | 28 | if not self.strip: 29 | return in_str 30 | 31 | if self.strip_pos == 'left': 32 | return in_str.lstrip(self.strip_str) 33 | elif self.strip_pos == 'right': 34 | return in_str.rstrip(self.strip_str) 35 | else: 36 | return in_str.strip(self.strip_str) 37 | -------------------------------------------------------------------------------- /tests/test_utils/test_check_argument.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | import mmocr.utils as utils 5 | 6 | 7 | def test_is_3dlist(): 8 | 9 | assert utils.is_3dlist([]) 10 | assert utils.is_3dlist([[]]) 11 | assert utils.is_3dlist([[[]]]) 12 | assert utils.is_3dlist([[[1]]]) 13 | assert not utils.is_3dlist([[1, 2]]) 14 | assert not utils.is_3dlist([[np.array([1, 2])]]) 15 | 16 | 17 | def test_is_2dlist(): 18 | 19 | assert utils.is_2dlist([]) 20 | assert utils.is_2dlist([[]]) 21 | assert utils.is_2dlist([[1]]) 22 | 23 | 24 | def test_is_type_list(): 25 | assert utils.is_type_list([], int) 26 | assert utils.is_type_list([], float) 27 | assert utils.is_type_list([np.array([])], np.ndarray) 28 | assert utils.is_type_list([1], int) 29 | assert utils.is_type_list(['str'], str) 30 | 31 | 32 | def test_is_none_or_type(): 33 | 34 | assert utils.is_none_or_type(None, int) 35 | assert utils.is_none_or_type(1.0, float) 36 | assert utils.is_none_or_type(np.ndarray([]), np.ndarray) 37 | assert utils.is_none_or_type(1, int) 38 | assert utils.is_none_or_type('str', str) 39 | 40 | 41 | def test_valid_boundary(): 42 | 43 | x = [0, 0, 1, 0, 1, 1, 0, 1] 44 | assert not utils.valid_boundary(x, True) 45 | assert not utils.valid_boundary([0]) 46 | assert utils.valid_boundary(x, False) 47 | x = [0, 0, 1, 0, 1, 1, 0, 1, 1] 48 | assert utils.valid_boundary(x, True) 49 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/local/ephoie_ar_local_1033.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/share/whua/dataset/ie_e2e/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=80, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=80, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/local/ephoie_ar_local_1061.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/dataset/ie_e2e/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=80, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=80, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/local/ephoie_ar_local_sort_1033.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/share/whua/dataset/ie_e2e/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=80, 30 | order_type='sort', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=80, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/sroie_ar_cloud_ssd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/data/docker/data_whua/ie_e2e/sroie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train_update.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=72, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=72, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | -------------------------------------------------------------------------------- /configs/vie_custom/_base_/vie_datasets/local/ephoie_ar_local_1032.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2022/2/2 22:53 4 | # @Author : WeiHua 5 | 6 | dataset_type = 'VIEE2EDataset' 7 | data_root = '/home/whua/datasets/ie_e2e_format/ephoie/e2e_format' 8 | # data_root = '/data/whua/dataset/ie_e2e/nfv1/ie_e2e_data/mm_format/table' 9 | # data_root = '/mnt/whua/ie_e2e_data/mm_format/table' 10 | 11 | loader = dict( 12 | type='HardDiskLoader', 13 | repeat=1, 14 | parser=dict( 15 | type='CustomLineJsonParser', 16 | keys=['file_name', 'height', 'width', 'annotations'], 17 | optional_keys=['entity_dict'])) 18 | 19 | train = dict( 20 | type=dataset_type, 21 | ann_file=f'{data_root}/train.txt', 22 | loader=loader, 23 | dict_file=f'{data_root}/dict.json', 24 | img_prefix=data_root, 25 | pipeline=None, 26 | test_mode=False, 27 | class_file=f'{data_root}/class_list.json', 28 | data_type='vie', 29 | max_seq_len=80, 30 | order_type='shuffle', 31 | auto_reg=True) 32 | 33 | test = dict( 34 | type=dataset_type, 35 | ann_file=f'{data_root}/test.txt', 36 | loader=loader, 37 | dict_file=f'{data_root}/dict.json', 38 | img_prefix=data_root, 39 | pipeline=None, 40 | test_mode=True, 41 | class_file=f'{data_root}/class_list.json', 42 | data_type='vie', 43 | max_seq_len=80, 44 | order_type='origin', 45 | auto_reg=True) 46 | 47 | train_list = [train] 48 | 49 | test_list = [test] 50 | --------------------------------------------------------------------------------