├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── config.yml └── workflows │ └── rapid_undistort.yml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── README_en.md ├── __init__.py ├── demo.py ├── img └── demo.jpg ├── preview1.gif ├── preview2.gif ├── rapid_undistorted ├── __init__.py ├── binary_predictor.py ├── config.yaml ├── inference.py ├── models │ └── .gitkeep ├── unblur_predictor.py ├── unshadow_predictor.py ├── unwrap_predictor.py └── utils │ ├── __init__.py │ ├── download_model.py │ ├── img_transform.py │ ├── infer_engine.py │ ├── load_image.py │ └── logger.py ├── requirements.txt ├── setup_undistort.py └── tests ├── test_basic.py └── test_files ├── demo1.jpg ├── demo1.png ├── demo2.png └── demo3.jpg /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐞 Bug 3 | about: Bug 4 | title: 'Bug' 5 | labels: 'Bug' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 请提供下述完整信息以便快速定位问题 11 | (Please provide the following information to quickly locate the problem) 12 | - **系统环境/System Environment**: 13 | - **使用的是哪门语言的程序/Which programing language**: 14 | - **使用当前库的版本/Use version**: 15 | - **可复现问题的demo和文件/Demo of reproducible problems**: 16 | - **完整报错/Complete Error Message**: 17 | - **可能的解决方案/Possible solutions**: -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Questions 4 | url: https://github.com/RapidAI/TableStructureRec/discussions/categories/q-a 5 | about: Please use the community forum for help and questions regarding ProcessLaTeXFormulaTools Docs 6 | - name: 💡 Feature requests and ideas 7 | url: https://github.com/RapidAI/TableStructureRec/discussions/new?category=feature-requests 8 | about: Please vote for and post new feature ideas in the community forum 9 | - name: 📖 Documentation 10 | url: https://rapidai.github.io/TableStructureRec/docs/ 11 | about: A great place to find instructions and answers about RapidOCR. -------------------------------------------------------------------------------- /.github/workflows/rapid_undistort.yml: -------------------------------------------------------------------------------- 1 | name: Push rapid_undistorted to pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - rapid_undistort_v* 7 | 8 | jobs: 9 | UnitTesting: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Pull latest code 13 | uses: actions/checkout@v3 14 | 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | architecture: 'x64' 20 | 21 | - name: Display Python version 22 | run: python -c "import sys; print(sys.version)" 23 | 24 | - name: Unit testings 25 | run: | 26 | pip install -r requirements.txt 27 | pip install pytest 28 | pytest tests/test_basic.py 29 | 30 | GenerateWHL_PushPyPi: 31 | needs: UnitTesting 32 | runs-on: ubuntu-latest 33 | 34 | steps: 35 | - uses: actions/checkout@v3 36 | 37 | - name: Set up Python 3.10 38 | uses: actions/setup-python@v4 39 | with: 40 | python-version: '3.10' 41 | architecture: 'x64' 42 | 43 | - name: Run setup.py 44 | run: | 45 | pip install -r requirements.txt 46 | python -m pip install --upgrade pip 47 | pip install wheel get_pypi_latest_version 48 | python setup_undistort.py bdist_wheel "${{ github.ref_name }}" 49 | 50 | - name: Publish distribution 📦 to PyPI 51 | uses: pypa/gh-action-pypi-publish@v1.5.0 52 | with: 53 | password: ${{ secrets.PYPI_API_TOKEN }} 54 | packages_dir: dist/ 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /output/ 2 | /rapid_unwrap/models/*.onnx 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/SWHL/autoflake 3 | rev: v2.1.1 4 | hooks: 5 | - id: autoflake 6 | args: 7 | [ 8 | "--recursive", 9 | "--in-place", 10 | "--remove-all-unused-imports", 11 | "--remove-unused-variable", 12 | "--ignore-init-module-imports", 13 | ] 14 | files: \.py$ 15 | - repo: https://gitee.com/SWHL/black 16 | rev: 23.1.0 17 | hooks: 18 | - id: black 19 | files: \.py$ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

📊RapidUndistort

4 |
5 | 6 | 7 | SemVer2.0 8 | 9 | GitHub 10 | 11 | [English](README_en.md) | 简体中文 12 |
13 | 14 | ### 最近更新 15 | 16 | - **2024.11.15** 17 | - 完成初版代码,转换 [UVDoc](https://github.com/tanguymagne/UVDoc) 模型为onnx,完善前后处理 18 | - **2024.12.15** 19 | - 补充去模糊/去阴影/二值化的功能和模型,重新升级为 RapidUndistort 20 | 21 | 22 | ### 简介 23 | 24 | 本仓库用于进行文档扭曲的修正/文档去模糊/文档去阴影/文档二值化等问题 \ 25 | 提供多种模型和自由组合选择task,支持模型自动下载 \ 26 | 原始pytorch模型来源参考 [致谢](#致谢) \ 27 | [快速使用](#快速使用) [使用建议](#使用建议) [参数说明](#参数说明) [模型地址](https://www.modelscope.cn/studios/jockerK/DocUnwrap/files) 28 | 29 | ### 在线体验 30 | [modelscope](https://www.modelscope.cn/studios/jockerK/DocUnwrap) [huggingface](https://huggingface.co/spaces/Joker1212/RapidUnwrap) 31 | ### 效果展示 32 | ![res_show.jpg](preview1.gif) 33 | ![res_show1.jpg](preview2.gif) 34 | 35 | ### 安装 36 | ``` python {linenos=table} 37 | # 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple 38 | pip install rapid-undistorted 39 | ``` 40 | 41 | ### 快速使用 42 | 43 | ``` python {linenos=table} 44 | import cv2 45 | 46 | from rapid_undistorted.inference import InferenceEngine 47 | img_path = "img/demo.jpg" 48 | engine = InferenceEngine() 49 | # 扭曲修正->去阴影->去模糊 (指定去模糊模型) 50 | output_img, elapse = engine(img_path, ["unwrap", "unshadow", ("unblur", "OpenCvBilateral")]) 51 | # 去阴影->去模糊 (指定去模糊模型) 52 | #output_img, elapse = engine(img_path, ["unshadow", ("unblur", "OpenCvBilateral")]) 53 | # 默认选择yaml配置文件中第一个unblur模型 54 | #output_img, elapse = engine(img_path, ["unshadow", "unblur"]) 55 | # 二值化替代去阴影方法 56 | #output_img, elapse = engine(img_path, ["unwrap", "binarize", "unblur"]) 57 | print(f"doc unwrap elapse:{elapse}") 58 | cv2.imwrite("result.png", output_img) 59 | 60 | ``` 61 | 62 | ### 使用建议 63 | - 去模糊 NAFDPM 模型和直接使用opencv方法都有适用的情况,最好还是在线实际测试为准 64 | - 去阴影模型相对于二值化功能更丰富,效果更好,不建议直接使用二值化方法 65 | 66 | 67 | ### 参数说明 68 | #### 初始化参数 69 | 支持传入一个config配置文件,声明需要的task类型和对应的模型,以及model path 70 | ```python 71 | config_path = "configs/config.yaml" 72 | engine = InferenceEngine(config_path) 73 | ``` 74 | ```yaml 75 | tasks: 76 | unwrap: 77 | models: 78 | - type: "UVDoc" 79 | path: 80 | use_cuda: false 81 | 82 | unshadow: 83 | models: 84 | - type: "GCDnet" 85 | sub_models: 86 | - type: "GCDnet" 87 | path: 88 | use_cuda: false 89 | use_dml: false 90 | - type: "DRnet" 91 | path: 92 | use_cuda: false 93 | 94 | binarize: 95 | models: 96 | - type: "UnetCnn" 97 | path: 98 | use_cuda: false 99 | 100 | unblur: 101 | models: 102 | - type: "OpenCvBilateral" 103 | path: 104 | - type: "NAFDPM" 105 | path: 106 | use_cuda: false 107 | 108 | ``` 109 | #### 执行参数 110 | ```python 111 | engine(img_path, task_list) 112 | engine(img_path, ["unwrap", "unshadow", ("unblur", "OpenCvBilateral")]) 113 | ``` 114 | 115 | ### 致谢 116 | 117 | unwrap: [UVDoc](https://github.com/tanguymagne/UVDoc) 118 | unshadow: [GCDnet](https://github.com/ZZZHANG-jx/GCDRNet) 119 | unblur: [NAFDPM](https://github.com/ispamm/NAF-DPM) 120 | binarize: [UnetCnn](https://github.com/sajjanvsl/U-Net-CNN-for-binarization-of-Historical-Kannada-Handwritten-Palm-Leaf-Manuscripts) 121 | 122 | ### 贡献指南 123 | 124 | 欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 125 | 126 | 有其他的好建议和集成场景,作者也会积极响应支持 127 | 128 | ### 开源许可证 129 | 130 | 该项目采用[Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) 131 | 开源许可证。 132 | 133 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

📊RapidUnWrap

4 |
5 | 6 | 7 | SemVer2.0 8 | 9 | GitHub 10 |
11 | 12 | ### Recent Updates 13 | 14 | - **2024.11.15** 15 | - Completed the initial version of the code, converted the [UVDoc](https://github.com/tanguymagne/UVDoc) model to onnx, and improved pre- and post-processing. 16 | - **2024.12.15** 17 | - Added deblurring/shadow removal/binarization functions and models, upgraded to RapidUndistort. 18 | 19 | ### Introduction 20 | 21 | This repository is used for correcting document distortion, deblurring documents, shadow removal, and document binarization. 22 | It provides multiple models and flexible task combinations, supports automatic model downloading. 23 | Original PyTorch model sources can be found in the [Acknowledgments](#acknowledgments) section. 24 | [Quick Start](#quick-start) [Usage Suggestions](#usage-suggestions) [Parameter Explanation](#parameter-explanation) [Model Address](https://www.modelscope.cn/studios/jockerK/DocUnwrap/files) 25 | 26 | ### Online Demo 27 | [modelscope](https://www.modelscope.cn/studios/jockerK/DocUnwrap) [huggingface](https://huggingface.co/spaces/Joker1212/RapidUnwrap) 28 | 29 | ### Effect Showcase 30 | ![res_show.jpg](preview1.gif) 31 | ![res_show1.jpg](preview2.gif) 32 | 33 | ### Installation 34 | ``` python {linenos=table} 35 | pip install rapid-undistorted 36 | ``` 37 | 38 | ### Quick Start 39 | 40 | ``` python {linenos=table} 41 | import cv2 42 | 43 | from rapid_undistorted.inference import InferenceEngine 44 | img_path = "img/demo.jpg" 45 | engine = InferenceEngine() 46 | # Distortion correction -> Shadow removal -> Deblurring (specify deblurring model) 47 | output_img, elapse = engine(img_path, ["unwrap", "unshadow", ("unblur", "OpenCvBilateral")]) 48 | # Shadow removal -> Deblurring (specify deblurring model) 49 | #output_img, elapse = engine(img_path, ["unshadow", ("unblur", "OpenCvBilateral")]) 50 | # Default selection of the first unblur model in the yaml configuration file 51 | #output_img, elapse = engine(img_path, ["unshadow", "unblur"]) 52 | # Binarization as an alternative to shadow removal method 53 | #output_img, elapse = engine(img_path, ["unwrap", "binarize", "unblur"]) 54 | print(f"doc unwrap elapse:{elapse}") 55 | cv2.imwrite("result.png", output_img) 56 | 57 | ``` 58 | 59 | ### Usage Suggestions 60 | - For English and numeric deblurring, the NAFDPM model is better, but for Chinese text, using the OpenCV method is more suitable. 61 | - The shadow removal model has richer functionality and better results compared to binarization, so it is not recommended to directly use the binarization method. 62 | 63 | 64 | ### Parameter Explanation 65 | #### Initialization Parameters 66 | Supports passing a config configuration file to declare the required task types and corresponding models, as well as paths. 67 | ```python 68 | config_path = "configs/config.yaml" 69 | engine = InferenceEngine(config_path) 70 | ``` 71 | ```yaml 72 | tasks: 73 | unwrap: 74 | models: 75 | - type: "UVDoc" 76 | path: 77 | use_cuda: false 78 | 79 | unshadow: 80 | models: 81 | - type: "GCDnet" 82 | sub_models: 83 | - type: "GCDnet" 84 | path: 85 | use_cuda: false 86 | use_dml: false 87 | - type: "DRnet" 88 | path: 89 | use_cuda: false 90 | 91 | binarize: 92 | models: 93 | - type: "UnetCnn" 94 | path: 95 | use_cuda: false 96 | 97 | unblur: 98 | models: 99 | - type: "OpenCvBilateral" 100 | path: 101 | - type: "NAFDPM" 102 | path: 103 | use_cuda: false 104 | 105 | ``` 106 | #### Execution Parameters 107 | ```python 108 | engine(img_path, task_list) 109 | engine(img_path, ["unwrap", "unshadow", ("unblur", "OpenCvBilateral")]) 110 | ``` 111 | 112 | ### Acknowledgments 113 | 114 | unwrap: [UVDoc](https://github.com/tanguymagne/UVDoc) 115 | unshadow: [GCDnet](https://github.com/ZZZHANG-jx/GCDRNet) 116 | unblur: [NAFDPM](https://github.com/ispamm/NAF-DPM) 117 | binarize: [UnetCnn](https://github.com/sajjanvsl/U-Net-CNN-for-binarization-of-Historical-Kannada-Handwritten-Palm-Leaf-Manuscripts) 118 | 119 | 120 | 121 | ### Contribution Guidelines 122 | 123 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 124 | 125 | If you have other good suggestions or integration scenarios, the author will actively respond and support them. 126 | 127 | 128 | ### Open Source License 129 | 130 | This project is licensed under the [Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) license. 131 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/__init__.py -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from rapid_undistorted.inference import InferenceEngine 3 | 4 | if __name__ == '__main__': 5 | 6 | img_path = "tests/test_files/demo1.jpg" 7 | engine = InferenceEngine() 8 | unwrapped_img, elapse = engine(img_path, ["unwrap", "unshadow", ("unblur", "OpenCvBilateral")]) 9 | print(f"doc unwrap elapse:{elapse}") 10 | cv2.imwrite("unwarped.png", unwrapped_img) 11 | -------------------------------------------------------------------------------- /img/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/img/demo.jpg -------------------------------------------------------------------------------- /preview1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/preview1.gif -------------------------------------------------------------------------------- /preview2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/preview2.gif -------------------------------------------------------------------------------- /rapid_undistorted/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/rapid_undistorted/__init__.py -------------------------------------------------------------------------------- /rapid_undistorted/binary_predictor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from rapid_undistorted.utils.img_transform import restore_original_size, pad_to_multiple_of_n 8 | from rapid_undistorted.utils.infer_engine import OrtInferSession 9 | 10 | 11 | class UnetCNN(): 12 | def __init__(self, config: dict = None): 13 | self.unet_session = OrtInferSession(config) 14 | 15 | def __call__(self, img: np.ndarray): 16 | s = time.time() 17 | img, pad_info = self.preprocess(img) 18 | pred = self.unet_session([img])[0] 19 | out_img = self.postprocess(pred, pad_info) 20 | elapse = time.time() - s 21 | return out_img,elapse 22 | 23 | def preprocess(self, img: np.ndarray): 24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 25 | img, pad_info = pad_to_multiple_of_n(img) 26 | # 归一化 27 | img = img.transpose(2, 0, 1) / 255.0 28 | # 将图像数据扩展为一个批次的形式 29 | img = np.expand_dims(img, axis=0).astype(np.float32) 30 | # 转换为模型输入格式 31 | return img, pad_info 32 | 33 | def postprocess(self, 34 | img: np.ndarray, 35 | pad_info): 36 | img = 1 - (img - img.min()) / (img.max() - img.min()) 37 | img = img[0].transpose(1, 2, 0) 38 | # 重复最后一个通道维度三次 39 | img = np.repeat(img, 3, axis=2) 40 | img = (img * 255 + 0.5).clip(0, 255) 41 | img = restore_original_size(img, pad_info) 42 | return img 43 | -------------------------------------------------------------------------------- /rapid_undistorted/config.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | unwrap: 3 | models: 4 | - type: "UVDoc" 5 | path: 6 | use_cuda: false 7 | use_dml: false 8 | 9 | unshadow: 10 | models: 11 | - type: "GCDnet" 12 | sub_models: 13 | - type: "GCDnet" 14 | path: 15 | use_cuda: false 16 | use_dml: false 17 | - type: "DRnet" 18 | path: 19 | use_cuda: false 20 | use_dml: false 21 | 22 | binarize: 23 | models: 24 | - type: "UnetCnn" 25 | path: 26 | use_cuda: false 27 | use_dml: false 28 | 29 | unblur: 30 | models: 31 | - type: "OpenCvBilateral" 32 | path: 33 | - type: "NAFDPM" 34 | path: 35 | use_cuda: false 36 | use_dml: false 37 | 38 | 39 | -------------------------------------------------------------------------------- /rapid_undistorted/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | from typing import Optional, Union, Tuple, List 5 | import numpy as np 6 | import yaml 7 | from .binary_predictor import UnetCNN 8 | from .unblur_predictor import NAF_DPM, OpenCvBilateral 9 | from .unshadow_predictor import GCDRNET 10 | from .unwrap_predictor import UVDocPredictor 11 | from .utils.download_model import DownloadModel 12 | from .utils.load_image import LoadImage 13 | from .utils.logger import get_logger 14 | 15 | root_dir = Path(__file__).resolve().parent 16 | model_dir = os.path.join(root_dir, "models") 17 | logger = get_logger("rapid_undistorted") 18 | default_config = os.path.join(root_dir, "config.yaml") 19 | ROOT_URL = "https://www.modelscope.cn/studio/jockerK/DocUnwrap/resolve/master/models/" 20 | KEY_TO_MODEL_URL = { 21 | "unwrap": { 22 | "UVDoc": f"{ROOT_URL}/uvdoc.onnx", 23 | }, 24 | "unshadow": { 25 | "GCDnet": f"{ROOT_URL}/gcnet.onnx", 26 | "DRnet": f"{ROOT_URL}/drnet.onnx", 27 | }, 28 | "binarize": { 29 | "UnetCnn": f"{ROOT_URL}/unetcnn.onnx", 30 | }, 31 | "unblur": { 32 | "NAFDPM": f"{ROOT_URL}/nafdpm.onnx", 33 | }, 34 | } 35 | MODEL_CLASS_MAP = { 36 | "unwrap": { 37 | "UVDoc": UVDocPredictor, 38 | }, 39 | "unshadow": { 40 | "GCDnet": GCDRNET, 41 | }, 42 | "binarize": { 43 | "UnetCnn": UnetCNN, 44 | }, 45 | "unblur": { 46 | "NAFDPM": NAF_DPM, 47 | "OpenCvBilateral": OpenCvBilateral 48 | } 49 | } 50 | 51 | 52 | class InferenceEngine: 53 | def __init__(self, config_path: str = str(default_config)): 54 | with open(config_path, 'r') as file: 55 | config = yaml.safe_load(file) 56 | 57 | self.img_loader = LoadImage() 58 | self.tasks = config['tasks'] 59 | self.models = {} 60 | self.configs = {} 61 | self.initialize_models() 62 | 63 | def initialize_models(self): 64 | for task, task_config in self.tasks.items(): 65 | if 'models' not in task_config: 66 | raise ValueError(f"config has no models, task:{task}") 67 | if not self.configs.get(task, None): 68 | self.configs[task] = {} 69 | self.models[task] = {} 70 | for model_config in task_config['models']: 71 | model_type = model_config['type'] 72 | model_class = MODEL_CLASS_MAP.get(task, {}).get(model_type, None) 73 | if not model_class: 74 | raise ValueError(f"Model class {model_type} not found in MODEL_CLASS_MAP") 75 | if not self.configs[task].get(model_type, None): 76 | self.configs[task][model_type] = {} 77 | self.models[task][model_type] = {} 78 | if 'sub_models' in model_config: 79 | for sub_model_config in model_config['sub_models']: 80 | self.init_submodel_config(task, model_type, sub_model_config) 81 | else: 82 | self.init_model_config(task, model_type, model_config) 83 | self.models[task][model_type] = model_class(self.configs[task][model_type]) 84 | 85 | def init_model_config(self, task, model_type, model_config): 86 | model_path = model_config.get('path', None) 87 | use_cuda = model_config.get('use_cuda', False) 88 | use_dml = model_config.get('use_dml', False) 89 | # use model by model_path or download model 90 | model_path = self.get_model_path(task, model_type, model_path) 91 | self.configs[task][model_type] = { 92 | "model_path": model_path, 93 | "use_cuda": use_cuda, 94 | "use_dml": use_dml, 95 | } 96 | 97 | def init_submodel_config(self, task, model_type, sub_model_config): 98 | sub_model_type = sub_model_config['type'] 99 | sub_model_path = sub_model_config.get('path', None) 100 | sub_use_cuda = sub_model_config.get('use_cuda', False) 101 | sub_use_dml = sub_model_config.get('use_dml', False) 102 | sub_model_path = self.get_model_path(task, sub_model_type, sub_model_path) 103 | self.configs[task][model_type][sub_model_type] = { 104 | "model_path": sub_model_path, 105 | "use_cuda": sub_use_cuda, 106 | "use_dml": sub_use_dml, 107 | } 108 | 109 | def __call__( 110 | self, 111 | img_content: Union[str, np.ndarray, bytes, Path], 112 | task_list: List[Union[str, Tuple[str, str]]] 113 | ) -> Tuple[np.ndarray, dict]: 114 | img = self.img_loader(img_content) 115 | elapses = {} 116 | 117 | for task in task_list: 118 | if isinstance(task, tuple): 119 | task_name, model_type = task 120 | else: 121 | task_name = task 122 | model_type = next(iter(self.models.get(task, []))) 123 | if not self.models.get(task_name, None): 124 | raise ValueError(f"Task '{task}' not found in the configuration.") 125 | if not self.models.get(task_name).get(model_type): 126 | raise ValueError(f"Task '{task}, Model Type : {model_type}' not found in the configuration.") 127 | if not elapses.get(task, None): 128 | elapses[task] = {} 129 | model_instance = self.models[task_name][model_type] 130 | img, elapse = model_instance(img) 131 | elapses[task][model_type] = elapse 132 | return img, elapses 133 | 134 | @staticmethod 135 | def get_model_path(task: str, model_type: str, model_path: Union[str, Path, None]) -> str: 136 | if model_path is not None: 137 | return model_path 138 | 139 | model_url = KEY_TO_MODEL_URL.get(task, {}).get(model_type, None) 140 | if model_url: 141 | model_path = DownloadModel.download(model_url) 142 | return model_path 143 | 144 | logger.info( 145 | "model url is None, using the default download model %s", model_path 146 | ) 147 | return model_path 148 | -------------------------------------------------------------------------------- /rapid_undistorted/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/rapid_undistorted/models/.gitkeep -------------------------------------------------------------------------------- /rapid_undistorted/unblur_predictor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from rapid_undistorted.utils.img_transform import restore_original_size, pad_to_multiple_of_n 8 | from rapid_undistorted.utils.infer_engine import OrtInferSession 9 | 10 | class NAF_DPM(): 11 | def __init__(self, config=None): 12 | 13 | self.naf_dpm_session = OrtInferSession(config) 14 | def __call__(self, img: np.ndarray): 15 | s = time.time() 16 | img = self.preprocess(img) 17 | pred = self.naf_dpm_session([img])[0] 18 | out_img = self.postprocess(pred) 19 | elapse = time.time() - s 20 | return out_img,elapse 21 | 22 | 23 | def preprocess(self, img: np.ndarray): 24 | # 归一化 25 | img = img.transpose(2, 0, 1) / 255.0 26 | # 将图像数据扩展为一个批次的形式 27 | img = np.expand_dims(img, axis=0).astype(np.float32) 28 | # 转换为模型输入格式 29 | return img 30 | 31 | def postprocess(self, 32 | img: np.ndarray): 33 | img = img[0] 34 | img = (img * 255 + 0.5).clip(0, 255).transpose(1, 2, 0) 35 | return img 36 | 37 | 38 | class OpenCvBilateral: 39 | def __init__(self, config=None): 40 | pass 41 | def __call__(self, img): 42 | s = time.time() 43 | img = img.astype(np.uint8) 44 | # 双边滤波 45 | bilateral = cv2.bilateralFilter(img, d=9, sigmaColor=75, sigmaSpace=75) 46 | # 自适应直方图均衡化 47 | lab = cv2.cvtColor(bilateral, cv2.COLOR_BGR2LAB) 48 | l, a, b = cv2.split(lab) 49 | clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) 50 | cl = clahe.apply(l) 51 | limg = cv2.merge((cl, a, b)) 52 | enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) 53 | 54 | # 应用锐化滤波器 55 | kernel = np.array([[0, -1, 0], 56 | [-1, 5, -1], 57 | [0, -1, 0]]) 58 | sharpened = cv2.filter2D(enhanced, -1, kernel) 59 | elapse = time.time() - s 60 | return sharpened,elapse 61 | -------------------------------------------------------------------------------- /rapid_undistorted/unshadow_predictor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from rapid_undistorted.utils.infer_engine import OrtInferSession 7 | 8 | class GCDRNET(): 9 | def __init__(self, config=None): 10 | gcnet_config = config.get("GCDnet") 11 | drnet_config = config.get("DRnet") 12 | self.gcnet_session = OrtInferSession(gcnet_config) 13 | self.drnet_session = OrtInferSession(drnet_config) 14 | 15 | def __call__(self, img): 16 | s = time.time() 17 | im_padding, padding_h, padding_w = self.preprocess(img.copy()) 18 | img_shadow = im_padding.copy() 19 | img_shadow = self.gcnet_session([img_shadow])[0] 20 | model1_im = np.clip(im_padding / img_shadow, 0, 1) 21 | # 拼接 im_org 和 model1_im 22 | concatenated_input = np.concatenate((im_padding, model1_im), axis=1) 23 | pred = self.drnet_session([concatenated_input])[0] 24 | elapse = time.time() - s 25 | return self.postprocess(pred, padding_h, padding_w), elapse 26 | 27 | def stride_integral(self, img, stride=32): 28 | h, w = img.shape[:2] 29 | 30 | if (h % stride) != 0: 31 | padding_h = stride - (h % stride) 32 | img = cv2.copyMakeBorder(img, padding_h, 0, 0, 0, borderType=cv2.BORDER_REPLICATE) 33 | else: 34 | padding_h = 0 35 | 36 | if (w % stride) != 0: 37 | padding_w = stride - (w % stride) 38 | img = cv2.copyMakeBorder(img, 0, 0, padding_w, 0, borderType=cv2.BORDER_REPLICATE) 39 | else: 40 | padding_w = 0 41 | 42 | return img, padding_h, padding_w 43 | 44 | def preprocess(self, img): 45 | img, padding_h, padding_w = self.stride_integral(img) 46 | # 归一化 47 | img = img.transpose(2, 0, 1) / 255.0 48 | img = np.expand_dims(img, axis=0).astype(np.float32) 49 | # 转换为模型输入格式 50 | return img, padding_h, padding_w 51 | 52 | def postprocess(self, pred, padding_h, padding_w): 53 | pred = np.transpose(pred[0], (1, 2, 0)) 54 | pred = pred * 255 55 | enhance_img = pred[padding_h:, padding_w:] 56 | return enhance_img 57 | -------------------------------------------------------------------------------- /rapid_undistorted/unwrap_predictor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | from scipy.ndimage import map_coordinates 6 | 7 | from .utils.infer_engine import OrtInferSession 8 | 9 | 10 | class UVDocPredictor: 11 | 12 | def __init__(self, config): 13 | self.session = OrtInferSession(config) 14 | self.img_size = [488, 712] 15 | self.grid_size = [45, 31] 16 | 17 | def __call__(self, img): 18 | s = time.time() 19 | size = img.shape[:2][::-1] 20 | img = img.astype(np.float32) / 255 21 | inp = self.preprocess(img.copy()) 22 | outputs, _ = self.session([inp]) 23 | elapse = time.time() - s 24 | return self.postprocess(img, size, outputs),elapse 25 | 26 | def preprocess(self, img): 27 | img = cv2.resize(img, self.img_size).transpose(2, 0, 1) 28 | img = np.expand_dims(img, axis=0) 29 | return img 30 | 31 | def postprocess(self, img, size, output): 32 | 33 | # 将图像转换为NumPy数组 34 | warped_img = np.expand_dims(img.transpose(2, 0, 1), axis=0).astype(np.float32) 35 | 36 | # 上采样网格 37 | upsampled_grid = self.interpolate(output, size=(size[1], size[0]), align_corners=True) 38 | # 调整网格的形状 39 | upsampled_grid = upsampled_grid.transpose(0, 2, 3, 1) 40 | 41 | # 重映射图像 42 | unwarped_img = self.grid_sample(warped_img, upsampled_grid) 43 | 44 | # 将结果转换回原始格式 45 | return unwarped_img[0].transpose(1, 2, 0) * 255 46 | 47 | def interpolate(self, input_tensor, size, align_corners=True): 48 | """ 49 | Interpolate function to resize the input tensor. 50 | 51 | Args: 52 | input_tensor: numpy.ndarray of shape (B, C, H, W) 53 | size: tuple of int (new_height, new_width) 54 | mode: str, interpolation mode ('bilinear' or 'nearest') 55 | align_corners: bool, whether to align corners in bilinear interpolation 56 | 57 | Returns: 58 | numpy.ndarray of shape (B, C, new_height, new_width) 59 | """ 60 | B, C, H, W = input_tensor.shape 61 | new_H, new_W = size 62 | resized_tensors = [] 63 | for b in range(B): 64 | resized_channels = [] 65 | for c in range(C): 66 | # 计算新的坐标 67 | if align_corners: 68 | scale_h = (H - 1) / (new_H - 1) if new_H > 1 else 0 69 | scale_w = (W - 1) / (new_W - 1) if new_W > 1 else 0 70 | else: 71 | scale_h = H / new_H 72 | scale_w = W / new_W 73 | 74 | # 创建新的坐标网格 75 | y, x = np.indices((new_H, new_W), dtype=np.float32) 76 | y = y * scale_h 77 | x = x * scale_w 78 | 79 | # 双线性插值 80 | coords = np.stack([y.flatten(), x.flatten()], axis=0) 81 | resized_channel = map_coordinates(input_tensor[b, c], coords, order=1, mode='constant', cval=0.0) 82 | resized_channel = resized_channel.reshape(new_H, new_W) 83 | resized_channels.append(resized_channel) 84 | 85 | resized_tensors.append(np.stack(resized_channels, axis=0)) 86 | 87 | return np.stack(resized_tensors, axis=0) 88 | 89 | def grid_sample(self, input_tensor, grid, align_corners=True): 90 | """ 91 | Grid sample function to sample the input tensor using the given grid. 92 | 93 | Args: 94 | input_tensor: numpy.ndarray of shape (B, C, H, W) 95 | grid: numpy.ndarray of shape (B, H, W, 2) with values in [-1, 1] 96 | align_corners: bool, whether to align corners in bilinear interpolation 97 | 98 | Returns: 99 | numpy.ndarray of shape (B, C, H, W) 100 | """ 101 | B, C, H, W = input_tensor.shape 102 | B_grid, H_grid, W_grid, _ = grid.shape 103 | 104 | if B != B_grid or H != H_grid or W != W_grid: 105 | raise ValueError("Input tensor and grid must have the same spatial dimensions.") 106 | 107 | # Convert grid coordinates from [-1, 1] to [0, W-1] and [0, H-1] 108 | if align_corners: 109 | grid[:, :, :, 0] = (grid[:, :, :, 0] + 1) * (W - 1) / 2 110 | grid[:, :, :, 1] = (grid[:, :, :, 1] + 1) * (H - 1) / 2 111 | else: 112 | grid[:, :, :, 0] = ((grid[:, :, :, 0] + 1) * W - 1) / 2 113 | grid[:, :, :, 1] = ((grid[:, :, :, 1] + 1) * H - 1) / 2 114 | 115 | sampled_tensors = [] 116 | for b in range(B): 117 | sampled_channels = [] 118 | for c in range(C): 119 | channel = input_tensor[b, c] 120 | x_coords = grid[b, :, :, 0].flatten() 121 | y_coords = grid[b, :, :, 1].flatten() 122 | coords = np.stack([y_coords, x_coords], axis=-1) 123 | sampled_channel = map_coordinates(channel, coords.T, order=1, mode='constant', cval=0.0).reshape(H, W) 124 | sampled_channels.append(sampled_channel) 125 | sampled_tensors.append(np.stack(sampled_channels, axis=0)) 126 | 127 | return np.stack(sampled_tensors, axis=0) 128 | -------------------------------------------------------------------------------- /rapid_undistorted/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/rapid_undistorted/utils/__init__.py -------------------------------------------------------------------------------- /rapid_undistorted/utils/download_model.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | from .logger import get_logger 9 | 10 | logger = get_logger("DownloadModel") 11 | CUR_DIR = Path(__file__).resolve() 12 | PROJECT_DIR = CUR_DIR.parent.parent 13 | 14 | 15 | class DownloadModel: 16 | cur_dir = PROJECT_DIR 17 | 18 | @classmethod 19 | def download(cls, model_full_url: Union[str, Path]) -> str: 20 | save_dir = cls.cur_dir / "models" 21 | save_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | model_name = Path(model_full_url).name 24 | save_file_path = save_dir / model_name 25 | if save_file_path.exists(): 26 | logger.debug("%s already exists", save_file_path) 27 | return str(save_file_path) 28 | 29 | try: 30 | logger.info("Download %s to %s", model_full_url, save_dir) 31 | file = cls.download_as_bytes_with_progress(model_full_url, model_name) 32 | cls.save_file(save_file_path, file) 33 | except Exception as exc: 34 | raise DownloadModelError from exc 35 | return str(save_file_path) 36 | 37 | @staticmethod 38 | def download_as_bytes_with_progress( 39 | url: Union[str, Path], name: Optional[str] = None 40 | ) -> bytes: 41 | resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) 42 | total = int(resp.headers.get("content-length", 0)) 43 | bio = io.BytesIO() 44 | with tqdm( 45 | desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 46 | ) as pbar: 47 | for chunk in resp.iter_content(chunk_size=65536): 48 | pbar.update(len(chunk)) 49 | bio.write(chunk) 50 | return bio.getvalue() 51 | 52 | @staticmethod 53 | def save_file(save_path: Union[str, Path], file: bytes): 54 | with open(save_path, "wb") as f: 55 | f.write(file) 56 | 57 | 58 | class DownloadModelError(Exception): 59 | pass 60 | -------------------------------------------------------------------------------- /rapid_undistorted/utils/img_transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def pad_to_multiple_of_n(image, n=32): 6 | original_height, original_width = image.shape[:2] 7 | 8 | # 计算目标形状 9 | target_width = ((original_width + n - 1) // n) * n 10 | target_height = ((original_height + n - 1) // n) * n 11 | 12 | # 创建一个纯白背景的图像 13 | padded_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 14 | 15 | # 计算填充的位置 16 | start_x = (target_width - original_width) // 2 17 | start_y = (target_height - original_height) // 2 18 | 19 | # 将原始图像放置在纯白背景上 20 | padded_image[start_y:start_y + original_height, start_x:start_x + original_width] = image 21 | 22 | # 返回填充后的图像和填充位置 23 | return padded_image, (start_x, start_y, original_height, original_width) 24 | 25 | def restore_original_size(image, pad_info): 26 | start_x, start_y, original_height, original_width = pad_info 27 | 28 | # 去掉填充部分 29 | cropped_image = image[start_y:start_y + original_height, start_x:start_x + original_width] 30 | 31 | return cropped_image 32 | 33 | # def resize_and_pad(image, target_shape): 34 | # original_height, original_width = image.shape[:2] 35 | # target_height, target_width = target_shape 36 | # 37 | # # 计算缩放比例 38 | # scale = min(target_width / original_width, target_height / original_height) 39 | # 40 | # # 计算新的尺寸 41 | # new_width = int(original_width * scale) 42 | # new_height = int(original_height * scale) 43 | # 44 | # # 缩放图像 45 | # resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR) 46 | # 47 | # # 创建一个纯白背景的图像 48 | # padded_image = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255 49 | # 50 | # # 计算填充的位置 51 | # start_x = (target_width - new_width) // 2 52 | # start_y = (target_height - new_height) // 2 53 | # 54 | # # 将缩放后的图像放置在纯白背景上 55 | # padded_image[start_y:start_y + new_height, start_x:start_x + new_width] = resized_image 56 | # 57 | # return padded_image, (start_x, start_y,original_height, original_width, new_height, new_width) 58 | # 59 | # def restore_original_size(image, pad_info): 60 | # start_x, start_y, original_height, original_width, new_height, new_width = pad_info 61 | # 62 | # # 去掉填充部分 63 | # cropped_image = image[start_y:start_y + new_height, start_x:start_x + new_width] 64 | # 65 | # # 缩放回原大小 66 | # restored_image = cv2.resize(cropped_image, (original_width, original_height), interpolation=cv2.INTER_LINEAR) 67 | # 68 | # return restored_image 69 | -------------------------------------------------------------------------------- /rapid_undistorted/utils/infer_engine.py: -------------------------------------------------------------------------------- 1 | from .logger import get_logger 2 | import os 3 | import platform 4 | import traceback 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Tuple, Union 8 | 9 | import numpy as np 10 | from onnxruntime import ( 11 | GraphOptimizationLevel, 12 | InferenceSession, 13 | SessionOptions, 14 | get_available_providers, 15 | get_device, 16 | ) 17 | 18 | 19 | class EP(Enum): 20 | CPU_EP = "CPUExecutionProvider" 21 | CUDA_EP = "CUDAExecutionProvider" 22 | DIRECTML_EP = "DmlExecutionProvider" 23 | 24 | 25 | class OrtInferSession: 26 | def __init__(self, config: Dict[str, Any]): 27 | self.logger = get_logger("OrtInferSession") 28 | 29 | model_path = config.get("model_path", None) 30 | self._verify_model(model_path) 31 | 32 | self.cfg_use_cuda = config.get("use_cuda", None) 33 | self.cfg_use_dml = config.get("use_dml", None) 34 | 35 | self.had_providers: List[str] = get_available_providers() 36 | EP_list = self._get_ep_list() 37 | 38 | sess_opt = self._init_sess_opts(config) 39 | self.session = InferenceSession( 40 | model_path, 41 | sess_options=sess_opt, 42 | providers=EP_list, 43 | ) 44 | self._verify_providers() 45 | 46 | @staticmethod 47 | def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: 48 | sess_opt = SessionOptions() 49 | sess_opt.log_severity_level = 4 50 | sess_opt.enable_cpu_mem_arena = False 51 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 52 | 53 | cpu_nums = os.cpu_count() 54 | intra_op_num_threads = config.get("intra_op_num_threads", -1) 55 | if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: 56 | sess_opt.intra_op_num_threads = intra_op_num_threads 57 | 58 | inter_op_num_threads = config.get("inter_op_num_threads", -1) 59 | if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: 60 | sess_opt.inter_op_num_threads = inter_op_num_threads 61 | 62 | return sess_opt 63 | 64 | def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: 65 | cpu_provider_opts = { 66 | "arena_extend_strategy": "kSameAsRequested", 67 | } 68 | EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] 69 | 70 | cuda_provider_opts = { 71 | "device_id": 0, 72 | "arena_extend_strategy": "kNextPowerOfTwo", 73 | "cudnn_conv_algo_search": "EXHAUSTIVE", 74 | "do_copy_in_default_stream": True, 75 | } 76 | self.use_cuda = self._check_cuda() 77 | if self.use_cuda: 78 | EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) 79 | 80 | self.use_directml = self._check_dml() 81 | if self.use_directml: 82 | self.logger.info( 83 | "Windows 10 or above detected, try to use DirectML as primary provider" 84 | ) 85 | directml_options = ( 86 | cuda_provider_opts if self.use_cuda else cpu_provider_opts 87 | ) 88 | EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) 89 | return EP_list 90 | 91 | def _check_cuda(self) -> bool: 92 | if not self.cfg_use_cuda: 93 | return False 94 | 95 | cur_device = get_device() 96 | if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: 97 | return True 98 | 99 | self.logger.warning( 100 | "%s is not in available providers (%s). Use %s inference by default.", 101 | EP.CUDA_EP.value, 102 | self.had_providers, 103 | self.had_providers[0], 104 | ) 105 | self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") 106 | self.logger.info( 107 | "(For reference only) If you want to use GPU acceleration, you must do:" 108 | ) 109 | self.logger.info( 110 | "First, uninstall all onnxruntime pakcages in current environment." 111 | ) 112 | self.logger.info( 113 | "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." 114 | ) 115 | self.logger.info( 116 | "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." 117 | ) 118 | self.logger.info( 119 | "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" 120 | ) 121 | self.logger.info( 122 | "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", 123 | EP.CUDA_EP.value, 124 | ) 125 | return False 126 | 127 | def _check_dml(self) -> bool: 128 | if not self.cfg_use_dml: 129 | return False 130 | 131 | cur_os = platform.system() 132 | if cur_os != "Windows": 133 | self.logger.warning( 134 | "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", 135 | cur_os, 136 | self.had_providers[0], 137 | ) 138 | return False 139 | 140 | cur_window_version = int(platform.release().split(".")[0]) 141 | if cur_window_version < 10: 142 | self.logger.warning( 143 | "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", 144 | cur_window_version, 145 | self.had_providers[0], 146 | ) 147 | return False 148 | 149 | if EP.DIRECTML_EP.value in self.had_providers: 150 | return True 151 | 152 | self.logger.warning( 153 | "%s is not in available providers (%s). Use %s inference by default.", 154 | EP.DIRECTML_EP.value, 155 | self.had_providers, 156 | self.had_providers[0], 157 | ) 158 | self.logger.info("If you want to use DirectML acceleration, you must do:") 159 | self.logger.info( 160 | "First, uninstall all onnxruntime pakcages in current environment." 161 | ) 162 | self.logger.info( 163 | "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" 164 | ) 165 | self.logger.info( 166 | "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", 167 | EP.DIRECTML_EP.value, 168 | ) 169 | return False 170 | 171 | def _verify_providers(self): 172 | session_providers = self.session.get_providers() 173 | first_provider = session_providers[0] 174 | 175 | if self.use_cuda and first_provider != EP.CUDA_EP.value: 176 | self.logger.warning( 177 | "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", 178 | EP.CUDA_EP.value, 179 | first_provider, 180 | ) 181 | 182 | if self.use_directml and first_provider != EP.DIRECTML_EP.value: 183 | self.logger.warning( 184 | "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", 185 | EP.DIRECTML_EP.value, 186 | first_provider, 187 | ) 188 | 189 | def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: 190 | input_dict = dict(zip(self.get_input_names(), input_content)) 191 | try: 192 | return self.session.run(None, input_dict) 193 | except Exception as e: 194 | error_info = traceback.format_exc() 195 | raise ONNXRuntimeError(error_info) from e 196 | 197 | def get_input_names(self) -> List[str]: 198 | return [v.name for v in self.session.get_inputs()] 199 | 200 | def get_output_names(self) -> List[str]: 201 | return [v.name for v in self.session.get_outputs()] 202 | 203 | def get_character_list(self, key: str = "character") -> List[str]: 204 | meta_dict = self.session.get_modelmeta().custom_metadata_map 205 | return meta_dict[key].splitlines() 206 | 207 | def have_key(self, key: str = "character") -> bool: 208 | meta_dict = self.session.get_modelmeta().custom_metadata_map 209 | if key in meta_dict.keys(): 210 | return True 211 | return False 212 | 213 | @staticmethod 214 | def _verify_model(model_path: Union[str, Path, None]): 215 | if model_path is None: 216 | raise ValueError("model_path is None!") 217 | 218 | model_path = Path(model_path) 219 | if not model_path.exists(): 220 | raise FileNotFoundError(f"{model_path} does not exists.") 221 | 222 | if not model_path.is_file(): 223 | raise FileExistsError(f"{model_path} is not a file.") 224 | 225 | 226 | class ONNXRuntimeError(Exception): 227 | pass 228 | -------------------------------------------------------------------------------- /rapid_undistorted/utils/load_image.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Any, Union 7 | 8 | import cv2 9 | import numpy as np 10 | from PIL import Image, UnidentifiedImageError 11 | 12 | root_dir = Path(__file__).resolve().parent 13 | InputType = Union[str, np.ndarray, bytes, Path, Image.Image] 14 | 15 | 16 | class LoadImage: 17 | def __init__(self): 18 | pass 19 | 20 | def __call__(self, img: InputType) -> np.ndarray: 21 | if not isinstance(img, InputType.__args__): 22 | raise LoadImageError( 23 | f"The img type {type(img)} does not in {InputType.__args__}" 24 | ) 25 | 26 | origin_img_type = type(img) 27 | img = self.load_img(img) 28 | img = self.convert_img(img, origin_img_type) 29 | return img 30 | 31 | def load_img(self, img: InputType) -> np.ndarray: 32 | if isinstance(img, (str, Path)): 33 | self.verify_exist(img) 34 | try: 35 | img = self.img_to_ndarray(Image.open(img)) 36 | except UnidentifiedImageError as e: 37 | raise LoadImageError(f"cannot identify image file {img}") from e 38 | return img 39 | 40 | if isinstance(img, bytes): 41 | img = self.img_to_ndarray(Image.open(BytesIO(img))) 42 | return img 43 | 44 | if isinstance(img, np.ndarray): 45 | return img 46 | 47 | if isinstance(img, Image.Image): 48 | return self.img_to_ndarray(img) 49 | 50 | raise LoadImageError(f"{type(img)} is not supported!") 51 | 52 | def img_to_ndarray(self, img: Image.Image) -> np.ndarray: 53 | if img.mode == "1": 54 | img = img.convert("L") 55 | return np.array(img) 56 | return np.array(img) 57 | 58 | def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: 59 | if img.ndim == 2: 60 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 61 | 62 | if img.ndim == 3: 63 | channel = img.shape[2] 64 | if channel == 1: 65 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 66 | 67 | if channel == 2: 68 | return self.cvt_two_to_three(img) 69 | 70 | if channel == 3: 71 | if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): 72 | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 73 | return img 74 | 75 | if channel == 4: 76 | return self.cvt_four_to_three(img) 77 | 78 | raise LoadImageError( 79 | f"The channel({channel}) of the img is not in [1, 2, 3, 4]" 80 | ) 81 | 82 | raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") 83 | 84 | @staticmethod 85 | def cvt_two_to_three(img: np.ndarray) -> np.ndarray: 86 | """gray + alpha → BGR""" 87 | img_gray = img[..., 0] 88 | img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) 89 | 90 | img_alpha = img[..., 1] 91 | not_a = cv2.bitwise_not(img_alpha) 92 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 93 | 94 | new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) 95 | new_img = cv2.add(new_img, not_a) 96 | return new_img 97 | 98 | @staticmethod 99 | def cvt_four_to_three(img: np.ndarray) -> np.ndarray: 100 | """RGBA → BGR""" 101 | r, g, b, a = cv2.split(img) 102 | new_img = cv2.merge((b, g, r)) 103 | 104 | not_a = cv2.bitwise_not(a) 105 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 106 | 107 | new_img = cv2.bitwise_and(new_img, new_img, mask=a) 108 | 109 | mean_color = np.mean(new_img) 110 | if mean_color <= 0.0: 111 | new_img = cv2.add(new_img, not_a) 112 | else: 113 | new_img = cv2.bitwise_not(new_img) 114 | return new_img 115 | 116 | @staticmethod 117 | def verify_exist(file_path: Union[str, Path]): 118 | if not Path(file_path).exists(): 119 | raise LoadImageError(f"{file_path} does not exist.") 120 | 121 | 122 | class LoadImageError(Exception): 123 | pass 124 | -------------------------------------------------------------------------------- /rapid_undistorted/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import logging 5 | from functools import lru_cache 6 | 7 | 8 | @lru_cache(maxsize=32) 9 | def get_logger(name: str) -> logging.Logger: 10 | logger = logging.getLogger(name) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" 14 | format_str = logging.Formatter(fmt) 15 | 16 | sh = logging.StreamHandler() 17 | sh.setLevel(logging.DEBUG) 18 | 19 | logger.addHandler(sh) 20 | sh.setFormatter(format_str) 21 | return logger 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | onnxruntime 4 | scipy 5 | pillow 6 | requests 7 | tqdm 8 | pyyaml 9 | -------------------------------------------------------------------------------- /setup_undistort.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import sys 5 | from typing import List, Union 6 | from pathlib import Path 7 | from get_pypi_latest_version import GetPyPiLatestVersion 8 | 9 | import setuptools 10 | 11 | 12 | def read_txt(txt_path: Union[Path, str]) -> List[str]: 13 | with open(txt_path, "r", encoding="utf-8") as f: 14 | data = [v.rstrip("\n") for v in f] 15 | return data 16 | 17 | 18 | MODULE_NAME = "rapid_undistorted" 19 | 20 | obtainer = GetPyPiLatestVersion() 21 | try: 22 | latest_version = obtainer(MODULE_NAME) 23 | except Exception: 24 | latest_version = "0.0.0" 25 | 26 | VERSION_NUM = obtainer.version_add_one(latest_version) 27 | 28 | if len(sys.argv) > 2: 29 | match_str = " ".join(sys.argv[2:]) 30 | matched_versions = obtainer.extract_version(match_str) 31 | if matched_versions: 32 | VERSION_NUM = matched_versions 33 | sys.argv = sys.argv[:2] 34 | 35 | setuptools.setup( 36 | name=MODULE_NAME, 37 | version=VERSION_NUM, 38 | platforms="Any", 39 | description="table detection with onnx model", 40 | long_description="table detection with onnx model", 41 | author="jockerK", 42 | author_email="xinyijianggo@gmail.com", 43 | url="https://github.com/Joker1212/RapidTableDetection", 44 | license="Apache-2.0", 45 | install_requires=read_txt("requirements.txt"), 46 | include_package_data=True, 47 | packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"], 48 | package_data={"": [".gitkeep"], MODULE_NAME: ["config.yaml"]}, 49 | keywords=["obj detection,ocr,table-recognition"], 50 | classifiers=[ 51 | "Programming Language :: Python :: 3.8", 52 | "Programming Language :: Python :: 3.9", 53 | "Programming Language :: Python :: 3.10", 54 | "Programming Language :: Python :: 3.11", 55 | ], 56 | python_requires=">=3.8,<3.13", 57 | ) 58 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | cur_dir = Path(__file__).resolve().parent 7 | root_dir = cur_dir.parent 8 | 9 | sys.path.append(str(root_dir)) 10 | test_file_dir = cur_dir / "test_files" 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "img_path", 15 | [("demo1.png")], 16 | ) 17 | def test_unwrap_uvdoc(img_path): 18 | from rapid_undistorted.inference import InferenceEngine 19 | img_path = test_file_dir / img_path 20 | engine = InferenceEngine() 21 | unwrapped_img, elapse = engine(img_path,[("unwrap", "UVDoc")]) 22 | 23 | @pytest.mark.parametrize( 24 | "img_path", 25 | [("demo1.png")], 26 | ) 27 | def test_unshadow_gcnet(img_path): 28 | from rapid_undistorted.inference import InferenceEngine 29 | img_path = test_file_dir / img_path 30 | engine = InferenceEngine() 31 | unwrapped_img, elapse = engine(img_path,[("unshadow", "GCDnet")]) 32 | 33 | @pytest.mark.parametrize( 34 | "img_path", 35 | [("demo1.png")], 36 | ) 37 | def test_unblur_opencv(img_path): 38 | from rapid_undistorted.inference import InferenceEngine 39 | img_path = test_file_dir / img_path 40 | engine = InferenceEngine() 41 | unwrapped_img, elapse = engine(img_path,[("unblur", "OpenCvBilateral")]) 42 | 43 | @pytest.mark.parametrize( 44 | "img_path", 45 | [("demo1.png")], 46 | ) 47 | def test_unblur_nafnpm(img_path): 48 | from rapid_undistorted.inference import InferenceEngine 49 | img_path = test_file_dir / img_path 50 | engine = InferenceEngine() 51 | unwrapped_img, elapse = engine(img_path,[("unblur", "NAFDPM")]) 52 | 53 | @pytest.mark.parametrize( 54 | "img_path", 55 | [("demo1.png")], 56 | ) 57 | def test_binarize_unetcnn(img_path): 58 | from rapid_undistorted.inference import InferenceEngine 59 | img_path = test_file_dir / img_path 60 | engine = InferenceEngine() 61 | unwrapped_img, elapse = engine(img_path,[("binarize", "UnetCnn")]) 62 | -------------------------------------------------------------------------------- /tests/test_files/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/tests/test_files/demo1.jpg -------------------------------------------------------------------------------- /tests/test_files/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/tests/test_files/demo1.png -------------------------------------------------------------------------------- /tests/test_files/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/tests/test_files/demo2.png -------------------------------------------------------------------------------- /tests/test_files/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidUnDistort/b627192d96f2a62222040c1f2bceb61851318c4d/tests/test_files/demo3.jpg --------------------------------------------------------------------------------