├── .github ├── examples │ └── Pretendard-Medium.ttf └── images │ ├── arch.jpg │ ├── boxes-0.png │ ├── demo-0.webp │ ├── demo-1.png │ ├── demo-2.png │ ├── demo-3.webp │ ├── logo.png │ └── starmie.jpg ├── .gitignore ├── .prettierignore ├── .prettierrc.js ├── .pylintrc ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── betterocr ├── __init__.py ├── detect.py ├── engines │ └── easy_pororo_ocr │ │ ├── __init__.py │ │ ├── pororo │ │ ├── __init__.py │ │ ├── models │ │ │ └── brainOCR │ │ │ │ ├── __init__.py │ │ │ │ ├── _dataset.py │ │ │ │ ├── _modules.py │ │ │ │ ├── brainocr.py │ │ │ │ ├── craft.py │ │ │ │ ├── craft_utils.py │ │ │ │ ├── detection.py │ │ │ │ ├── imgproc.py │ │ │ │ ├── model.py │ │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── basenet.py │ │ │ │ ├── feature_extraction.py │ │ │ │ ├── prediction.py │ │ │ │ ├── sequence_modeling.py │ │ │ │ └── transformation.py │ │ │ │ ├── recognition.py │ │ │ │ └── utils.py │ │ ├── pororo.py │ │ ├── tasks │ │ │ ├── __init__.py │ │ │ ├── optical_character_recognition.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── config.py │ │ │ │ └── download_utils.py │ │ └── utils.py │ │ └── utils │ │ ├── __init__.py │ │ ├── image_convert.py │ │ ├── image_util.py │ │ └── pre_processing.py ├── parsers.py └── wrappers │ ├── __init__.py │ ├── easy_ocr.py │ ├── easy_pororo_ocr.py │ └── tesseract │ ├── __init__.py │ ├── job.py │ └── mapping.py ├── examples ├── detect_boxes.py └── detect_text.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py └── parsers ├── test_extract_json.py └── test_extract_list.py /.github/examples/Pretendard-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/examples/Pretendard-Medium.ttf -------------------------------------------------------------------------------- /.github/images/arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/arch.jpg -------------------------------------------------------------------------------- /.github/images/boxes-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/boxes-0.png -------------------------------------------------------------------------------- /.github/images/demo-0.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-0.webp -------------------------------------------------------------------------------- /.github/images/demo-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-1.png -------------------------------------------------------------------------------- /.github/images/demo-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-2.png -------------------------------------------------------------------------------- /.github/images/demo-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-3.webp -------------------------------------------------------------------------------- /.github/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/logo.png -------------------------------------------------------------------------------- /.github/images/starmie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/starmie.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .pytest_cache 4 | dist 5 | 6 | .DS_Store 7 | *.traineddata 8 | 9 | *.png 10 | !.github/images/*.png 11 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | *.md 2 | -------------------------------------------------------------------------------- /.prettierrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | bracketSpacing: true, 3 | jsxBracketSameLine: false, 4 | singleQuote: true, 5 | trailingComma: 'all', 6 | semi: true, 7 | }; 8 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # A comma-separated list of package or module names from where C extensions may 2 | # be loaded. Extensions are loading into the active Python interpreter and may 3 | # run arbitrary code. 4 | extension-pkg-whitelist=cv2 5 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "esbenp.prettier-vscode", 4 | "ms-python.pylint", 5 | "ms-python.black-formatter" 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "[json]": { 4 | "editor.defaultFormatter": "esbenp.prettier-vscode" 5 | }, 6 | "[jsonc]": { 7 | "editor.defaultFormatter": "esbenp.prettier-vscode" 8 | }, 9 | "[javascript]": { 10 | "editor.defaultFormatter": "esbenp.prettier-vscode" 11 | }, 12 | "[typescript]": { 13 | "editor.defaultFormatter": "esbenp.prettier-vscode" 14 | }, 15 | "[typescriptreact]": { 16 | "editor.defaultFormatter": "esbenp.prettier-vscode" 17 | }, 18 | "[python]": { 19 | "editor.defaultFormatter": "ms-python.black-formatter" 20 | }, 21 | "pylint.args": [ 22 | "--disable=dangerous-default-value", 23 | "--disable=missing-module-docstring", 24 | "--disable=missing-class-docstring", 25 | "--disable=missing-function-docstring", 26 | "--disable=line-too-long", 27 | "--disable=invalid-name" 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Junho Yeo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |

6 |

BetterOCR

7 | 8 |

9 | PyPI 10 | MIT License 11 |

12 | 13 | > 🔍 Better text detection by combining multiple OCR engines with 🧠 LLM. 14 | 15 | OCR _still_ sucks! ... Especially when you're from the _other side_ of the world (and face a significant lack of training data in your language) — or just not thrilled with noisy results. 16 | 17 | **BetterOCR** combines results from multiple OCR engines with an LLM to correct & reconstruct the output. 18 | 19 |

20 |
21 | My open source work is supported by the community
22 | Special thanks to (주)한국모바일상품권(Korea Mobile Git Card, Inc.) and others 23 |

24 | 25 | ### 🔍 OCR Engines 26 | Currently supports [EasyOCR](https://github.com/JaidedAI/EasyOCR) (JaidedAI), [Tesseract](https://github.com/tesseract-ocr/tesseract) (Google), and [Pororo](https://github.com/kakaobrain/pororo) (KakaoBrain). 27 | 28 | - For Pororo, we're using the code from https://github.com/black7375/korean_ocr_using_pororo
29 | (Pre-processing ➡️ _Text detection_ with EasyOCR ➡️ _Text recognition_ with BrainOCR (Pororo's OCR module)). 30 | - Pororo is used only if the language options (`lang`) specified include either 🇺🇸 English (`en`) or 🇰🇷 Korean (`ko`). Also additional dependencies listed in [tool.poetry.group.pororo.dependencies] must be available. (If not, it'll automatically be excluded from enabled engines.) 31 | 32 | ### 🧠 LLM 33 | Supports [Chat models](https://github.com/openai/openai-python#chat-completions) from OpenAI. 34 | 35 | ### 📒 Custom Context 36 | Allows users to provide an optional context to use specific keywords such as proper nouns and product names. This assists in spelling correction and noise identification, ensuring accuracy even with rare or unconventional words. 37 | 38 | ### 🛢️ Resources 39 | 40 | - Head over to [💯 Examples](#-Examples) to view performace by languages (🇺🇸, 🇰🇷, 🇮🇳). 41 | - Coming Soon: ~~box detection~~ 🧪✅, improved interface 🚧, async support, and more. Contributions are welcomed. 42 | 43 | > **Warning**
44 | > This package is under rapid development 🛠 45 | 46 | 47 | 48 | 49 | 50 | > Architecture 51 | 52 | ## 🚀 Usage (WIP) 53 | 54 | ```bash 55 | pip install betterocr 56 | # pip3 install betterocr 57 | ``` 58 | 59 | ```py 60 | import betterocr 61 | 62 | # text detection 63 | text = betterocr.detect_text( 64 | "demo.png", 65 | ["ko", "en"], # language codes (from EasyOCR) 66 | context="", # (optional) context 67 | tesseract={ 68 | # Tesseract options here 69 | "config": "--tessdata-dir ./tessdata" 70 | }, 71 | openai={ 72 | # OpenAI options here 73 | 74 | # `os.environ["OPENAI_API_KEY"]` is used by default 75 | "API_KEY": "sk-xxxxxxx", 76 | 77 | # rest are used to pass params to `client.chat.completions.create` 78 | # `{"model": "gpt-4"}` by default 79 | "model": "gpt-3.5-turbo", 80 | }, 81 | ) 82 | print(text) 83 | ``` 84 | 85 | ### 📦 Box Detection 86 | 87 | | Original | Detected | 88 | |:---:|:---:| 89 | | | | 90 | 91 | Example Script: https://github.com/junhoyeo/BetterOCR/blob/main/examples/detect_boxes.py (Uses OpenCV and Matplotlib to draw rectangles) 92 | 93 | ```py 94 | import betterocr 95 | 96 | image_path = ".github/images/demo-1.png" 97 | items = betterocr.detect_boxes( 98 | image_path, 99 | ["ko", "en"], 100 | context="퍼멘테이션 펩타인 아이케어 크림", # product name 101 | tesseract={ 102 | "config": "--psm 6 --tessdata-dir ./tessdata -c tessedit_create_boxfile=1" 103 | }, 104 | ) 105 | print(items) 106 | ``` 107 | 108 |
109 | View Output 110 | 111 | ```py 112 | [ 113 | {'text': 'JUST FOR YOU', 'box': [[543, 87], [1013, 87], [1013, 151], [543, 151]]}, 114 | {'text': '이런 분들께 추천드리는 퍼멘테이션 펩타인 아이케어 크림', 'box': [[240, 171], [1309, 171], [1309, 224], [240, 224]]}, 115 | {'text': '매일매일 진해지는 다크서클을 개선하고 싶다면', 'box': [[123, 345], [1166, 345], [1166, 396], [123, 396]]}, 116 | {'text': '축축 처지는 피부를 탄력 있게 바꾸고 싶다면', 'box': [[125, 409], [1242, 409], [1242, 470], [125, 470]]}, 117 | {'text': '나날이 늘어가는 눈가 주름을 완화하고 싶다면', 'box': [[123, 479], [1112, 479], [1112, 553], [123, 553]]}, 118 | {'text': 'FERMENATION', 'box': [[1216, 578], [1326, 578], [1326, 588], [1216, 588]]}, 119 | {'text': '민감성 피부에도 사용할 수 있는 아이크림을 찾는다면', 'box': [[134, 534], [1071, 534], [1071, 618], [134, 618]]}, 120 | {'text': '얇고 예민한 눈가 주변 피부를 관리하고 싶다면', 'box': [[173, 634], [1098, 634], [1098, 690], [173, 690]]} 121 | ] 122 | ``` 123 | 124 |
125 | 126 | ## 💯 Examples 127 | 128 | > **Note**
129 | > Results may vary due to inherent variability and potential future updates to OCR engines or the OpenAI API. 130 | 131 | ### Example 1 (English with Noise) 132 | 133 | 134 | 135 | | Source | Text | 136 | | ------ | ---- | 137 | | EasyOCR | `CHAINSAWMANChapter 109:The Easy Way to Stop Bullying~BV-THTSUKIFUUIMUTU ETT` | 138 | | Tesseract | `A\ ira \| LT ge a TE ay NS\nye SE F Pa Ce YI AIG 44\nopr See aC\n; a) Ny 7S =u \|\n_ F2 SENN\n\ ZR\n3 ~ 1 A \ Ws —— “s 7 “A\n=) 24 4 = rt fl /1\n£72 7 a NS dA Chapter 109:77/ ¢ 4\nZz % = ~ oes os \| \STheEasf Way.to Stop Bullying:\n© Wa) ROT\n\n` | 139 | | Pororo | `CHAINSAWNAN\nChapter 109\nThe Easy Way.to Stop Bullying.\nCBY=TATSUKI FUJIMDTO` | 140 | | LLM | 🤖 GPT-3.5 | 141 | | **Result** | **`CHAINSAW MAN\n\nChapter 109: The Easy Way to Stop Bullying\n\nBY: TATSUKI FUJIMOTO`** | 142 | 143 | ### Example 2 (Korean+English) 144 | 145 | 146 | 147 | | Source | Text | 148 | | ------ | ---- | 149 | | EasyOCR | `JUST FOR YOU이런 분들께 추천드리는 퍼멘테이선 팬타인 아이켜어 크림매일매일 진해지논 다크서클올 개선하고 싶다면축축 처지논 피부름 탄력 잇게 바꾸고 싶다면나날이 늘어가는 눈가 주름올 완화하고 싶다면FERMENATION민감성 피부에도 사용할 수잇는 아이크림올 찾는다면얇고 예민한 눈가 주변 피부름 관리하고 싶다면` | 150 | | Tesseract | `9051 508 \ㅇ4\n이런 분들께 추천드리는 퍼멘테이션 타인 아이케어 크림\n.매일매일 진해지는 다크서클을 개선하고 싶다면 "도\nㆍ축축 처지는 피부를 탄력 있게 바꾸고 싶다면 7\nㆍ나날이 늘어가는 눈가 주름을 완화하고 싶다면 /\n-민감성 피부에도 사용할 수 있는 아이크림을 찾는다면 (프\nㆍ않고 예민한 눈가 주변 피부를 관리하고 싶다면 밸\n\n` | 151 | | Pororo | `JUST FOR YOU\n이런 분들께 추천드리는 퍼맨테이션 펩타인 아이케어 크림\n매일매일 진해지는 다크서클을 개선하고 싶다면\n촉촉 처지는 피부를 탄력 있게 바꾸고 싶다면\n나날이 늘어가는 눈가 주름을 완화하고 싶다면\nFERMENTATIOM\n민감성 피부에도 사용할 수 있는 아이크림을 찾는다면\n얇고 예민한 눈가 주변 피부를 관리하고 싶다면` | 152 | | LLM | 🤖 GPT-3.5 | 153 | | **Result** | **`JUST FOR YOU\n이런 분들께 추천드리는 퍼멘테이션 펩타인 아이케어 크림\n매일매일 진해지는 다크서클을 개선하고 싶다면\n축축 처지는 피부를 탄력 있게 바꾸고 싶다면\n나날이 늘어가는 눈가 주름을 완화하고 싶다면\nFERMENTATION\n민감성 피부에도 사용할 수 있는 아이크림을 찾는다면\n얇고 예민한 눈가 주변 피부를 관리하고 싶다면`** | 154 | 155 | ### Example 3 (Korean with custom `context`) 156 | 157 | 158 | 159 | | Source | Text | 160 | | ------ | ---- | 161 | | EasyOCR | `바이오함보#세로모공존존세럼6글로우픽 설문단 100인이꼼꼼하게 평가햇어요"#누적 판매액 40억#제품만족도 1009` | 162 | | Tesseract | `바이오힐보\n#세로모공폰폰세럼\n“글로 으피 석무다 1 00인이\n꼼꼼하게평가했어요”\n\n` | 163 | | Pororo | `바이오힐보\n#세로모공쫀쫀세럼\n'.\n'글로우픽 설문단 100인이\n꼼꼼하게 평가했어요'"\n#누적 판매액 40억\n# 제품 만족도 100%` | 164 | | Context | `[바이오힐보] 세로모공쫀쫀세럼으로 콜라겐 타이트닝! (6S)` | 165 | | LLM | 🤖 GPT-4 | 166 | | **Result** | **`바이오힐보\n#세로모공쫀쫀세럼\n글로우픽 설문단 100인이 꼼꼼하게 평가했어요\n#누적 판매액 40억\n#제품 만족도 100%`** | 167 | 168 | #### 🧠 LLM Reasoning (*Old) 169 | 170 | Based on the given OCR results and the context, here is the combined and corrected result: 171 | 172 | ``` 173 | { 174 | "data": "바이오힐보\n#세로모공쫀쫀세럼\n글로우픽 설문단 100인이 꼼꼼하게 평가했어요\n#누적 판매액 40억\n#제품만족도 100%" 175 | } 176 | ``` 177 | 178 | - `바이오힐보` is the correct brand name, taken from [1] and the context. 179 | - `#세로모공쫀쫀세럼` seems to be the product name and is derived from the context. 180 | - `글로우픽 설문단 100인이 꼼꼼하게 평가했어요` is extracted and corrected from both OCR results. 181 | - `#누적 판매액 40억` is taken from [0]. 182 | - `#제품만족도 100%` is corrected from [0]. 183 | 184 | ### Example 4 (Hindi) 185 | 186 | 187 | 188 | | Source | Text | 189 | | ------ | ---- | 190 | | EasyOCR | `` `७नवभारतटाइम्सतोक्यो ओलिंपिक के लिए भारतीय दलका थीम सॉन्ग लॉन्च कर दिया गयाबुधवार को इस सॉन्ग को किया गया लॉन्चसिंगर मोहित चौहान ने दी है आवाज7लखेल मंत्री किरण रिजिजू ने ट्विटर पर शेयरकिया थीम सॉन्ग का वीडियो0ब४0 २०२०गीत का नाम- '्लक्ष्य तेरा सामने है' , खेलमंत्री ने ५७ सेकंड का वीडियो किया शेयर `` | 191 | | Tesseract | `'8ा.\nनवभोरत टैइम्स\n\nतोक्यो ओलिंपिक के लिंए भारतीय दल\n\nका थीम सॉन्ग लॉन्च कर दिया गया\n\nबुधवार को हस सॉन्ग को किया गया लॉन्च\nसिंगर मोहित चौहान ने दी है आवाज\n\nखेल मंत्री किरण रिजिजू ने द्विटर पर शेयर\nकिया थीम सॉन्ग का वीडियो\n\nपृ 0 (९ है 0 2 0 2 0 गीत का नाम- 'लक्ष्य तेरा सामने है', खेल\n\n(2 (9९) मंत्री ने 57 सेकंड का वीडियो किया शेयर\n\n` | 192 | | LLM | 🤖 GPT-4 | 193 | | **Result** | **`नवभारत टाइम्स\nतोक्यो ओलिंपिक के लिए भारतीय दल का थीम सॉन्ग लॉन्च कर दिया गया\nबुधवार को इस सॉन्ग को किया गया लॉन्च\nसिंगर मोहित चौहान ने दी है आवाज\n\nखेल मंत्री किरण रिजिजू ने ट्विटर पर शेयर किया थीम सॉन्ग का वीडियो\n2020 गीत का नाम- 'लक्ष्य तेरा सामने है', खेल मंत्री ने 57 सेकंड का वीडियो किया शेयर`** | 194 | 195 | ## License (Starmie!) 196 | 197 |

198 | MIT © Junho Yeo 199 |

200 | 201 |

202 | 203 | 204 | 205 |

206 | 207 | If you find this project interesting, **please consider giving it a star(⭐)** and following me on [GitHub](https://github.com/junhoyeo). I code 24/7 and ship mind-breaking things on a regular basis, so your support definitely won't be in vain! 208 | -------------------------------------------------------------------------------- /betterocr/__init__.py: -------------------------------------------------------------------------------- 1 | from .detect import ( 2 | detect, 3 | detect_async, 4 | detect_text, 5 | detect_text_async, 6 | detect_boxes, 7 | detect_boxes_async, 8 | NoTextDetectedError, 9 | ) 10 | from .wrappers import job_easy_ocr, job_tesseract 11 | from .parsers import extract_json 12 | 13 | __all__ = [ 14 | "detect", 15 | "detect_async", 16 | "detect_text", 17 | "detect_text_async", 18 | "detect_boxes", 19 | "detect_boxes_async", 20 | "NoTextDetectedError", 21 | "job_easy_ocr", 22 | "job_tesseract", 23 | "extract_json", 24 | ] 25 | 26 | __author__ = "junhoyeo" 27 | -------------------------------------------------------------------------------- /betterocr/detect.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | import json 3 | from queue import Queue 4 | import os 5 | 6 | from openai import OpenAI 7 | 8 | from .parsers import extract_json, extract_list, rectangle_corners 9 | from .wrappers import ( 10 | job_easy_ocr, 11 | job_easy_ocr_boxes, 12 | job_tesseract, 13 | job_tesseract_boxes, 14 | ) 15 | 16 | 17 | def wrapper(func, args, queue): 18 | queue.put(func(args)) 19 | 20 | 21 | # custom error 22 | class NoTextDetectedError(Exception): 23 | pass 24 | 25 | 26 | def detect(): 27 | """Unimplemented""" 28 | raise NotImplementedError 29 | 30 | 31 | def detect_async(): 32 | """Unimplemented""" 33 | raise NotImplementedError 34 | 35 | 36 | def get_jobs(languages: list[str], boxes=False): 37 | jobs = [ 38 | job_easy_ocr if not boxes else job_easy_ocr_boxes, 39 | job_tesseract if not boxes else job_tesseract_boxes, 40 | ] 41 | # ko or en in languages 42 | if "ko" in languages or "en" in languages: 43 | try: 44 | if not boxes: 45 | from .wrappers.easy_pororo_ocr import job_easy_pororo_ocr 46 | 47 | jobs.append(job_easy_pororo_ocr) 48 | else: 49 | from .wrappers.easy_pororo_ocr import job_easy_pororo_ocr_boxes 50 | 51 | jobs.append(job_easy_pororo_ocr_boxes) 52 | except ImportError as e: 53 | print(e) 54 | print( 55 | "[!] Pororo dependencies is not installed. Skipping Pororo (EasyPororoOCR)." 56 | ) 57 | pass 58 | return jobs 59 | 60 | 61 | def detect_text( 62 | image_path: str, 63 | lang: list[str], 64 | context: str = "", 65 | tesseract: dict = {}, 66 | openai: dict = {"model": "gpt-4"}, 67 | ): 68 | """Detect text from an image using EasyOCR and Tesseract, then combine and correct the results using OpenAI's LLM.""" 69 | options = { 70 | "path": image_path, # "demo.png", 71 | "lang": lang, # ["ko", "en"] 72 | "context": context, 73 | "tesseract": tesseract, 74 | "openai": openai, 75 | } 76 | jobs = get_jobs(languages=options["lang"], boxes=False) 77 | 78 | queues = [] 79 | for job in jobs: 80 | queue = Queue() 81 | Thread(target=wrapper, args=(job, options, queue)).start() 82 | queues.append(queue) 83 | 84 | results = [queue.get() for queue in queues] 85 | 86 | result_indexes_prompt = "" # "[0][1][2]" 87 | result_prompt = "" # "[0]: result_0\n[1]: result_1\n[2]: result_2" 88 | 89 | for i in range(len(results)): 90 | result_indexes_prompt += f"[{i}]" 91 | result_prompt += f"[{i}]: {results[i]}" 92 | 93 | if i != len(results) - 1: 94 | result_prompt += "\n" 95 | 96 | optional_context_prompt = ( 97 | f"[context]: {options['context']}" if options["context"] else "" 98 | ) 99 | 100 | prompt = f"""Combine and correct OCR results {result_indexes_prompt}, using \\n for line breaks. Langauge is in {'+'.join(options['lang'])}. Remove unintended noise. Refer to the [context] keywords. Answer in the JSON format {{data:}}: 101 | {result_prompt} 102 | {optional_context_prompt}""" 103 | 104 | prompt = prompt.strip() 105 | 106 | print("=====") 107 | print(prompt) 108 | 109 | # Prioritize user-specified API_KEY 110 | api_key = options["openai"].get("API_KEY", os.environ.get("OPENAI_API_KEY")) 111 | 112 | # Make a shallow copy of the openai options and remove the API_KEY 113 | openai_options = options["openai"].copy() 114 | if "API_KEY" in openai_options: 115 | del openai_options["API_KEY"] 116 | 117 | client = OpenAI( 118 | api_key=api_key, 119 | ) 120 | 121 | print("=====") 122 | 123 | completion = client.chat.completions.create( 124 | messages=[ 125 | {"role": "user", "content": prompt}, 126 | ], 127 | **openai_options, 128 | ) 129 | output = completion.choices[0].message.content 130 | print("[*] LLM", output) 131 | 132 | result = extract_json(output) 133 | print(result) 134 | 135 | if "data" in result: 136 | return result["data"] 137 | if isinstance(result, str): 138 | return result 139 | raise NoTextDetectedError("No text detected") 140 | 141 | 142 | def detect_text_async(): 143 | """Unimplemented""" 144 | raise NotImplementedError 145 | 146 | 147 | def detect_boxes( 148 | image_path: str, 149 | lang: list[str], 150 | context: str = "", 151 | tesseract: dict = {}, 152 | openai: dict = {"model": "gpt-4"}, 153 | ): 154 | options = { 155 | "path": image_path, # "demo.png", 156 | "lang": lang, # ["ko", "en"] 157 | "context": context, 158 | "tesseract": tesseract, 159 | "openai": openai, 160 | } 161 | jobs = get_jobs(languages=options["lang"], boxes=True) 162 | 163 | queues = [] 164 | for job in jobs: 165 | queue = Queue() 166 | Thread(target=wrapper, args=(job, options, queue)).start() 167 | queues.append(queue) 168 | 169 | results = [queue.get() for queue in queues] 170 | 171 | result_indexes_prompt = "" # "[0][1][2]" 172 | result_prompt = "" # "[0]: result_0\n[1]: result_1\n[2]: result_2" 173 | 174 | for i in range(len(results)): 175 | result_indexes_prompt += f"[{i}]" 176 | 177 | boxes = results[i] 178 | boxes_json = json.dumps(boxes, ensure_ascii=False, default=int) 179 | 180 | result_prompt += f"[{i}]: {boxes_json}" 181 | 182 | if i != len(results) - 1: 183 | result_prompt += "\n" 184 | 185 | optional_context_prompt = ( 186 | " " + "Please refer to the keywords and spelling in [context]" 187 | if options["context"] 188 | else "" 189 | ) 190 | optional_context_prompt_data = ( 191 | f"[context]: {options['context']}" if options["context"] else "" 192 | ) 193 | 194 | prompt = f"""Combine and correct OCR data {result_indexes_prompt}. Include many items as possible. Langauge is in {'+'.join(options['lang'])} (Avoid arbitrary translations). Remove unintended noise.{optional_context_prompt} Answer in the JSON format. Ensure coordinates are integers (round based on confidence if necessary) and output in the same JSON format (indent=0): Array({{box:[[x,y],[x+w,y],[x+w,y+h],[x,y+h]],text:str}}): 195 | {result_prompt} 196 | {optional_context_prompt_data}""" 197 | 198 | prompt = prompt.strip() 199 | 200 | print("=====") 201 | print(prompt) 202 | 203 | # Prioritize user-specified API_KEY 204 | api_key = options["openai"].get("API_KEY", os.environ.get("OPENAI_API_KEY")) 205 | 206 | # Make a shallow copy of the openai options and remove the API_KEY 207 | openai_options = options["openai"].copy() 208 | if "API_KEY" in openai_options: 209 | del openai_options["API_KEY"] 210 | 211 | client = OpenAI( 212 | api_key=api_key, 213 | ) 214 | 215 | print("=====") 216 | 217 | completion = client.chat.completions.create( 218 | messages=[ 219 | {"role": "user", "content": prompt}, 220 | ], 221 | **openai_options, 222 | ) 223 | output = completion.choices[0].message.content 224 | output = output.replace("\n", "") 225 | print("[*] LLM", output) 226 | 227 | items = extract_list(output) 228 | 229 | for idx, item in enumerate(items): 230 | box = item["box"] 231 | 232 | # [x,y,w,h] 233 | if len(box) == 4 and isinstance(box[0], int): 234 | rect = rectangle_corners(box) 235 | items[idx]["box"] = rect 236 | 237 | # [[x,y],[w,h]] 238 | elif len(box) == 2 and isinstance(box[0], list) and len(box[0]) == 2: 239 | flattened = [i for sublist in box for i in sublist] 240 | rect = rectangle_corners(flattened) 241 | items[idx]["box"] = rect 242 | 243 | return items 244 | 245 | 246 | def detect_boxes_async(): 247 | """Unimplemented""" 248 | raise NotImplementedError 249 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import cv2 9 | from abc import ABC, abstractmethod 10 | from .pororo import Pororo 11 | from .utils.image_util import plt_imshow, put_text 12 | from .utils.image_convert import convert_coord, crop 13 | from .utils.pre_processing import load_with_filter, roi_filter 14 | from easyocr import Reader 15 | import warnings 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | class BaseOcr(ABC): 21 | def __init__(self): 22 | self.img_path = None 23 | self.ocr_result = {} 24 | 25 | def get_ocr_result(self): 26 | return self.ocr_result 27 | 28 | def get_img_path(self): 29 | return self.img_path 30 | 31 | def show_img(self): 32 | plt_imshow(img=self.img_path) 33 | 34 | def show_img_with_ocr(self, bounding, description, vertices, point): 35 | img = ( 36 | cv2.imread(self.img_path) 37 | if isinstance(self.img_path, str) 38 | else self.img_path 39 | ) 40 | roi_img = img.copy() 41 | color = (0, 200, 0) 42 | 43 | x, y = point 44 | ocr_result = self.ocr_result if bounding is None else self.ocr_result[bounding] 45 | for text_result in ocr_result: 46 | text = text_result[description] 47 | rect = text_result[vertices] 48 | 49 | topLeft, topRight, bottomRight, bottomLeft = [ 50 | (round(point[x]), round(point[y])) for point in rect 51 | ] 52 | 53 | cv2.line(roi_img, topLeft, topRight, color, 2) 54 | cv2.line(roi_img, topRight, bottomRight, color, 2) 55 | cv2.line(roi_img, bottomRight, bottomLeft, color, 2) 56 | cv2.line(roi_img, bottomLeft, topLeft, color, 2) 57 | roi_img = put_text(roi_img, text, topLeft[0], topLeft[1] - 20, color=color) 58 | 59 | plt_imshow(["Original", "ROI"], [img, roi_img], figsize=(16, 10)) 60 | 61 | @abstractmethod 62 | def run_ocr(self, img_path: str, debug: bool = False): 63 | pass 64 | 65 | 66 | class EasyPororoOcr(BaseOcr): 67 | def __init__(self, lang: list[str] = ["ko", "en"], gpu=True, **kwargs): 68 | super().__init__() 69 | self._detector = Reader(lang_list=lang, gpu=gpu, **kwargs).detect 70 | self.detect_result = None 71 | self.languages = lang 72 | 73 | def create_result(self, points): 74 | roi = crop(self.img, points) 75 | result = self._ocr(roi_filter(roi)) 76 | text = " ".join(result) 77 | 78 | return [points, text] 79 | 80 | def run_ocr(self, img_path: str, debug: bool = False, **kwargs): 81 | self.img_path = img_path 82 | self.img = cv2.imread(img_path) if isinstance(img_path, str) else self.img_path 83 | 84 | lang = "ko" if "ko" in self.languages else "en" 85 | self._ocr = Pororo(task="ocr", lang=lang, model="brainocr", **kwargs) 86 | 87 | self.detect_result = self._detector(self.img, slope_ths=0.3, height_ths=1) 88 | if debug: 89 | print(self.detect_result) 90 | 91 | horizontal_list, free_list = self.detect_result 92 | 93 | rois = [convert_coord(point) for point in horizontal_list[0]] + free_list[0] 94 | 95 | self.ocr_result = list( 96 | filter( 97 | lambda result: len(result[1]) > 0, 98 | [self.create_result(roi) for roi in rois], 99 | ) 100 | ) 101 | 102 | if len(self.ocr_result) != 0: 103 | ocr_text = list(map(lambda result: result[1], self.ocr_result)) 104 | else: 105 | ocr_text = "No text detected." 106 | 107 | if debug: 108 | self.show_img_with_ocr(None, 1, 0, [0, 1]) 109 | 110 | return ocr_text 111 | 112 | def get_boxes(self): 113 | x, y = [0, 1] 114 | ocr_result = self.ocr_result 115 | description = 1 116 | vertices = 0 117 | 118 | items = [] 119 | 120 | for text_result in ocr_result: 121 | text = text_result[description] 122 | rect = text_result[vertices] 123 | rect = [[round(point[x]), round(point[y])] for point in rect] 124 | items.append({"box": rect, "text": text}) 125 | 126 | return items 127 | 128 | 129 | __all__ = ["EasyPororoOcr", "load_with_filter"] 130 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from .pororo import Pororo # noqa 9 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from .brainocr import Reader # noqa 9 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import os 9 | 10 | from natsort import natsorted 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class RawDataset(Dataset): 16 | def __init__(self, root, imgW, imgH): 17 | self.imgW = imgW 18 | self.imgH = imgH 19 | self.image_path_list = [] 20 | for dirpath, _, filenames in os.walk(root): 21 | for name in filenames: 22 | _, ext = os.path.splitext(name) 23 | ext = ext.lower() 24 | if ext in (".jpg", ".jpeg", ".png"): 25 | self.image_path_list.append(os.path.join(dirpath, name)) 26 | 27 | self.image_path_list = natsorted(self.image_path_list) 28 | self.nSamples = len(self.image_path_list) 29 | 30 | def __len__(self): 31 | return self.nSamples 32 | 33 | def __getitem__(self, index): 34 | try: 35 | img = Image.open(self.image_path_list[index]).convert("L") 36 | 37 | except IOError: 38 | print(f"Corrupted image for {index}") 39 | img = Image.new("L", (self.imgW, self.imgH)) 40 | 41 | return img, self.image_path_list[index] 42 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from collections import namedtuple 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.nn.init as init 15 | from torchvision import models 16 | 17 | # from torchvision.models.vgg import model_urls 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | def init_weights(modules): 23 | for m in modules: 24 | if isinstance(m, nn.Conv2d): 25 | init.xavier_uniform_(m.weight.data) 26 | if m.bias is not None: 27 | m.bias.data.zero_() 28 | elif isinstance(m, nn.BatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.weight.data.normal_(0, 0.01) 33 | m.bias.data.zero_() 34 | 35 | 36 | class Vgg16BN(torch.nn.Module): 37 | def __init__(self, pretrained: bool = True, freeze: bool = True): 38 | super(Vgg16BN, self).__init__() 39 | # model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace( 40 | # "https://", "http://") 41 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 42 | self.slice1 = torch.nn.Sequential() 43 | self.slice2 = torch.nn.Sequential() 44 | self.slice3 = torch.nn.Sequential() 45 | self.slice4 = torch.nn.Sequential() 46 | self.slice5 = torch.nn.Sequential() 47 | for x in range(12): # conv2_2 48 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 49 | for x in range(12, 19): # conv3_3 50 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 51 | for x in range(19, 29): # conv4_3 52 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 53 | for x in range(29, 39): # conv5_3 54 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 55 | 56 | # fc6, fc7 without atrous conv 57 | self.slice5 = torch.nn.Sequential( 58 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 59 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 60 | nn.Conv2d(1024, 1024, kernel_size=1), 61 | ) 62 | 63 | if not pretrained: 64 | init_weights(self.slice1.modules()) 65 | init_weights(self.slice2.modules()) 66 | init_weights(self.slice3.modules()) 67 | init_weights(self.slice4.modules()) 68 | 69 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 70 | 71 | if freeze: 72 | for param in self.slice1.parameters(): # only first conv 73 | param.requires_grad = False 74 | 75 | def forward(self, x): 76 | h = self.slice1(x) 77 | h_relu2_2 = h 78 | h = self.slice2(h) 79 | h_relu3_2 = h 80 | h = self.slice3(h) 81 | h_relu4_3 = h 82 | h = self.slice4(h) 83 | h_relu5_3 = h 84 | h = self.slice5(h) 85 | h_fc7 = h 86 | vgg_outputs = namedtuple( 87 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"] 88 | ) 89 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 90 | return out 91 | 92 | 93 | class VGGFeatureExtractor(nn.Module): 94 | """FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf)""" 95 | 96 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512): 97 | super(VGGFeatureExtractor, self).__init__() 98 | 99 | self.output_channel = [ 100 | int(n_output_channels / 8), 101 | int(n_output_channels / 4), 102 | int(n_output_channels / 2), 103 | n_output_channels, 104 | ] # [64, 128, 256, 512] 105 | self.ConvNet = nn.Sequential( 106 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1), 107 | nn.ReLU(True), 108 | nn.MaxPool2d(2, 2), # 64x16x50 109 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), 110 | nn.ReLU(True), 111 | nn.MaxPool2d(2, 2), # 128x8x25 112 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), 113 | nn.ReLU(True), # 256x8x25 114 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), 115 | nn.ReLU(True), 116 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 117 | nn.Conv2d( 118 | self.output_channel[2], 119 | self.output_channel[3], 120 | 3, 121 | 1, 122 | 1, 123 | bias=False, 124 | ), 125 | nn.BatchNorm2d(self.output_channel[3]), 126 | nn.ReLU(True), # 512x4x25 127 | nn.Conv2d( 128 | self.output_channel[3], 129 | self.output_channel[3], 130 | 3, 131 | 1, 132 | 1, 133 | bias=False, 134 | ), 135 | nn.BatchNorm2d(self.output_channel[3]), 136 | nn.ReLU(True), 137 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 138 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), 139 | nn.ReLU(True), 140 | ) # 512x1x24 141 | 142 | def forward(self, x): 143 | return self.ConvNet(x) 144 | 145 | 146 | class BidirectionalLSTM(nn.Module): 147 | def __init__(self, input_size: int, hidden_size: int, output_size: int): 148 | super(BidirectionalLSTM, self).__init__() 149 | self.rnn = nn.LSTM( 150 | input_size, 151 | hidden_size, 152 | bidirectional=True, 153 | batch_first=True, 154 | ) 155 | self.linear = nn.Linear(hidden_size * 2, output_size) 156 | 157 | def forward(self, x): 158 | """ 159 | x : visual feature [batch_size x T x input_size] 160 | output : contextual feature [batch_size x T x output_size] 161 | """ 162 | self.rnn.flatten_parameters() 163 | recurrent, _ = self.rnn( 164 | x 165 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 166 | output = self.linear(recurrent) # batch_size x T x output_size 167 | return output 168 | 169 | 170 | class ResNetFeatureExtractor(nn.Module): 171 | """ 172 | FeatureExtractor of FAN 173 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) 174 | 175 | """ 176 | 177 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512): 178 | super(ResNetFeatureExtractor, self).__init__() 179 | self.ConvNet = ResNet( 180 | n_input_channels, 181 | n_output_channels, 182 | BasicBlock, 183 | [1, 2, 5, 3], 184 | ) 185 | 186 | def forward(self, inputs): 187 | return self.ConvNet(inputs) 188 | 189 | 190 | class BasicBlock(nn.Module): 191 | expansion = 1 192 | 193 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None): 194 | super(BasicBlock, self).__init__() 195 | self.conv1 = self._conv3x3(inplanes, planes) 196 | self.bn1 = nn.BatchNorm2d(planes) 197 | self.conv2 = self._conv3x3(planes, planes) 198 | self.bn2 = nn.BatchNorm2d(planes) 199 | self.relu = nn.ReLU(inplace=True) 200 | self.downsample = downsample 201 | self.stride = stride 202 | 203 | def _conv3x3(self, in_planes, out_planes, stride=1): 204 | "3x3 convolution with padding" 205 | return nn.Conv2d( 206 | in_planes, 207 | out_planes, 208 | kernel_size=3, 209 | stride=stride, 210 | padding=1, 211 | bias=False, 212 | ) 213 | 214 | def forward(self, x): 215 | residual = x 216 | 217 | out = self.conv1(x) 218 | out = self.bn1(out) 219 | out = self.relu(out) 220 | 221 | out = self.conv2(out) 222 | out = self.bn2(out) 223 | 224 | if self.downsample is not None: 225 | residual = self.downsample(x) 226 | out += residual 227 | out = self.relu(out) 228 | 229 | return out 230 | 231 | 232 | class ResNet(nn.Module): 233 | def __init__( 234 | self, 235 | n_input_channels: int, 236 | n_output_channels: int, 237 | block, 238 | layers, 239 | ): 240 | """ 241 | :param n_input_channels (int): The number of input channels of the feature extractor 242 | :param n_output_channels (int): The number of output channels of the feature extractor 243 | :param block: 244 | :param layers: 245 | """ 246 | super(ResNet, self).__init__() 247 | 248 | self.output_channel_blocks = [ 249 | int(n_output_channels / 4), 250 | int(n_output_channels / 2), 251 | n_output_channels, 252 | n_output_channels, 253 | ] 254 | 255 | self.inplanes = int(n_output_channels / 8) 256 | self.conv0_1 = nn.Conv2d( 257 | n_input_channels, 258 | int(n_output_channels / 16), 259 | kernel_size=3, 260 | stride=1, 261 | padding=1, 262 | bias=False, 263 | ) 264 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16)) 265 | self.conv0_2 = nn.Conv2d( 266 | int(n_output_channels / 16), 267 | self.inplanes, 268 | kernel_size=3, 269 | stride=1, 270 | padding=1, 271 | bias=False, 272 | ) 273 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 274 | self.relu = nn.ReLU(inplace=True) 275 | 276 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 277 | self.layer1 = self._make_layer( 278 | block, 279 | self.output_channel_blocks[0], 280 | layers[0], 281 | ) 282 | self.conv1 = nn.Conv2d( 283 | self.output_channel_blocks[0], 284 | self.output_channel_blocks[0], 285 | kernel_size=3, 286 | stride=1, 287 | padding=1, 288 | bias=False, 289 | ) 290 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0]) 291 | 292 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 293 | self.layer2 = self._make_layer( 294 | block, 295 | self.output_channel_blocks[1], 296 | layers[1], 297 | stride=1, 298 | ) 299 | self.conv2 = nn.Conv2d( 300 | self.output_channel_blocks[1], 301 | self.output_channel_blocks[1], 302 | kernel_size=3, 303 | stride=1, 304 | padding=1, 305 | bias=False, 306 | ) 307 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1]) 308 | 309 | self.maxpool3 = nn.MaxPool2d( 310 | kernel_size=2, 311 | stride=(2, 1), 312 | padding=(0, 1), 313 | ) 314 | self.layer3 = self._make_layer( 315 | block, 316 | self.output_channel_blocks[2], 317 | layers[2], 318 | stride=1, 319 | ) 320 | self.conv3 = nn.Conv2d( 321 | self.output_channel_blocks[2], 322 | self.output_channel_blocks[2], 323 | kernel_size=3, 324 | stride=1, 325 | padding=1, 326 | bias=False, 327 | ) 328 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2]) 329 | 330 | self.layer4 = self._make_layer( 331 | block, 332 | self.output_channel_blocks[3], 333 | layers[3], 334 | stride=1, 335 | ) 336 | self.conv4_1 = nn.Conv2d( 337 | self.output_channel_blocks[3], 338 | self.output_channel_blocks[3], 339 | kernel_size=2, 340 | stride=(2, 1), 341 | padding=(0, 1), 342 | bias=False, 343 | ) 344 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3]) 345 | self.conv4_2 = nn.Conv2d( 346 | self.output_channel_blocks[3], 347 | self.output_channel_blocks[3], 348 | kernel_size=2, 349 | stride=1, 350 | padding=0, 351 | bias=False, 352 | ) 353 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3]) 354 | 355 | def _make_layer(self, block, planes, blocks, stride=1): 356 | downsample = None 357 | if stride != 1 or self.inplanes != planes * block.expansion: 358 | downsample = nn.Sequential( 359 | nn.Conv2d( 360 | self.inplanes, 361 | planes * block.expansion, 362 | kernel_size=1, 363 | stride=stride, 364 | bias=False, 365 | ), 366 | nn.BatchNorm2d(planes * block.expansion), 367 | ) 368 | 369 | layers = [] 370 | layers.append(block(self.inplanes, planes, stride, downsample)) 371 | self.inplanes = planes * block.expansion 372 | for i in range(1, blocks): 373 | layers.append(block(self.inplanes, planes)) 374 | 375 | return nn.Sequential(*layers) 376 | 377 | def forward(self, x): 378 | x = self.conv0_1(x) 379 | x = self.bn0_1(x) 380 | x = self.relu(x) 381 | x = self.conv0_2(x) 382 | x = self.bn0_2(x) 383 | x = self.relu(x) 384 | 385 | x = self.maxpool1(x) 386 | x = self.layer1(x) 387 | x = self.conv1(x) 388 | x = self.bn1(x) 389 | x = self.relu(x) 390 | 391 | x = self.maxpool2(x) 392 | x = self.layer2(x) 393 | x = self.conv2(x) 394 | x = self.bn2(x) 395 | x = self.relu(x) 396 | 397 | x = self.maxpool3(x) 398 | x = self.layer3(x) 399 | x = self.conv3(x) 400 | x = self.bn3(x) 401 | x = self.relu(x) 402 | 403 | x = self.layer4(x) 404 | x = self.conv4_1(x) 405 | x = self.bn4_1(x) 406 | x = self.relu(x) 407 | x = self.conv4_2(x) 408 | x = self.bn4_2(x) 409 | x = self.relu(x) 410 | 411 | return x 412 | 413 | 414 | class TpsSpatialTransformerNetwork(nn.Module): 415 | """Rectification Network of RARE, namely TPS based STN""" 416 | 417 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1): 418 | """Based on RARE TPS 419 | input: 420 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 421 | I_size : (height, width) of the input image I 422 | I_r_size : (height, width) of the rectified image I_r 423 | I_channel_num : the number of channels of the input image I 424 | output: 425 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 426 | """ 427 | super(TpsSpatialTransformerNetwork, self).__init__() 428 | self.F = F 429 | self.I_size = I_size 430 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 431 | self.I_channel_num = I_channel_num 432 | self.LocalizationNetwork = LocalizationNetwork( 433 | self.F, 434 | self.I_channel_num, 435 | ) 436 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 437 | 438 | def forward(self, batch_I): 439 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 440 | build_P_prime = self.GridGenerator.build_P_prime( 441 | batch_C_prime 442 | ) # batch_size x n (= I_r_width x I_r_height) x 2 443 | build_P_prime_reshape = build_P_prime.reshape( 444 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2] 445 | ) 446 | 447 | batch_I_r = F.grid_sample( 448 | batch_I, 449 | build_P_prime_reshape, 450 | padding_mode="border", 451 | ) 452 | 453 | return batch_I_r 454 | 455 | 456 | class LocalizationNetwork(nn.Module): 457 | """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)""" 458 | 459 | def __init__(self, F, I_channel_num: int): 460 | super(LocalizationNetwork, self).__init__() 461 | self.F = F 462 | self.I_channel_num = I_channel_num 463 | self.conv = nn.Sequential( 464 | nn.Conv2d( 465 | in_channels=self.I_channel_num, 466 | out_channels=64, 467 | kernel_size=3, 468 | stride=1, 469 | padding=1, 470 | bias=False, 471 | ), 472 | nn.BatchNorm2d(64), 473 | nn.ReLU(True), 474 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 475 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), 476 | nn.BatchNorm2d(128), 477 | nn.ReLU(True), 478 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 479 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), 480 | nn.BatchNorm2d(256), 481 | nn.ReLU(True), 482 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 483 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), 484 | nn.BatchNorm2d(512), 485 | nn.ReLU(True), 486 | nn.AdaptiveAvgPool2d(1), # batch_size x 512 487 | ) 488 | 489 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 490 | self.localization_fc2 = nn.Linear(256, self.F * 2) 491 | 492 | # Init fc2 in LocalizationNetwork 493 | self.localization_fc2.weight.data.fill_(0) 494 | 495 | # see RARE paper Fig. 6 (a) 496 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 497 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 498 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 499 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 500 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 501 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 502 | self.localization_fc2.bias.data = ( 503 | torch.from_numpy(initial_bias).float().view(-1) 504 | ) 505 | 506 | def forward(self, batch_I): 507 | """ 508 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 509 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 510 | """ 511 | batch_size = batch_I.size(0) 512 | features = self.conv(batch_I).view(batch_size, -1) 513 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view( 514 | batch_size, self.F, 2 515 | ) 516 | return batch_C_prime 517 | 518 | 519 | class GridGenerator(nn.Module): 520 | """Grid Generator of RARE, which produces P_prime by multipling T with P""" 521 | 522 | def __init__(self, F, I_r_size): 523 | """Generate P_hat and inv_delta_C for later""" 524 | super(GridGenerator, self).__init__() 525 | self.eps = 1e-6 526 | self.I_r_height, self.I_r_width = I_r_size 527 | self.F = F 528 | self.C = self._build_C(self.F) # F x 2 529 | self.P = self._build_P(self.I_r_width, self.I_r_height) 530 | 531 | # for multi-gpu, you need register buffer 532 | self.register_buffer( 533 | "inv_delta_C", 534 | torch.tensor( 535 | self._build_inv_delta_C( 536 | self.F, 537 | self.C, 538 | ) 539 | ).float(), 540 | ) # F+3 x F+3 541 | self.register_buffer( 542 | "P_hat", 543 | torch.tensor( 544 | self._build_P_hat( 545 | self.F, 546 | self.C, 547 | self.P, 548 | ) 549 | ).float(), 550 | ) # n x F+3 551 | 552 | def _build_C(self, F): 553 | """Return coordinates of fiducial points in I_r; C""" 554 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 555 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 556 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 557 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 558 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 559 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 560 | return C # F x 2 561 | 562 | def _build_inv_delta_C(self, F, C): 563 | """Return inv_delta_C which is needed to calculate T""" 564 | hat_C = np.zeros((F, F), dtype=float) # F x F 565 | for i in range(0, F): 566 | for j in range(i, F): 567 | r = np.linalg.norm(C[i] - C[j]) 568 | hat_C[i, j] = r 569 | hat_C[j, i] = r 570 | np.fill_diagonal(hat_C, 1) 571 | hat_C = (hat_C**2) * np.log(hat_C) 572 | # print(C.shape, hat_C.shape) 573 | delta_C = np.concatenate( # F+3 x F+3 574 | [ 575 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 576 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 577 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3 578 | ], 579 | axis=0, 580 | ) 581 | inv_delta_C = np.linalg.inv(delta_C) 582 | return inv_delta_C # F+3 x F+3 583 | 584 | def _build_P(self, I_r_width, I_r_height): 585 | I_r_grid_x = ( 586 | np.arange(-I_r_width, I_r_width, 2) + 1.0 587 | ) / I_r_width # self.I_r_width 588 | I_r_grid_y = ( 589 | np.arange(-I_r_height, I_r_height, 2) + 1.0 590 | ) / I_r_height # self.I_r_height 591 | P = np.stack( # self.I_r_width x self.I_r_height x 2 592 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2 593 | ) 594 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 595 | 596 | def _build_P_hat(self, F, C, P): 597 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 598 | P_tile = np.tile( 599 | np.expand_dims(P, axis=1), (1, F, 1) 600 | ) # n x 2 -> n x 1 x 2 -> n x F x 2 601 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 602 | P_diff = P_tile - C_tile # n x F x 2 603 | rbf_norm = np.linalg.norm( 604 | P_diff, 605 | ord=2, 606 | axis=2, 607 | keepdims=False, 608 | ) # n x F 609 | rbf = np.multiply( 610 | np.square(rbf_norm), 611 | np.log(rbf_norm + self.eps), 612 | ) # n x F 613 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 614 | return P_hat # n x F+3 615 | 616 | def build_P_prime(self, batch_C_prime): 617 | """Generate Grid from batch_C_prime [batch_size x F x 2]""" 618 | batch_size = batch_C_prime.size(0) 619 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 620 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 621 | batch_C_prime_with_zeros = torch.cat( 622 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1 623 | ) # batch_size x F+3 x 2 624 | batch_T = torch.bmm( 625 | batch_inv_delta_C, 626 | batch_C_prime_with_zeros, 627 | ) # batch_size x F+3 x 2 628 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 629 | return batch_P_prime # batch_size x n x 2 630 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/brainocr.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is primarily based on the following: 10 | https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/easyocr.py 11 | 12 | Basic usage: 13 | >>> from pororo import Pororo 14 | >>> ocr = Pororo(task="ocr", lang="ko") 15 | >>> ocr("IMAGE_FILE") 16 | """ 17 | 18 | import ast 19 | from logging import getLogger 20 | from typing import List 21 | 22 | import cv2 23 | import numpy as np 24 | from PIL import Image 25 | 26 | from .detection import get_detector, get_textbox 27 | from .recognition import get_recognizer, get_text 28 | from .utils import ( 29 | diff, 30 | get_image_list, 31 | get_paragraph, 32 | group_text_box, 33 | reformat_input, 34 | ) 35 | 36 | LOGGER = getLogger(__name__) 37 | 38 | 39 | class Reader(object): 40 | def __init__( 41 | self, 42 | lang: str, 43 | det_model_ckpt_fp: str, 44 | rec_model_ckpt_fp: str, 45 | opt_fp: str, 46 | device: str, 47 | ) -> None: 48 | """ 49 | TODO @karter: modify this such that you download the pretrained checkpoint files 50 | Parameters: 51 | lang: language code. e.g, "en" or "ko" 52 | det_model_ckpt_fp: Detection model's checkpoint path e.g., 'craft_mlt_25k.pth' 53 | rec_model_ckpt_fp: Recognition model's checkpoint path 54 | opt_fp: option file path 55 | """ 56 | # Plug options in the dictionary 57 | opt2val = self.parse_options(opt_fp) # e.g., {"imgH": 64, ...} 58 | opt2val["vocab"] = self.build_vocab(opt2val["character"]) 59 | opt2val["vocab_size"] = len(opt2val["vocab"]) 60 | opt2val["device"] = device 61 | opt2val["lang"] = lang 62 | opt2val["det_model_ckpt_fp"] = det_model_ckpt_fp 63 | opt2val["rec_model_ckpt_fp"] = rec_model_ckpt_fp 64 | 65 | # Get model objects 66 | self.detector = get_detector(det_model_ckpt_fp, opt2val["device"]) 67 | self.recognizer, self.converter = get_recognizer(opt2val) 68 | self.opt2val = opt2val 69 | 70 | @staticmethod 71 | def parse_options(opt_fp: str) -> dict: 72 | opt2val = dict() 73 | for line in open(opt_fp, "r", encoding="utf8"): 74 | line = line.strip() 75 | if ": " in line: 76 | opt, val = line.split(": ", 1) 77 | try: 78 | opt2val[opt] = ast.literal_eval(val) 79 | except: 80 | opt2val[opt] = val 81 | 82 | return opt2val 83 | 84 | @staticmethod 85 | def build_vocab(character: str) -> List[str]: 86 | """Returns vocabulary (=list of characters)""" 87 | vocab = ["[blank]"] + list( 88 | character 89 | ) # dummy '[blank]' token for CTCLoss (index 0) 90 | return vocab 91 | 92 | def detect(self, img: np.ndarray, opt2val: dict): 93 | """ 94 | :return: 95 | horizontal_list (list): e.g., [[613, 1496, 51, 190], [136, 1544, 134, 508]] 96 | free_list (list): e.g., [] 97 | """ 98 | text_box = get_textbox(self.detector, img, opt2val) 99 | horizontal_list, free_list = group_text_box( 100 | text_box, 101 | opt2val["slope_ths"], 102 | opt2val["ycenter_ths"], 103 | opt2val["height_ths"], 104 | opt2val["width_ths"], 105 | opt2val["add_margin"], 106 | ) 107 | 108 | min_size = opt2val["min_size"] 109 | if min_size: 110 | horizontal_list = [ 111 | i for i in horizontal_list if max(i[1] - i[0], i[3] - i[2]) > min_size 112 | ] 113 | free_list = [ 114 | i 115 | for i in free_list 116 | if max(diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size 117 | ] 118 | 119 | return horizontal_list, free_list 120 | 121 | def recognize( 122 | self, 123 | img_cv_grey: np.ndarray, 124 | horizontal_list: list, 125 | free_list: list, 126 | opt2val: dict, 127 | ): 128 | """ 129 | Read text in the image 130 | :return: 131 | result (list): bounding box, text and confident score 132 | e.g., [([[189, 75], [469, 75], [469, 165], [189, 165]], '愚园路', 0.3754989504814148), 133 | ([[86, 80], [134, 80], [134, 128], [86, 128]], '西', 0.40452659130096436), 134 | ([[517, 81], [565, 81], [565, 123], [517, 123]], '东', 0.9989598989486694), 135 | ([[78, 126], [136, 126], [136, 156], [78, 156]], '315', 0.8125889301300049), 136 | ([[514, 126], [574, 126], [574, 156], [514, 156]], '309', 0.4971577227115631), 137 | ([[226, 170], [414, 170], [414, 220], [226, 220]], 'Yuyuan Rd.', 0.8261902332305908), 138 | ([[79, 173], [125, 173], [125, 213], [79, 213]], 'W', 0.9848111271858215), 139 | ([[529, 173], [569, 173], [569, 213], [529, 213]], 'E', 0.8405593633651733)] 140 | or list of texts (if skip_details is True) 141 | e.g., ['愚园路', '西', '东', '315', '309', 'Yuyuan Rd.', 'W', 'E'] 142 | """ 143 | imgH = opt2val["imgH"] 144 | paragraph = opt2val["paragraph"] 145 | skip_details = opt2val["skip_details"] 146 | 147 | if (horizontal_list is None) and (free_list is None): 148 | y_max, x_max = img_cv_grey.shape 149 | ratio = x_max / y_max 150 | max_width = int(imgH * ratio) 151 | crop_img = cv2.resize( 152 | img_cv_grey, 153 | (max_width, imgH), 154 | interpolation=Image.LANCZOS, 155 | ) 156 | image_list = [([[0, 0], [x_max, 0], [x_max, y_max], [0, y_max]], crop_img)] 157 | else: 158 | image_list, max_width = get_image_list( 159 | horizontal_list, 160 | free_list, 161 | img_cv_grey, 162 | model_height=imgH, 163 | ) 164 | 165 | result = get_text(image_list, self.recognizer, self.converter, opt2val) 166 | 167 | if paragraph: 168 | result = get_paragraph(result, mode="ltr") 169 | 170 | if skip_details: # texts only 171 | return [item[1] for item in result] 172 | else: # full outputs: bounding box, text and confident score 173 | return result 174 | 175 | def __call__( 176 | self, 177 | image, 178 | batch_size: int = 1, 179 | n_workers: int = 0, 180 | skip_details: bool = False, 181 | paragraph: bool = False, 182 | min_size: int = 20, 183 | contrast_ths: float = 0.1, 184 | adjust_contrast: float = 0.5, 185 | filter_ths: float = 0.003, 186 | text_threshold: float = 0.7, 187 | low_text: float = 0.4, 188 | link_threshold: float = 0.4, 189 | canvas_size: int = 2560, 190 | mag_ratio: float = 1.0, 191 | slope_ths: float = 0.1, 192 | ycenter_ths: float = 0.5, 193 | height_ths: float = 0.5, 194 | width_ths: float = 0.5, 195 | add_margin: float = 0.1, 196 | ): 197 | """ 198 | Detect text in the image and then recognize it. 199 | :param image: file path or numpy-array or a byte stream object 200 | :param batch_size: 201 | :param n_workers: 202 | :param skip_details: 203 | :param paragraph: 204 | :param min_size: 205 | :param contrast_ths: 206 | :param adjust_contrast: 207 | :param filter_ths: 208 | :param text_threshold: 209 | :param low_text: 210 | :param link_threshold: 211 | :param canvas_size: 212 | :param mag_ratio: 213 | :param slope_ths: 214 | :param ycenter_ths: 215 | :param height_ths: 216 | :param width_ths: 217 | :param add_margin: 218 | :return: 219 | """ 220 | # update `opt2val` 221 | self.opt2val["batch_size"] = batch_size 222 | self.opt2val["n_workers"] = n_workers 223 | self.opt2val["skip_details"] = skip_details 224 | self.opt2val["paragraph"] = paragraph 225 | self.opt2val["min_size"] = min_size 226 | self.opt2val["contrast_ths"] = contrast_ths 227 | self.opt2val["adjust_contrast"] = adjust_contrast 228 | self.opt2val["filter_ths"] = filter_ths 229 | self.opt2val["text_threshold"] = text_threshold 230 | self.opt2val["low_text"] = low_text 231 | self.opt2val["link_threshold"] = link_threshold 232 | self.opt2val["canvas_size"] = canvas_size 233 | self.opt2val["mag_ratio"] = mag_ratio 234 | self.opt2val["slope_ths"] = slope_ths 235 | self.opt2val["ycenter_ths"] = ycenter_ths 236 | self.opt2val["height_ths"] = height_ths 237 | self.opt2val["width_ths"] = width_ths 238 | self.opt2val["add_margin"] = add_margin 239 | 240 | img, img_cv_grey = reformat_input(image) # img, img_cv_grey: array 241 | 242 | horizontal_list, free_list = self.detect(img, self.opt2val) 243 | result = self.recognize( 244 | img_cv_grey, 245 | horizontal_list, 246 | free_list, 247 | self.opt2val, 248 | ) 249 | 250 | return result 251 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/craft.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/craft.py. 10 | Copyright (c) 2019-present NAVER Corp. 11 | MIT License 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch import Tensor 18 | 19 | from ._modules import Vgg16BN, init_weights 20 | 21 | 22 | class DoubleConv(nn.Module): 23 | def __init__(self, in_ch: int, mid_ch: int, out_ch: int) -> None: 24 | super(DoubleConv, self).__init__() 25 | self.conv = nn.Sequential( 26 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 27 | nn.BatchNorm2d(mid_ch), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 30 | nn.BatchNorm2d(out_ch), 31 | nn.ReLU(inplace=True), 32 | ) 33 | 34 | def forward(self, x: Tensor): 35 | x = self.conv(x) 36 | return x 37 | 38 | 39 | class CRAFT(nn.Module): 40 | def __init__(self, pretrained: bool = False, freeze: bool = False) -> None: 41 | super(CRAFT, self).__init__() 42 | 43 | # Base network 44 | self.basenet = Vgg16BN(pretrained, freeze) 45 | 46 | # U network 47 | self.upconv1 = DoubleConv(1024, 512, 256) 48 | self.upconv2 = DoubleConv(512, 256, 128) 49 | self.upconv3 = DoubleConv(256, 128, 64) 50 | self.upconv4 = DoubleConv(128, 64, 32) 51 | 52 | num_class = 2 53 | self.conv_cls = nn.Sequential( 54 | nn.Conv2d(32, 32, kernel_size=3, padding=1), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(32, 32, kernel_size=3, padding=1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(32, 16, kernel_size=3, padding=1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(16, 16, kernel_size=1), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(16, num_class, kernel_size=1), 63 | ) 64 | 65 | init_weights(self.upconv1.modules()) 66 | init_weights(self.upconv2.modules()) 67 | init_weights(self.upconv3.modules()) 68 | init_weights(self.upconv4.modules()) 69 | init_weights(self.conv_cls.modules()) 70 | 71 | def forward(self, x: Tensor): 72 | # Base network 73 | sources = self.basenet(x) 74 | 75 | # U network 76 | y = torch.cat([sources[0], sources[1]], dim=1) 77 | y = self.upconv1(y) 78 | 79 | y = F.interpolate( 80 | y, 81 | size=sources[2].size()[2:], 82 | mode="bilinear", 83 | align_corners=False, 84 | ) 85 | y = torch.cat([y, sources[2]], dim=1) 86 | y = self.upconv2(y) 87 | 88 | y = F.interpolate( 89 | y, 90 | size=sources[3].size()[2:], 91 | mode="bilinear", 92 | align_corners=False, 93 | ) 94 | y = torch.cat([y, sources[3]], dim=1) 95 | y = self.upconv3(y) 96 | 97 | y = F.interpolate( 98 | y, 99 | size=sources[4].size()[2:], 100 | mode="bilinear", 101 | align_corners=False, 102 | ) 103 | y = torch.cat([y, sources[4]], dim=1) 104 | feature = self.upconv4(y) 105 | 106 | y = self.conv_cls(feature) 107 | 108 | return y.permute(0, 2, 3, 1), feature 109 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from avhttps://github.com/clovaai/CRAFT-pytorch/blob/master/craft_utils.py 10 | MIT License 11 | """ 12 | 13 | import math 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | 19 | def warp_coord(Minv, pt): 20 | """auxilary functions: unwarp corodinates:""" 21 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 22 | return np.array([out[0] / out[2], out[1] / out[2]]) 23 | 24 | 25 | def get_det_boxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 26 | # prepare data 27 | linkmap = linkmap.copy() 28 | textmap = textmap.copy() 29 | img_h, img_w = textmap.shape 30 | 31 | # labeling method 32 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 33 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 34 | 35 | text_score_comb = np.clip(text_score + link_score, 0, 1) 36 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 37 | text_score_comb.astype(np.uint8), connectivity=4 38 | ) 39 | 40 | det = [] 41 | mapper = [] 42 | for k in range(1, nLabels): 43 | # size filtering 44 | size = stats[k, cv2.CC_STAT_AREA] 45 | if size < 10: 46 | continue 47 | 48 | # thresholding 49 | if np.max(textmap[labels == k]) < text_threshold: 50 | continue 51 | 52 | # make segmentation map 53 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 54 | segmap[labels == k] = 255 55 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 # remove link area 56 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 57 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 58 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 59 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 60 | # boundary check 61 | if sx < 0: 62 | sx = 0 63 | if sy < 0: 64 | sy = 0 65 | if ex >= img_w: 66 | ex = img_w 67 | if ey >= img_h: 68 | ey = img_h 69 | kernel = cv2.getStructuringElement( 70 | cv2.MORPH_RECT, 71 | (1 + niter, 1 + niter), 72 | ) 73 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 74 | 75 | # make box 76 | np_contours = ( 77 | np.roll(np.array(np.where(segmap != 0)), 1, axis=0) 78 | .transpose() 79 | .reshape(-1, 2) 80 | ) 81 | rectangle = cv2.minAreaRect(np_contours) 82 | box = cv2.boxPoints(rectangle) 83 | 84 | # align diamond-shape 85 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 86 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 87 | if abs(1 - box_ratio) <= 0.1: 88 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0]) 89 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1]) 90 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 91 | 92 | # make clock-wise order 93 | startidx = box.sum(axis=1).argmin() 94 | box = np.roll(box, 4 - startidx, 0) 95 | box = np.array(box) 96 | 97 | det.append(box) 98 | mapper.append(k) 99 | 100 | return det, labels, mapper 101 | 102 | 103 | def get_poly_core(boxes, labels, mapper, linkmap): 104 | # configs 105 | num_cp = 5 106 | max_len_ratio = 0.7 107 | expand_ratio = 1.45 108 | max_r = 2.0 109 | step_r = 0.2 110 | 111 | polys = [] 112 | for k, box in enumerate(boxes): 113 | # size filter for small instance 114 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int( 115 | np.linalg.norm(box[1] - box[2]) + 1 116 | ) 117 | if w < 10 or h < 10: 118 | polys.append(None) 119 | continue 120 | 121 | # warp image 122 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 123 | M = cv2.getPerspectiveTransform(box, tar) 124 | word_label = cv2.warpPerspective( 125 | labels, 126 | M, 127 | (w, h), 128 | flags=cv2.INTER_NEAREST, 129 | ) 130 | try: 131 | Minv = np.linalg.inv(M) 132 | except: 133 | polys.append(None) 134 | continue 135 | 136 | # binarization for selected label 137 | cur_label = mapper[k] 138 | word_label[word_label != cur_label] = 0 139 | word_label[word_label > 0] = 1 140 | 141 | # Polygon generation: find top/bottom contours 142 | cp = [] 143 | max_len = -1 144 | for i in range(w): 145 | region = np.where(word_label[:, i] != 0)[0] 146 | if len(region) < 2: 147 | continue 148 | cp.append((i, region[0], region[-1])) 149 | length = region[-1] - region[0] + 1 150 | if length > max_len: 151 | max_len = length 152 | 153 | # pass if max_len is similar to h 154 | if h * max_len_ratio < max_len: 155 | polys.append(None) 156 | continue 157 | 158 | # get pivot points with fixed length 159 | tot_seg = num_cp * 2 + 1 160 | seg_w = w / tot_seg # segment width 161 | pp = [None] * num_cp # init pivot points 162 | cp_section = [[0, 0]] * tot_seg 163 | seg_height = [0] * num_cp 164 | seg_num = 0 165 | num_sec = 0 166 | prev_h = -1 167 | for i in range(0, len(cp)): 168 | (x, sy, ey) = cp[i] 169 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 170 | # average previous segment 171 | if num_sec == 0: 172 | break 173 | cp_section[seg_num] = [ 174 | cp_section[seg_num][0] / num_sec, 175 | cp_section[seg_num][1] / num_sec, 176 | ] 177 | num_sec = 0 178 | 179 | # reset variables 180 | seg_num += 1 181 | prev_h = -1 182 | 183 | # accumulate center points 184 | cy = (sy + ey) * 0.5 185 | cur_h = ey - sy + 1 186 | cp_section[seg_num] = [ 187 | cp_section[seg_num][0] + x, 188 | cp_section[seg_num][1] + cy, 189 | ] 190 | num_sec += 1 191 | 192 | if seg_num % 2 == 0: 193 | continue # No polygon area 194 | 195 | if prev_h < cur_h: 196 | pp[int((seg_num - 1) / 2)] = (x, cy) 197 | seg_height[int((seg_num - 1) / 2)] = cur_h 198 | prev_h = cur_h 199 | 200 | # processing last segment 201 | if num_sec != 0: 202 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 203 | 204 | # pass if num of pivots is not sufficient or segment width is smaller than character height 205 | if None in pp or seg_w < np.max(seg_height) * 0.25: 206 | polys.append(None) 207 | continue 208 | 209 | # calc median maximum of pivot points 210 | half_char_h = np.median(seg_height) * expand_ratio / 2 211 | 212 | # calc gradiant and apply to make horizontal pivots 213 | new_pp = [] 214 | for i, (x, cy) in enumerate(pp): 215 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 216 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 217 | if dx == 0: # gradient if zero 218 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 219 | continue 220 | rad = -math.atan2(dy, dx) 221 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 222 | new_pp.append([x - s, cy - c, x + s, cy + c]) 223 | 224 | # get edge points to cover character heatmaps 225 | isSppFound, isEppFound = False, False 226 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + ( 227 | pp[2][1] - pp[1][1] 228 | ) / (pp[2][0] - pp[1][0]) 229 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + ( 230 | pp[-3][1] - pp[-2][1] 231 | ) / (pp[-3][0] - pp[-2][0]) 232 | for r in np.arange(0.5, max_r, step_r): 233 | dx = 2 * half_char_h * r 234 | if not isSppFound: 235 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 236 | dy = grad_s * dx 237 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 238 | cv2.line( 239 | line_img, 240 | (int(p[0]), int(p[1])), 241 | (int(p[2]), int(p[3])), 242 | 1, 243 | thickness=1, 244 | ) 245 | if ( 246 | np.sum(np.logical_and(word_label, line_img)) == 0 247 | or r + 2 * step_r >= max_r 248 | ): 249 | spp = p 250 | isSppFound = True 251 | if not isEppFound: 252 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 253 | dy = grad_e * dx 254 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 255 | cv2.line( 256 | line_img, 257 | (int(p[0]), int(p[1])), 258 | (int(p[2]), int(p[3])), 259 | 1, 260 | thickness=1, 261 | ) 262 | if ( 263 | np.sum(np.logical_and(word_label, line_img)) == 0 264 | or r + 2 * step_r >= max_r 265 | ): 266 | epp = p 267 | isEppFound = True 268 | if isSppFound and isEppFound: 269 | break 270 | 271 | # pass if boundary of polygon is not found 272 | if not (isSppFound and isEppFound): 273 | polys.append(None) 274 | continue 275 | 276 | # make final polygon 277 | poly = [] 278 | poly.append(warp_coord(Minv, (spp[0], spp[1]))) 279 | for p in new_pp: 280 | poly.append(warp_coord(Minv, (p[0], p[1]))) 281 | poly.append(warp_coord(Minv, (epp[0], epp[1]))) 282 | poly.append(warp_coord(Minv, (epp[2], epp[3]))) 283 | for p in reversed(new_pp): 284 | poly.append(warp_coord(Minv, (p[2], p[3]))) 285 | poly.append(warp_coord(Minv, (spp[2], spp[3]))) 286 | 287 | # add to final result 288 | polys.append(np.array(poly)) 289 | 290 | return polys 291 | 292 | 293 | def get_det_boxes( 294 | textmap, 295 | linkmap, 296 | text_threshold, 297 | link_threshold, 298 | low_text, 299 | poly=False, 300 | ): 301 | boxes, labels, mapper = get_det_boxes_core( 302 | textmap, 303 | linkmap, 304 | text_threshold, 305 | link_threshold, 306 | low_text, 307 | ) 308 | 309 | if poly: 310 | polys = get_poly_core(boxes, labels, mapper, linkmap) 311 | else: 312 | polys = [None] * len(boxes) 313 | 314 | return boxes, polys 315 | 316 | 317 | def adjust_result_coordinates(polys, ratio_w, ratio_h, ratio_net=2): 318 | if len(polys) > 0: 319 | polys = np.array(polys) 320 | for k in range(len(polys)): 321 | if polys[k] is not None: 322 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 323 | return polys 324 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/detection.py 10 | """ 11 | 12 | from collections import OrderedDict 13 | 14 | import cv2 15 | import numpy as np 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | from torch.autograd import Variable 19 | 20 | from .craft import CRAFT 21 | from .craft_utils import adjust_result_coordinates, get_det_boxes 22 | from .imgproc import normalize_mean_variance, resize_aspect_ratio 23 | 24 | 25 | def copy_state_dict(state_dict): 26 | if list(state_dict.keys())[0].startswith("module"): 27 | start_idx = 1 28 | else: 29 | start_idx = 0 30 | new_state_dict = OrderedDict() 31 | for k, v in state_dict.items(): 32 | name = ".".join(k.split(".")[start_idx:]) 33 | new_state_dict[name] = v 34 | return new_state_dict 35 | 36 | 37 | def test_net(image: np.ndarray, net, opt2val: dict): 38 | canvas_size = opt2val["canvas_size"] 39 | mag_ratio = opt2val["mag_ratio"] 40 | text_threshold = opt2val["text_threshold"] 41 | link_threshold = opt2val["link_threshold"] 42 | low_text = opt2val["low_text"] 43 | device = opt2val["device"] 44 | 45 | # resize 46 | img_resized, target_ratio, size_heatmap = resize_aspect_ratio( 47 | image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio 48 | ) 49 | ratio_h = ratio_w = 1 / target_ratio 50 | 51 | # preprocessing 52 | x = normalize_mean_variance(img_resized) 53 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 54 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 55 | x = x.to(device) 56 | 57 | # forward pass 58 | with torch.no_grad(): 59 | y, feature = net(x) 60 | 61 | # make score and link map 62 | score_text = y[0, :, :, 0].cpu().data.numpy() 63 | score_link = y[0, :, :, 1].cpu().data.numpy() 64 | 65 | # Post-processing 66 | boxes, polys = get_det_boxes( 67 | score_text, 68 | score_link, 69 | text_threshold, 70 | link_threshold, 71 | low_text, 72 | ) 73 | 74 | # coordinate adjustment 75 | boxes = adjust_result_coordinates(boxes, ratio_w, ratio_h) 76 | polys = adjust_result_coordinates(polys, ratio_w, ratio_h) 77 | for k in range(len(polys)): 78 | if polys[k] is None: 79 | polys[k] = boxes[k] 80 | 81 | return boxes, polys 82 | 83 | 84 | def get_detector(det_model_ckpt_fp: str, device: str = "cpu"): 85 | net = CRAFT() 86 | 87 | net.load_state_dict( 88 | copy_state_dict(torch.load(det_model_ckpt_fp, map_location=device)) 89 | ) 90 | if device == "cuda": 91 | net = torch.nn.DataParallel(net).to(device) 92 | cudnn.benchmark = False 93 | 94 | net.eval() 95 | return net 96 | 97 | 98 | def get_textbox(detector, image: np.ndarray, opt2val: dict): 99 | bboxes, polys = test_net(image, detector, opt2val) 100 | result = [] 101 | for i, box in enumerate(polys): 102 | poly = np.array(box).astype(np.int32).reshape((-1)) 103 | result.append(poly) 104 | 105 | return result 106 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/imgproc.py 10 | Copyright (c) 2019-present NAVER Corp. 11 | MIT License 12 | """ 13 | 14 | import cv2 15 | import numpy as np 16 | from skimage import io 17 | 18 | 19 | def load_image(img_file): 20 | img = io.imread(img_file) # RGB order 21 | if img.shape[0] == 2: 22 | img = img[0] 23 | if len(img.shape) == 2: 24 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 25 | if img.shape[2] == 4: 26 | img = img[:, :, :3] 27 | img = np.array(img) 28 | 29 | return img 30 | 31 | 32 | def normalize_mean_variance( 33 | in_img, 34 | mean=(0.485, 0.456, 0.406), 35 | variance=(0.229, 0.224, 0.225), 36 | ): 37 | # should be RGB order 38 | img = in_img.copy().astype(np.float32) 39 | 40 | img -= np.array( 41 | [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32 42 | ) 43 | img /= np.array( 44 | [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], 45 | dtype=np.float32, 46 | ) 47 | return img 48 | 49 | 50 | def denormalize_mean_variance( 51 | in_img, 52 | mean=(0.485, 0.456, 0.406), 53 | variance=(0.229, 0.224, 0.225), 54 | ): 55 | # should be RGB order 56 | img = in_img.copy() 57 | img *= variance 58 | img += mean 59 | img *= 255.0 60 | img = np.clip(img, 0, 255).astype(np.uint8) 61 | return img 62 | 63 | 64 | def resize_aspect_ratio( 65 | img: np.ndarray, 66 | square_size: int, 67 | interpolation: int, 68 | mag_ratio: float = 1.0, 69 | ): 70 | height, width, channel = img.shape 71 | 72 | # magnify image size 73 | target_size = mag_ratio * max(height, width) 74 | 75 | # set original image size 76 | if target_size > square_size: 77 | target_size = square_size 78 | 79 | ratio = target_size / max(height, width) 80 | 81 | target_h, target_w = int(height * ratio), int(width * ratio) 82 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation) 83 | 84 | # make canvas and paste image 85 | target_h32, target_w32 = target_h, target_w 86 | if target_h % 32 != 0: 87 | target_h32 = target_h + (32 - target_h % 32) 88 | if target_w % 32 != 0: 89 | target_w32 = target_w + (32 - target_w % 32) 90 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 91 | resized[0:target_h, 0:target_w, :] = proc 92 | target_h, target_w = target_h32, target_w32 93 | 94 | size_heatmap = (int(target_w / 2), int(target_h / 2)) 95 | 96 | return resized, ratio, size_heatmap 97 | 98 | 99 | def cvt2heatmap_img(img): 100 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 101 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 102 | return img 103 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from 10 | https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/model.py 11 | """ 12 | 13 | import torch.nn as nn 14 | from torch import Tensor 15 | 16 | from .modules.feature_extraction import ( 17 | ResNetFeatureExtractor, 18 | VGGFeatureExtractor, 19 | ) 20 | from .modules.prediction import Attention 21 | from .modules.sequence_modeling import BidirectionalLSTM 22 | from .modules.transformation import TpsSpatialTransformerNetwork 23 | 24 | 25 | class Model(nn.Module): 26 | def __init__(self, opt2val: dict): 27 | super(Model, self).__init__() 28 | 29 | input_channel = opt2val["input_channel"] 30 | output_channel = opt2val["output_channel"] 31 | hidden_size = opt2val["hidden_size"] 32 | vocab_size = opt2val["vocab_size"] 33 | num_fiducial = opt2val["num_fiducial"] 34 | imgH = opt2val["imgH"] 35 | imgW = opt2val["imgW"] 36 | FeatureExtraction = opt2val["FeatureExtraction"] 37 | Transformation = opt2val["Transformation"] 38 | SequenceModeling = opt2val["SequenceModeling"] 39 | Prediction = opt2val["Prediction"] 40 | 41 | # Transformation 42 | if Transformation == "TPS": 43 | self.Transformation = TpsSpatialTransformerNetwork( 44 | F=num_fiducial, 45 | I_size=(imgH, imgW), 46 | I_r_size=(imgH, imgW), 47 | I_channel_num=input_channel, 48 | ) 49 | else: 50 | print("No Transformation module specified") 51 | 52 | # FeatureExtraction 53 | if FeatureExtraction == "VGG": 54 | extractor = VGGFeatureExtractor 55 | else: # ResNet 56 | extractor = ResNetFeatureExtractor 57 | self.FeatureExtraction = extractor( 58 | input_channel, 59 | output_channel, 60 | opt2val, 61 | ) 62 | self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512 63 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( 64 | (None, 1) 65 | ) # Transform final (imgH/16-1) -> 1 66 | 67 | # Sequence modeling 68 | if SequenceModeling == "BiLSTM": 69 | self.SequenceModeling = nn.Sequential( 70 | BidirectionalLSTM( 71 | self.FeatureExtraction_output, 72 | hidden_size, 73 | hidden_size, 74 | ), 75 | BidirectionalLSTM(hidden_size, hidden_size, hidden_size), 76 | ) 77 | self.SequenceModeling_output = hidden_size 78 | else: 79 | print("No SequenceModeling module specified") 80 | self.SequenceModeling_output = self.FeatureExtraction_output 81 | 82 | # Prediction 83 | if Prediction == "CTC": 84 | self.Prediction = nn.Linear( 85 | self.SequenceModeling_output, 86 | vocab_size, 87 | ) 88 | elif Prediction == "Attn": 89 | self.Prediction = Attention( 90 | self.SequenceModeling_output, 91 | hidden_size, 92 | vocab_size, 93 | ) 94 | elif Prediction == "Transformer": # TODO 95 | pass 96 | else: 97 | raise Exception("Prediction is neither CTC or Attn") 98 | 99 | def forward(self, x: Tensor): 100 | """ 101 | :param x: (batch, input_channel, height, width) 102 | :return: 103 | """ 104 | # Transformation stage 105 | x = self.Transformation(x) 106 | 107 | # Feature extraction stage 108 | visual_feature = self.FeatureExtraction(x) # (b, output_channel=512, h=3, w) 109 | visual_feature = self.AdaptiveAvgPool( 110 | visual_feature.permute(0, 3, 1, 2) 111 | ) # (b, w, channel=512, h=1) 112 | visual_feature = visual_feature.squeeze(3) # (b, w, channel=512) 113 | 114 | # Sequence modeling stage 115 | self.SequenceModeling.eval() 116 | contextual_feature = self.SequenceModeling(visual_feature) 117 | 118 | # Prediction stage 119 | prediction = self.Prediction( 120 | contextual_feature.contiguous() 121 | ) # (b, T, num_classes) 122 | 123 | return prediction 124 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/basenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from collections import namedtuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | from torchvision import models 14 | from torchvision.models.vgg import model_urls 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | def init_weights(modules): 20 | for m in modules: 21 | if isinstance(m, nn.Conv2d): 22 | init.xavier_uniform_(m.weight.data) 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif isinstance(m, nn.BatchNorm2d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | elif isinstance(m, nn.Linear): 29 | m.weight.data.normal_(0, 0.01) 30 | m.bias.data.zero_() 31 | 32 | 33 | class Vgg16BN(torch.nn.Module): 34 | def __init__(self, pretrained: bool = True, freeze: bool = True): 35 | super(Vgg16BN, self).__init__() 36 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace("https://", "http://") 37 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 38 | self.slice1 = torch.nn.Sequential() 39 | self.slice2 = torch.nn.Sequential() 40 | self.slice3 = torch.nn.Sequential() 41 | self.slice4 = torch.nn.Sequential() 42 | self.slice5 = torch.nn.Sequential() 43 | for x in range(12): # conv2_2 44 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 45 | for x in range(12, 19): # conv3_3 46 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 47 | for x in range(19, 29): # conv4_3 48 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 49 | for x in range(29, 39): # conv5_3 50 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 51 | 52 | # fc6, fc7 without atrous conv 53 | self.slice5 = torch.nn.Sequential( 54 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 55 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 56 | nn.Conv2d(1024, 1024, kernel_size=1), 57 | ) 58 | 59 | if not pretrained: 60 | init_weights(self.slice1.modules()) 61 | init_weights(self.slice2.modules()) 62 | init_weights(self.slice3.modules()) 63 | init_weights(self.slice4.modules()) 64 | 65 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 66 | 67 | if freeze: 68 | for param in self.slice1.parameters(): # only first conv 69 | param.requires_grad = False 70 | 71 | def forward(self, x): 72 | h = self.slice1(x) 73 | h_relu2_2 = h 74 | h = self.slice2(h) 75 | h_relu3_2 = h 76 | h = self.slice3(h) 77 | h_relu4_3 = h 78 | h = self.slice4(h) 79 | h_relu5_3 = h 80 | h = self.slice5(h) 81 | h_fc7 = h 82 | vgg_outputs = namedtuple( 83 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"] 84 | ) 85 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 86 | return out 87 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class VGGFeatureExtractor(nn.Module): 12 | """FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf)""" 13 | 14 | def __init__( 15 | self, n_input_channels: int = 1, n_output_channels: int = 512, opt2val=None 16 | ): 17 | super(VGGFeatureExtractor, self).__init__() 18 | 19 | self.output_channel = [ 20 | int(n_output_channels / 8), 21 | int(n_output_channels / 4), 22 | int(n_output_channels / 2), 23 | n_output_channels, 24 | ] # [64, 128, 256, 512] 25 | 26 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"] 27 | if "baseline" in rec_model_ckpt_fp: 28 | self.ConvNet = nn.Sequential( 29 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1), 30 | nn.ReLU(True), 31 | nn.MaxPool2d(2, 2), # 64x16x50 32 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), 33 | nn.ReLU(True), 34 | nn.MaxPool2d(2, 2), # 128x8x25 35 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), 36 | nn.ReLU(True), # 256x8x25 37 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), 38 | nn.ReLU(True), 39 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 40 | nn.Conv2d( 41 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False 42 | ), 43 | nn.BatchNorm2d(self.output_channel[3]), 44 | nn.ReLU(True), # 512x4x25 45 | nn.Conv2d( 46 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False 47 | ), 48 | nn.BatchNorm2d(self.output_channel[3]), 49 | nn.ReLU(True), 50 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 51 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 52 | nn.ConvTranspose2d( 53 | self.output_channel[3], self.output_channel[3], 2, 2 54 | ), 55 | nn.ReLU(True), 56 | ) # 512x4x50 57 | else: 58 | self.ConvNet = nn.Sequential( 59 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1), 60 | nn.ReLU(True), 61 | nn.MaxPool2d(2, 2), # 64x16x50 62 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), 63 | nn.ReLU(True), 64 | nn.MaxPool2d(2, 2), # 128x8x25 65 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), 66 | nn.ReLU(True), # 256x8x25 67 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), 68 | nn.ReLU(True), 69 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 70 | nn.Conv2d( 71 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False 72 | ), 73 | nn.BatchNorm2d(self.output_channel[3]), 74 | nn.ReLU(True), # 512x4x25 75 | nn.Conv2d( 76 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False 77 | ), 78 | nn.BatchNorm2d(self.output_channel[3]), 79 | nn.ReLU(True), 80 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 81 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 82 | nn.ConvTranspose2d( 83 | self.output_channel[3], self.output_channel[3], 2, 2 84 | ), 85 | nn.ReLU(True), # 512x4x50 86 | nn.ConvTranspose2d( 87 | self.output_channel[3], self.output_channel[3], 2, 2 88 | ), 89 | nn.ReLU(True), 90 | ) # 512x4x50 91 | 92 | def forward(self, x): 93 | return self.ConvNet(x) 94 | 95 | 96 | class ResNetFeatureExtractor(nn.Module): 97 | """ 98 | FeatureExtractor of FAN 99 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) 100 | """ 101 | 102 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512): 103 | super(ResNetFeatureExtractor, self).__init__() 104 | self.ConvNet = ResNet( 105 | n_input_channels, n_output_channels, BasicBlock, [1, 2, 5, 3] 106 | ) 107 | 108 | def forward(self, inputs): 109 | return self.ConvNet(inputs) 110 | 111 | 112 | class BasicBlock(nn.Module): 113 | expansion = 1 114 | 115 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None): 116 | super(BasicBlock, self).__init__() 117 | self.conv1 = self._conv3x3(inplanes, planes) 118 | self.bn1 = nn.BatchNorm2d(planes) 119 | self.conv2 = self._conv3x3(planes, planes) 120 | self.bn2 = nn.BatchNorm2d(planes) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def _conv3x3(self, in_planes, out_planes, stride=1): 126 | "3x3 convolution with padding" 127 | return nn.Conv2d( 128 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 129 | ) 130 | 131 | def forward(self, x): 132 | residual = x 133 | 134 | out = self.conv1(x) 135 | out = self.bn1(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv2(out) 139 | out = self.bn2(out) 140 | 141 | if self.downsample is not None: 142 | residual = self.downsample(x) 143 | out += residual 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | def __init__(self, n_input_channels: int, n_output_channels: int, block, layers): 151 | """ 152 | :param n_input_channels (int): The number of input channels of the feature extractor 153 | :param n_output_channels (int): The number of output channels of the feature extractor 154 | :param block: 155 | :param layers: 156 | """ 157 | super(ResNet, self).__init__() 158 | 159 | self.output_channel_blocks = [ 160 | int(n_output_channels / 4), 161 | int(n_output_channels / 2), 162 | n_output_channels, 163 | n_output_channels, 164 | ] 165 | 166 | self.inplanes = int(n_output_channels / 8) 167 | self.conv0_1 = nn.Conv2d( 168 | n_input_channels, 169 | int(n_output_channels / 16), 170 | kernel_size=3, 171 | stride=1, 172 | padding=1, 173 | bias=False, 174 | ) 175 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16)) 176 | self.conv0_2 = nn.Conv2d( 177 | int(n_output_channels / 16), 178 | self.inplanes, 179 | kernel_size=3, 180 | stride=1, 181 | padding=1, 182 | bias=False, 183 | ) 184 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 185 | self.relu = nn.ReLU(inplace=True) 186 | 187 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 188 | self.layer1 = self._make_layer(block, self.output_channel_blocks[0], layers[0]) 189 | self.conv1 = nn.Conv2d( 190 | self.output_channel_blocks[0], 191 | self.output_channel_blocks[0], 192 | kernel_size=3, 193 | stride=1, 194 | padding=1, 195 | bias=False, 196 | ) 197 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0]) 198 | 199 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 200 | self.layer2 = self._make_layer( 201 | block, self.output_channel_blocks[1], layers[1], stride=1 202 | ) 203 | self.conv2 = nn.Conv2d( 204 | self.output_channel_blocks[1], 205 | self.output_channel_blocks[1], 206 | kernel_size=3, 207 | stride=1, 208 | padding=1, 209 | bias=False, 210 | ) 211 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1]) 212 | 213 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 214 | self.layer3 = self._make_layer( 215 | block, self.output_channel_blocks[2], layers[2], stride=1 216 | ) 217 | self.conv3 = nn.Conv2d( 218 | self.output_channel_blocks[2], 219 | self.output_channel_blocks[2], 220 | kernel_size=3, 221 | stride=1, 222 | padding=1, 223 | bias=False, 224 | ) 225 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2]) 226 | 227 | self.layer4 = self._make_layer( 228 | block, self.output_channel_blocks[3], layers[3], stride=1 229 | ) 230 | self.conv4_1 = nn.Conv2d( 231 | self.output_channel_blocks[3], 232 | self.output_channel_blocks[3], 233 | kernel_size=2, 234 | stride=(2, 1), 235 | padding=(0, 1), 236 | bias=False, 237 | ) 238 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3]) 239 | self.conv4_2 = nn.Conv2d( 240 | self.output_channel_blocks[3], 241 | self.output_channel_blocks[3], 242 | kernel_size=2, 243 | stride=1, 244 | padding=0, 245 | bias=False, 246 | ) 247 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3]) 248 | 249 | def _make_layer(self, block, planes, blocks, stride=1): 250 | downsample = None 251 | if stride != 1 or self.inplanes != planes * block.expansion: 252 | downsample = nn.Sequential( 253 | nn.Conv2d( 254 | self.inplanes, 255 | planes * block.expansion, 256 | kernel_size=1, 257 | stride=stride, 258 | bias=False, 259 | ), 260 | nn.BatchNorm2d(planes * block.expansion), 261 | ) 262 | 263 | layers = [] 264 | layers.append(block(self.inplanes, planes, stride, downsample)) 265 | self.inplanes = planes * block.expansion 266 | for i in range(1, blocks): 267 | layers.append(block(self.inplanes, planes)) 268 | 269 | return nn.Sequential(*layers) 270 | 271 | def forward(self, x): 272 | x = self.conv0_1(x) 273 | x = self.bn0_1(x) 274 | x = self.relu(x) 275 | x = self.conv0_2(x) 276 | x = self.bn0_2(x) 277 | x = self.relu(x) 278 | 279 | x = self.maxpool1(x) 280 | x = self.layer1(x) 281 | x = self.conv1(x) 282 | x = self.bn1(x) 283 | x = self.relu(x) 284 | 285 | x = self.maxpool2(x) 286 | x = self.layer2(x) 287 | x = self.conv2(x) 288 | x = self.bn2(x) 289 | x = self.relu(x) 290 | 291 | x = self.maxpool3(x) 292 | x = self.layer3(x) 293 | x = self.conv3(x) 294 | x = self.bn3(x) 295 | x = self.relu(x) 296 | 297 | x = self.layer4(x) 298 | x = self.conv4_1(x) 299 | x = self.bn4_1(x) 300 | x = self.relu(x) 301 | x = self.conv4_2(x) 302 | x = self.bn4_2(x) 303 | x = self.relu(x) 304 | 305 | return x 306 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class Attention(nn.Module): 16 | def __init__(self, input_size, hidden_size, num_classes): 17 | super(Attention, self).__init__() 18 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 19 | self.hidden_size = hidden_size 20 | self.num_classes = num_classes 21 | self.generator = nn.Linear(hidden_size, num_classes) 22 | 23 | def _char_to_onehot(self, input_char, onehot_dim=38): 24 | input_char = input_char.unsqueeze(1) 25 | batch_size = input_char.size(0) 26 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 27 | one_hot = one_hot.scatter_(1, input_char, 1) 28 | return one_hot 29 | 30 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 31 | """ 32 | input: 33 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 34 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 35 | output: probability distribution at each step [batch_size x num_steps x num_classes] 36 | """ 37 | batch_size = batch_H.size(0) 38 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 39 | 40 | output_hiddens = ( 41 | torch.FloatTensor(batch_size, num_steps, self.hidden_size) 42 | .fill_(0) 43 | .to(device) 44 | ) 45 | hidden = ( 46 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 47 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 48 | ) 49 | 50 | if is_train: 51 | for i in range(num_steps): 52 | # one-hot vectors for a i-th char. in a batch 53 | char_onehots = self._char_to_onehot( 54 | text[:, i], onehot_dim=self.num_classes 55 | ) 56 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 57 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 58 | output_hiddens[:, i, :] = hidden[ 59 | 0 60 | ] # LSTM hidden index (0: hidden, 1: Cell) 61 | probs = self.generator(output_hiddens) 62 | 63 | else: 64 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 65 | probs = ( 66 | torch.FloatTensor(batch_size, num_steps, self.num_classes) 67 | .fill_(0) 68 | .to(device) 69 | ) 70 | 71 | for i in range(num_steps): 72 | char_onehots = self._char_to_onehot( 73 | targets, onehot_dim=self.num_classes 74 | ) 75 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 76 | probs_step = self.generator(hidden[0]) 77 | probs[:, i, :] = probs_step 78 | _, next_input = probs_step.max(1) 79 | targets = next_input 80 | 81 | return probs # batch_size x num_steps x num_classes 82 | 83 | 84 | class AttentionCell(nn.Module): 85 | def __init__(self, input_size, hidden_size, num_embeddings): 86 | super(AttentionCell, self).__init__() 87 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 88 | self.h2h = nn.Linear( 89 | hidden_size, hidden_size 90 | ) # either i2i or h2h should have bias 91 | self.score = nn.Linear(hidden_size, 1, bias=False) 92 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 93 | self.hidden_size = hidden_size 94 | 95 | def forward(self, prev_hidden, batch_H, char_onehots): 96 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 97 | batch_H_proj = self.i2h(batch_H) 98 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 99 | e = self.score( 100 | torch.tanh(batch_H_proj + prev_hidden_proj) 101 | ) # batch_size x num_encoder_step * 1 102 | 103 | alpha = F.softmax(e, dim=1) 104 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze( 105 | 1 106 | ) # batch_size x num_channel 107 | concat_context = torch.cat( 108 | [context, char_onehots], 1 109 | ) # batch_size x (num_channel + num_embedding) 110 | cur_hidden = self.rnn(concat_context, prev_hidden) 111 | return cur_hidden, alpha 112 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class BidirectionalLSTM(nn.Module): 12 | def __init__(self, input_size: int, hidden_size: int, output_size: int): 13 | super(BidirectionalLSTM, self).__init__() 14 | self.rnn = nn.LSTM( 15 | input_size, hidden_size, bidirectional=True, batch_first=True 16 | ) 17 | self.linear = nn.Linear(hidden_size * 2, output_size) 18 | 19 | def forward(self, x): 20 | """ 21 | x : visual feature [batch_size x T=24 x input_size=512] 22 | output : contextual feature [batch_size x T x output_size] 23 | """ 24 | self.rnn.flatten_parameters() 25 | recurrent, _ = self.rnn( 26 | x 27 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 28 | output = self.linear(recurrent) # batch_size x T x output_size 29 | return output 30 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/transformation.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class TpsSpatialTransformerNetwork(nn.Module): 17 | """Rectification Network of RARE, namely TPS based STN""" 18 | 19 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1): 20 | """Based on RARE TPS 21 | input: 22 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 23 | I_size : (height, width) of the input image I 24 | I_r_size : (height, width) of the rectified image I_r 25 | I_channel_num : the number of channels of the input image I 26 | output: 27 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 28 | """ 29 | super(TpsSpatialTransformerNetwork, self).__init__() 30 | self.F = F 31 | self.I_size = I_size 32 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 33 | self.I_channel_num = I_channel_num 34 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 35 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 36 | 37 | def forward(self, batch_I): 38 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 39 | build_P_prime = self.GridGenerator.build_P_prime( 40 | batch_C_prime 41 | ) # batch_size x n (= I_r_width x I_r_height) x 2 42 | build_P_prime_reshape = build_P_prime.reshape( 43 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2] 44 | ) 45 | 46 | # if torch.__version__ > "1.2.0": 47 | # batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 48 | # else: 49 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode="border") 50 | 51 | return batch_I_r 52 | 53 | 54 | class LocalizationNetwork(nn.Module): 55 | """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)""" 56 | 57 | def __init__(self, F, I_channel_num: int): 58 | super(LocalizationNetwork, self).__init__() 59 | self.F = F 60 | self.I_channel_num = I_channel_num 61 | self.conv = nn.Sequential( 62 | nn.Conv2d( 63 | in_channels=self.I_channel_num, 64 | out_channels=64, 65 | kernel_size=3, 66 | stride=1, 67 | padding=1, 68 | bias=False, 69 | ), 70 | nn.BatchNorm2d(64), 71 | nn.ReLU(True), 72 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 73 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), 74 | nn.BatchNorm2d(128), 75 | nn.ReLU(True), 76 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 77 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), 78 | nn.BatchNorm2d(256), 79 | nn.ReLU(True), 80 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 81 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), 82 | nn.BatchNorm2d(512), 83 | nn.ReLU(True), 84 | nn.AdaptiveAvgPool2d(1), # batch_size x 512 85 | ) 86 | 87 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 88 | self.localization_fc2 = nn.Linear(256, self.F * 2) 89 | 90 | # Init fc2 in LocalizationNetwork 91 | self.localization_fc2.weight.data.fill_(0) 92 | 93 | # see RARE paper Fig. 6 (a) 94 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 95 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 96 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 97 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 98 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 99 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 100 | self.localization_fc2.bias.data = ( 101 | torch.from_numpy(initial_bias).float().view(-1) 102 | ) 103 | 104 | def forward(self, batch_I): 105 | """ 106 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 107 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 108 | """ 109 | batch_size = batch_I.size(0) 110 | features = self.conv(batch_I).view(batch_size, -1) 111 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view( 112 | batch_size, self.F, 2 113 | ) 114 | return batch_C_prime 115 | 116 | 117 | class GridGenerator(nn.Module): 118 | """Grid Generator of RARE, which produces P_prime by multipling T with P""" 119 | 120 | def __init__(self, F, I_r_size): 121 | """Generate P_hat and inv_delta_C for later""" 122 | super(GridGenerator, self).__init__() 123 | self.eps = 1e-6 124 | self.I_r_height, self.I_r_width = I_r_size 125 | self.F = F 126 | self.C = self._build_C(self.F) # F x 2 127 | self.P = self._build_P(self.I_r_width, self.I_r_height) 128 | 129 | # for multi-gpu, you need register buffer 130 | self.register_buffer( 131 | "inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() 132 | ) # F+3 x F+3 133 | self.register_buffer( 134 | "P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() 135 | ) # n x F+3 136 | 137 | def _build_C(self, F): 138 | """Return coordinates of fiducial points in I_r; C""" 139 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 140 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 141 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 142 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 143 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 144 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 145 | return C # F x 2 146 | 147 | def _build_inv_delta_C(self, F, C): 148 | """Return inv_delta_C which is needed to calculate T""" 149 | hat_C = np.zeros((F, F), dtype=float) # F x F 150 | for i in range(0, F): 151 | for j in range(i, F): 152 | r = np.linalg.norm(C[i] - C[j]) 153 | hat_C[i, j] = r 154 | hat_C[j, i] = r 155 | np.fill_diagonal(hat_C, 1) 156 | hat_C = (hat_C**2) * np.log(hat_C) 157 | # print(C.shape, hat_C.shape) 158 | delta_C = np.concatenate( # F+3 x F+3 159 | [ 160 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 161 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 162 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3 163 | ], 164 | axis=0, 165 | ) 166 | inv_delta_C = np.linalg.inv(delta_C) 167 | return inv_delta_C # F+3 x F+3 168 | 169 | def _build_P(self, I_r_width, I_r_height): 170 | I_r_grid_x = ( 171 | np.arange(-I_r_width, I_r_width, 2) + 1.0 172 | ) / I_r_width # self.I_r_width 173 | I_r_grid_y = ( 174 | np.arange(-I_r_height, I_r_height, 2) + 1.0 175 | ) / I_r_height # self.I_r_height 176 | P = np.stack( # self.I_r_width x self.I_r_height x 2 177 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2 178 | ) 179 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 180 | 181 | def _build_P_hat(self, F, C, P): 182 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 183 | P_tile = np.tile( 184 | np.expand_dims(P, axis=1), (1, F, 1) 185 | ) # n x 2 -> n x 1 x 2 -> n x F x 2 186 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 187 | P_diff = P_tile - C_tile # n x F x 2 188 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 189 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 190 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 191 | return P_hat # n x F+3 192 | 193 | def build_P_prime(self, batch_C_prime): 194 | """Generate Grid from batch_C_prime [batch_size x F x 2]""" 195 | batch_size = batch_C_prime.size(0) 196 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 197 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 198 | batch_C_prime_with_zeros = torch.cat( 199 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1 200 | ) # batch_size x F+3 x 2 201 | batch_T = torch.bmm( 202 | batch_inv_delta_C, batch_C_prime_with_zeros 203 | ) # batch_size x F+3 x 2 204 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 205 | return batch_P_prime # batch_size x n x 2 206 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/recognition.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/recognition.py 10 | """ 11 | 12 | import math 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | import torch.utils.data 18 | import torchvision.transforms as transforms 19 | from PIL import Image 20 | 21 | from .model import Model 22 | from .utils import CTCLabelConverter 23 | 24 | 25 | def contrast_grey(img): 26 | high = np.percentile(img, 90) 27 | low = np.percentile(img, 10) 28 | return (high - low) / np.maximum(10, high + low), high, low 29 | 30 | 31 | def adjust_contrast_grey(img, target: float = 0.4): 32 | contrast, high, low = contrast_grey(img) 33 | if contrast < target: 34 | img = img.astype(int) 35 | ratio = 200.0 / np.maximum(10, high - low) 36 | img = (img - low + 25) * ratio 37 | img = np.maximum( 38 | np.full(img.shape, 0), 39 | np.minimum( 40 | np.full(img.shape, 255), 41 | img, 42 | ), 43 | ).astype(np.uint8) 44 | return img 45 | 46 | 47 | class NormalizePAD(object): 48 | def __init__(self, max_size, PAD_type: str = "right"): 49 | self.toTensor = transforms.ToTensor() 50 | self.max_size = max_size 51 | self.max_width_half = math.floor(max_size[2] / 2) 52 | self.PAD_type = PAD_type 53 | 54 | def __call__(self, img): 55 | img = self.toTensor(img) 56 | img.sub_(0.5).div_(0.5) 57 | c, h, w = img.size() 58 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 59 | Pad_img[:, :, :w] = img # right pad 60 | if self.max_size[2] != w: # add border Pad 61 | Pad_img[:, :, w:] = ( 62 | img[:, :, w - 1] 63 | .unsqueeze(2) 64 | .expand( 65 | c, 66 | h, 67 | self.max_size[2] - w, 68 | ) 69 | ) 70 | 71 | return Pad_img 72 | 73 | 74 | class ListDataset(torch.utils.data.Dataset): 75 | def __init__(self, image_list: list): 76 | self.image_list = image_list 77 | self.nSamples = len(image_list) 78 | 79 | def __len__(self): 80 | return self.nSamples 81 | 82 | def __getitem__(self, index): 83 | img = self.image_list[index] 84 | return Image.fromarray(img, "L") 85 | 86 | 87 | class AlignCollate(object): 88 | def __init__(self, imgH: int, imgW: int, adjust_contrast: float): 89 | self.imgH = imgH 90 | self.imgW = imgW 91 | self.keep_ratio_with_pad = True # Do Not Change 92 | self.adjust_contrast = adjust_contrast 93 | 94 | def __call__(self, batch): 95 | batch = filter(lambda x: x is not None, batch) 96 | images = batch 97 | 98 | resized_max_w = self.imgW 99 | input_channel = 1 100 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 101 | 102 | resized_images = [] 103 | for image in images: 104 | w, h = image.size 105 | # augmentation here - change contrast 106 | if self.adjust_contrast > 0: 107 | image = np.array(image.convert("L")) 108 | image = adjust_contrast_grey(image, target=self.adjust_contrast) 109 | image = Image.fromarray(image, "L") 110 | 111 | ratio = w / float(h) 112 | if math.ceil(self.imgH * ratio) > self.imgW: 113 | resized_w = self.imgW 114 | else: 115 | resized_w = math.ceil(self.imgH * ratio) 116 | 117 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 118 | resized_images.append(transform(resized_image)) 119 | 120 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 121 | return image_tensors 122 | 123 | 124 | def recognizer_predict(model, converter, test_loader, opt2val: dict): 125 | device = opt2val["device"] 126 | 127 | model.eval() 128 | result = [] 129 | with torch.no_grad(): 130 | for image_tensors in test_loader: 131 | batch_size = image_tensors.size(0) 132 | inputs = image_tensors.to(device) 133 | preds = model(inputs) # (N, length, num_classes) 134 | 135 | # rebalance 136 | preds_prob = F.softmax(preds, dim=2) 137 | preds_prob = preds_prob.cpu().detach().numpy() 138 | pred_norm = preds_prob.sum(axis=2) 139 | preds_prob = preds_prob / np.expand_dims(pred_norm, axis=-1) 140 | preds_prob = torch.from_numpy(preds_prob).float().to(device) 141 | 142 | # Select max probabilty (greedy decoding), then decode index to character 143 | preds_lengths = torch.IntTensor([preds.size(1)] * batch_size) # (N,) 144 | _, preds_indices = preds_prob.max(2) # (N, length) 145 | preds_indices = preds_indices.view(-1) # (N*length) 146 | preds_str = converter.decode_greedy(preds_indices, preds_lengths) 147 | 148 | preds_max_prob, _ = preds_prob.max(dim=2) 149 | 150 | for pred, pred_max_prob in zip(preds_str, preds_max_prob): 151 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 152 | result.append([pred, confidence_score.item()]) 153 | 154 | return result 155 | 156 | 157 | def get_recognizer(opt2val: dict): 158 | """ 159 | :return: 160 | recognizer: recognition net 161 | converter: CTCLabelConverter 162 | """ 163 | # converter 164 | vocab = opt2val["vocab"] 165 | converter = CTCLabelConverter(vocab) 166 | 167 | # recognizer 168 | recognizer = Model(opt2val) 169 | 170 | # state_dict 171 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"] 172 | device = opt2val["device"] 173 | state_dict = torch.load(rec_model_ckpt_fp, map_location=device) 174 | 175 | if device == "cuda": 176 | recognizer = torch.nn.DataParallel(recognizer).to(device) 177 | else: 178 | # TODO temporary: multigpu 학습한 뒤 ckpt loading 문제 179 | from collections import OrderedDict 180 | 181 | def _sync_tensor_name(state_dict): 182 | state_dict_ = OrderedDict() 183 | for name, val in state_dict.items(): 184 | name = name.replace("module.", "") 185 | state_dict_[name] = val 186 | return state_dict_ 187 | 188 | state_dict = _sync_tensor_name(state_dict) 189 | 190 | recognizer.load_state_dict(state_dict) 191 | 192 | return recognizer, converter 193 | 194 | 195 | def get_text(image_list, recognizer, converter, opt2val: dict): 196 | imgW = opt2val["imgW"] 197 | imgH = opt2val["imgH"] 198 | adjust_contrast = opt2val["adjust_contrast"] 199 | batch_size = opt2val["batch_size"] 200 | n_workers = opt2val["n_workers"] 201 | contrast_ths = opt2val["contrast_ths"] 202 | 203 | # TODO: figure out what is this for 204 | # batch_max_length = int(imgW / 10) 205 | 206 | coord = [item[0] for item in image_list] 207 | img_list = [item[1] for item in image_list] 208 | AlignCollate_normal = AlignCollate(imgH, imgW, adjust_contrast) 209 | test_data = ListDataset(img_list) 210 | test_loader = torch.utils.data.DataLoader( 211 | test_data, 212 | batch_size=batch_size, 213 | shuffle=False, 214 | num_workers=n_workers, 215 | collate_fn=AlignCollate_normal, 216 | pin_memory=True, 217 | ) 218 | 219 | # predict first round 220 | result1 = recognizer_predict(recognizer, converter, test_loader, opt2val) 221 | 222 | # predict second round 223 | low_confident_idx = [ 224 | i for i, item in enumerate(result1) if (item[1] < contrast_ths) 225 | ] 226 | if len(low_confident_idx) > 0: 227 | img_list2 = [img_list[i] for i in low_confident_idx] 228 | AlignCollate_contrast = AlignCollate(imgH, imgW, adjust_contrast) 229 | test_data = ListDataset(img_list2) 230 | test_loader = torch.utils.data.DataLoader( 231 | test_data, 232 | batch_size=batch_size, 233 | shuffle=False, 234 | num_workers=n_workers, 235 | collate_fn=AlignCollate_contrast, 236 | pin_memory=True, 237 | ) 238 | result2 = recognizer_predict(recognizer, converter, test_loader, opt2val) 239 | 240 | result = [] 241 | for i, zipped in enumerate(zip(coord, result1)): 242 | box, pred1 = zipped 243 | if i in low_confident_idx: 244 | pred2 = result2[low_confident_idx.index(i)] 245 | if pred1[1] > pred2[1]: 246 | result.append((box, pred1[0], pred1[1])) 247 | else: 248 | result.append((box, pred2[0], pred2[1])) 249 | else: 250 | result.append((box, pred1[0], pred1[1])) 251 | 252 | return result 253 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/utils.py 10 | """ 11 | 12 | import math 13 | import os 14 | from urllib.request import urlretrieve 15 | 16 | import cv2 17 | import numpy as np 18 | import torch 19 | from PIL import Image 20 | from torch import Tensor 21 | 22 | from .imgproc import load_image 23 | 24 | 25 | def consecutive(data, mode: str = "first", stepsize: int = 1): 26 | group = np.split(data, np.where(np.diff(data) != stepsize)[0] + 1) 27 | group = [item for item in group if len(item) > 0] 28 | 29 | if mode == "first": 30 | result = [l[0] for l in group] 31 | elif mode == "last": 32 | result = [l[-1] for l in group] 33 | return result 34 | 35 | 36 | def word_segmentation( 37 | mat, 38 | separator_idx={"th": [1, 2], "en": [3, 4]}, 39 | separator_idx_list=[1, 2, 3, 4], 40 | ): 41 | result = [] 42 | sep_list = [] 43 | start_idx = 0 44 | sep_lang = "" 45 | for sep_idx in separator_idx_list: 46 | if sep_idx % 2 == 0: 47 | mode = "first" 48 | else: 49 | mode = "last" 50 | a = consecutive(np.argwhere(mat == sep_idx).flatten(), mode) 51 | new_sep = [[item, sep_idx] for item in a] 52 | sep_list += new_sep 53 | sep_list = sorted(sep_list, key=lambda x: x[0]) 54 | 55 | for sep in sep_list: 56 | for lang in separator_idx.keys(): 57 | if sep[1] == separator_idx[lang][0]: # start lang 58 | sep_lang = lang 59 | sep_start_idx = sep[0] 60 | elif sep[1] == separator_idx[lang][1]: # end lang 61 | if sep_lang == lang: # check if last entry if the same start lang 62 | new_sep_pair = [lang, [sep_start_idx + 1, sep[0] - 1]] 63 | if sep_start_idx > start_idx: 64 | result.append(["", [start_idx, sep_start_idx - 1]]) 65 | start_idx = sep[0] + 1 66 | result.append(new_sep_pair) 67 | sep_lang = "" # reset 68 | 69 | if start_idx <= len(mat) - 1: 70 | result.append(["", [start_idx, len(mat) - 1]]) 71 | return result 72 | 73 | 74 | # code is based from https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py 75 | class BeamEntry: 76 | "information about one single beam at specific time-step" 77 | 78 | def __init__(self): 79 | self.prTotal = 0 # blank and non-blank 80 | self.prNonBlank = 0 # non-blank 81 | self.prBlank = 0 # blank 82 | self.prText = 1 # LM score 83 | self.lmApplied = False # flag if LM was already applied to this beam 84 | self.labeling = () # beam-labeling 85 | 86 | 87 | class BeamState: 88 | "information about the beams at specific time-step" 89 | 90 | def __init__(self): 91 | self.entries = {} 92 | 93 | def norm(self): 94 | "length-normalise LM score" 95 | for k, _ in self.entries.items(): 96 | labelingLen = len(self.entries[k].labeling) 97 | self.entries[k].prText = self.entries[k].prText ** ( 98 | 1.0 / (labelingLen if labelingLen else 1.0) 99 | ) 100 | 101 | def sort(self): 102 | "return beam-labelings, sorted by probability" 103 | beams = [v for (_, v) in self.entries.items()] 104 | sortedBeams = sorted( 105 | beams, 106 | reverse=True, 107 | key=lambda x: x.prTotal * x.prText, 108 | ) 109 | return [x.labeling for x in sortedBeams] 110 | 111 | def wordsearch(self, classes, ignore_idx, maxCandidate, dict_list): 112 | beams = [v for (_, v) in self.entries.items()] 113 | sortedBeams = sorted( 114 | beams, 115 | reverse=True, 116 | key=lambda x: x.prTotal * x.prText, 117 | ) 118 | if len(sortedBeams) > maxCandidate: 119 | sortedBeams = sortedBeams[:maxCandidate] 120 | 121 | for j, candidate in enumerate(sortedBeams): 122 | idx_list = candidate.labeling 123 | text = "" 124 | for i, l in enumerate(idx_list): 125 | if l not in ignore_idx and ( 126 | not (i > 0 and idx_list[i - 1] == idx_list[i]) 127 | ): 128 | text += classes[l] 129 | 130 | if j == 0: 131 | best_text = text 132 | if text in dict_list: 133 | # print('found text: ', text) 134 | best_text = text 135 | break 136 | else: 137 | pass 138 | # print('not in dict: ', text) 139 | return best_text 140 | 141 | 142 | def applyLM(parentBeam, childBeam, classes, lm_model, lm_factor: float = 0.01): 143 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" 144 | if lm_model is not None and not childBeam.lmApplied: 145 | history = parentBeam.labeling 146 | history = " ".join( 147 | classes[each].replace(" ", "▁") for each in history if each != 0 148 | ) 149 | 150 | current_char = classes[childBeam.labeling[-1]].replace(" ", "▁") 151 | if current_char == "[blank]": 152 | lmProb = 1 153 | else: 154 | text = history + " " + current_char 155 | lmProb = 10 ** lm_model.score(text, bos=True) * lm_factor 156 | 157 | childBeam.prText = lmProb # probability of char sequence 158 | childBeam.lmApplied = True # only apply LM once per beam entry 159 | 160 | 161 | def simplify_label(labeling, blankIdx: int = 0): 162 | labeling = np.array(labeling) 163 | 164 | # collapse blank 165 | idx = np.where(~((np.roll(labeling, 1) == labeling) & (labeling == blankIdx)))[0] 166 | labeling = labeling[idx] 167 | 168 | # get rid of blank between different characters 169 | idx = np.where( 170 | ~((np.roll(labeling, 1) != np.roll(labeling, -1)) & (labeling == blankIdx)) 171 | )[0] 172 | 173 | if len(labeling) > 0: 174 | last_idx = len(labeling) - 1 175 | if last_idx not in idx: 176 | idx = np.append(idx, [last_idx]) 177 | labeling = labeling[idx] 178 | 179 | return tuple(labeling) 180 | 181 | 182 | def addBeam(beamState, labeling): 183 | "add beam if it does not yet exist" 184 | if labeling not in beamState.entries: 185 | beamState.entries[labeling] = BeamEntry() 186 | 187 | 188 | def ctcBeamSearch( 189 | mat, 190 | classes: list, 191 | ignore_idx: int, 192 | lm_model, 193 | lm_factor: float = 0.01, 194 | beam_width: int = 5, 195 | ): 196 | blankIdx = 0 197 | maxT, maxC = mat.shape 198 | 199 | # initialise beam state 200 | last = BeamState() 201 | labeling = () 202 | last.entries[labeling] = BeamEntry() 203 | last.entries[labeling].prBlank = 1 204 | last.entries[labeling].prTotal = 1 205 | 206 | # go over all time-steps 207 | for t in range(maxT): 208 | # print("t=", t) 209 | curr = BeamState() 210 | # get beam-labelings of best beams 211 | bestLabelings = last.sort()[0:beam_width] 212 | # go over best beams 213 | for labeling in bestLabelings: 214 | # print("labeling:", labeling) 215 | # probability of paths ending with a non-blank 216 | prNonBlank = 0 217 | # in case of non-empty beam 218 | if labeling: 219 | # probability of paths with repeated last char at the end 220 | prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] 221 | 222 | # probability of paths ending with a blank 223 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] 224 | 225 | # add beam at current time-step if needed 226 | labeling = simplify_label(labeling, blankIdx) 227 | addBeam(curr, labeling) 228 | 229 | # fill in data 230 | curr.entries[labeling].labeling = labeling 231 | curr.entries[labeling].prNonBlank += prNonBlank 232 | curr.entries[labeling].prBlank += prBlank 233 | curr.entries[labeling].prTotal += prBlank + prNonBlank 234 | curr.entries[labeling].prText = last.entries[labeling].prText 235 | # beam-labeling not changed, therefore also LM score unchanged from 236 | 237 | curr.entries[ 238 | labeling 239 | ].lmApplied = ( 240 | True # LM already applied at previous time-step for this beam-labeling 241 | ) 242 | 243 | # extend current beam-labeling 244 | # char_highscore = np.argpartition(mat[t, :], -5)[-5:] # run through 5 highest probability 245 | char_highscore = np.where(mat[t, :] >= 0.5 / maxC)[ 246 | 0 247 | ] # run through all probable characters 248 | for c in char_highscore: 249 | # for c in range(maxC - 1): 250 | # add new char to current beam-labeling 251 | newLabeling = labeling + (c,) 252 | newLabeling = simplify_label(newLabeling, blankIdx) 253 | 254 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 255 | if labeling and labeling[-1] == c: 256 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank 257 | else: 258 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal 259 | 260 | # add beam at current time-step if needed 261 | addBeam(curr, newLabeling) 262 | 263 | # fill in data 264 | curr.entries[newLabeling].labeling = newLabeling 265 | curr.entries[newLabeling].prNonBlank += prNonBlank 266 | curr.entries[newLabeling].prTotal += prNonBlank 267 | 268 | # apply LM 269 | applyLM( 270 | curr.entries[labeling], 271 | curr.entries[newLabeling], 272 | classes, 273 | lm_model, 274 | lm_factor, 275 | ) 276 | 277 | # set new beam state 278 | 279 | last = curr 280 | 281 | # normalise LM scores according to beam-labeling-length 282 | last.norm() 283 | 284 | bestLabeling = last.sort()[0] # get most probable labeling 285 | res = "" 286 | for i, l in enumerate(bestLabeling): 287 | # removing repeated characters and blank. 288 | if l != ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): 289 | res += classes[l] 290 | 291 | return res 292 | 293 | 294 | class CTCLabelConverter(object): 295 | """Convert between text-label and text-index""" 296 | 297 | def __init__(self, vocab: list): 298 | self.char2idx = {char: idx for idx, char in enumerate(vocab)} 299 | self.idx2char = {idx: char for idx, char in enumerate(vocab)} 300 | self.ignored_index = 0 301 | self.vocab = vocab 302 | 303 | def encode(self, texts: list): 304 | """ 305 | Convert input texts into indices 306 | texts (list): text labels of each image. [batch_size] 307 | 308 | Returns 309 | text: concatenated text index for CTCLoss. 310 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 311 | length: length of each text. [batch_size] 312 | """ 313 | lengths = [len(text) for text in texts] 314 | concatenated_text = "".join(texts) 315 | indices = [self.char2idx[char] for char in concatenated_text] 316 | 317 | return torch.IntTensor(indices), torch.IntTensor(lengths) 318 | 319 | def decode_greedy(self, indices: Tensor, lengths: Tensor): 320 | """convert text-index into text-label. 321 | 322 | :param indices (1D int32 Tensor): [N*length,] 323 | :param lengths (1D int32 Tensor): [N,] 324 | :return: 325 | """ 326 | texts = [] 327 | index = 0 328 | for length in lengths: 329 | text = indices[index : index + length] 330 | 331 | chars = [] 332 | for i in range(length): 333 | if (text[i] != self.ignored_index) and ( 334 | not (i > 0 and text[i - 1] == text[i]) 335 | ): # removing repeated characters and blank (and separator). 336 | chars.append(self.idx2char[text[i].item()]) 337 | texts.append("".join(chars)) 338 | index += length 339 | return texts 340 | 341 | def decode_beamsearch(self, mat, lm_model, lm_factor, beam_width: int = 5): 342 | texts = [] 343 | for i in range(mat.shape[0]): 344 | text = ctcBeamSearch( 345 | mat[i], 346 | self.vocab, 347 | self.ignored_index, 348 | lm_model, 349 | lm_factor, 350 | beam_width, 351 | ) 352 | texts.append(text) 353 | return texts 354 | 355 | 356 | def four_point_transform(image, rect): 357 | (tl, tr, br, bl) = rect 358 | 359 | widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) 360 | widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) 361 | maxWidth = max(int(widthA), int(widthB)) 362 | 363 | # compute the height of the new image, which will be the 364 | # maximum distance between the top-right and bottom-right 365 | # y-coordinates or the top-left and bottom-left y-coordinates 366 | heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) 367 | heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) 368 | maxHeight = max(int(heightA), int(heightB)) 369 | 370 | dst = np.array( 371 | [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], 372 | dtype="float32", 373 | ) 374 | 375 | # compute the perspective transform matrix and then apply it 376 | M = cv2.getPerspectiveTransform(rect, dst) 377 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) 378 | 379 | return warped 380 | 381 | 382 | def group_text_box( 383 | polys, 384 | slope_ths: float = 0.1, 385 | ycenter_ths: float = 0.5, 386 | height_ths: float = 0.5, 387 | width_ths: float = 1.0, 388 | add_margin: float = 0.05, 389 | ): 390 | # poly top-left, top-right, low-right, low-left 391 | horizontal_list, free_list, combined_list, merged_list = [], [], [], [] 392 | 393 | for poly in polys: 394 | slope_up = (poly[3] - poly[1]) / np.maximum(10, (poly[2] - poly[0])) 395 | slope_down = (poly[5] - poly[7]) / np.maximum(10, (poly[4] - poly[6])) 396 | if max(abs(slope_up), abs(slope_down)) < slope_ths: 397 | x_max = max([poly[0], poly[2], poly[4], poly[6]]) 398 | x_min = min([poly[0], poly[2], poly[4], poly[6]]) 399 | y_max = max([poly[1], poly[3], poly[5], poly[7]]) 400 | y_min = min([poly[1], poly[3], poly[5], poly[7]]) 401 | horizontal_list.append( 402 | [x_min, x_max, y_min, y_max, 0.5 * (y_min + y_max), y_max - y_min] 403 | ) 404 | else: 405 | height = np.linalg.norm([poly[6] - poly[0], poly[7] - poly[1]]) 406 | margin = int(1.44 * add_margin * height) 407 | 408 | theta13 = abs( 409 | np.arctan((poly[1] - poly[5]) / np.maximum(10, (poly[0] - poly[4]))) 410 | ) 411 | theta24 = abs( 412 | np.arctan((poly[3] - poly[7]) / np.maximum(10, (poly[2] - poly[6]))) 413 | ) 414 | # do I need to clip minimum, maximum value here? 415 | x1 = poly[0] - np.cos(theta13) * margin 416 | y1 = poly[1] - np.sin(theta13) * margin 417 | x2 = poly[2] + np.cos(theta24) * margin 418 | y2 = poly[3] - np.sin(theta24) * margin 419 | x3 = poly[4] + np.cos(theta13) * margin 420 | y3 = poly[5] + np.sin(theta13) * margin 421 | x4 = poly[6] - np.cos(theta24) * margin 422 | y4 = poly[7] + np.sin(theta24) * margin 423 | 424 | free_list.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 425 | horizontal_list = sorted(horizontal_list, key=lambda item: item[4]) 426 | 427 | # combine box 428 | new_box = [] 429 | for poly in horizontal_list: 430 | if len(new_box) == 0: 431 | b_height = [poly[5]] 432 | b_ycenter = [poly[4]] 433 | new_box.append(poly) 434 | else: 435 | # comparable height and comparable y_center level up to ths*height 436 | if (abs(np.mean(b_height) - poly[5]) < height_ths * np.mean(b_height)) and ( 437 | abs(np.mean(b_ycenter) - poly[4]) < ycenter_ths * np.mean(b_height) 438 | ): 439 | b_height.append(poly[5]) 440 | b_ycenter.append(poly[4]) 441 | new_box.append(poly) 442 | else: 443 | b_height = [poly[5]] 444 | b_ycenter = [poly[4]] 445 | combined_list.append(new_box) 446 | new_box = [poly] 447 | combined_list.append(new_box) 448 | 449 | # merge list use sort again 450 | for boxes in combined_list: 451 | if len(boxes) == 1: # one box per line 452 | box = boxes[0] 453 | margin = int(add_margin * box[5]) 454 | merged_list.append( 455 | [box[0] - margin, box[1] + margin, box[2] - margin, box[3] + margin] 456 | ) 457 | else: # multiple boxes per line 458 | boxes = sorted(boxes, key=lambda item: item[0]) 459 | 460 | merged_box, new_box = [], [] 461 | for box in boxes: 462 | if len(new_box) == 0: 463 | x_max = box[1] 464 | new_box.append(box) 465 | else: 466 | if abs(box[0] - x_max) < width_ths * ( 467 | box[3] - box[2] 468 | ): # merge boxes 469 | x_max = box[1] 470 | new_box.append(box) 471 | else: 472 | x_max = box[1] 473 | merged_box.append(new_box) 474 | new_box = [box] 475 | if len(new_box) > 0: 476 | merged_box.append(new_box) 477 | 478 | for mbox in merged_box: 479 | if len(mbox) != 1: # adjacent box in same line 480 | # do I need to add margin here? 481 | x_min = min(mbox, key=lambda x: x[0])[0] 482 | x_max = max(mbox, key=lambda x: x[1])[1] 483 | y_min = min(mbox, key=lambda x: x[2])[2] 484 | y_max = max(mbox, key=lambda x: x[3])[3] 485 | 486 | margin = int(add_margin * (y_max - y_min)) 487 | 488 | merged_list.append( 489 | [x_min - margin, x_max + margin, y_min - margin, y_max + margin] 490 | ) 491 | else: # non adjacent box in same line 492 | box = mbox[0] 493 | 494 | margin = int(add_margin * (box[3] - box[2])) 495 | merged_list.append( 496 | [ 497 | box[0] - margin, 498 | box[1] + margin, 499 | box[2] - margin, 500 | box[3] + margin, 501 | ] 502 | ) 503 | # may need to check if box is really in image 504 | return merged_list, free_list 505 | 506 | 507 | def get_image_list( 508 | horizontal_list: list, free_list: list, img: np.ndarray, model_height: int = 64 509 | ): 510 | image_list = [] 511 | maximum_y, maximum_x = img.shape 512 | 513 | max_ratio_hori, max_ratio_free = 1, 1 514 | for box in free_list: 515 | rect = np.array(box, dtype="float32") 516 | transformed_img = four_point_transform(img, rect) 517 | ratio = transformed_img.shape[1] / transformed_img.shape[0] 518 | crop_img = cv2.resize( 519 | transformed_img, 520 | (int(model_height * ratio), model_height), 521 | interpolation=Image.LANCZOS, 522 | ) 523 | # box : [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] 524 | image_list.append((box, crop_img)) 525 | max_ratio_free = max(ratio, max_ratio_free) 526 | 527 | max_ratio_free = math.ceil(max_ratio_free) 528 | 529 | for box in horizontal_list: 530 | x_min = max(0, box[0]) 531 | x_max = min(box[1], maximum_x) 532 | y_min = max(0, box[2]) 533 | y_max = min(box[3], maximum_y) 534 | crop_img = img[y_min:y_max, x_min:x_max] 535 | width = x_max - x_min 536 | height = y_max - y_min 537 | ratio = width / height 538 | crop_img = cv2.resize( 539 | crop_img, 540 | (int(model_height * ratio), model_height), 541 | interpolation=Image.LANCZOS, 542 | ) 543 | image_list.append( 544 | ( 545 | [ 546 | [x_min, y_min], 547 | [x_max, y_min], 548 | [x_max, y_max], 549 | [x_min, y_max], 550 | ], 551 | crop_img, 552 | ) 553 | ) 554 | max_ratio_hori = max(ratio, max_ratio_hori) 555 | 556 | max_ratio_hori = math.ceil(max_ratio_hori) 557 | max_ratio = max(max_ratio_hori, max_ratio_free) 558 | max_width = math.ceil(max_ratio) * model_height 559 | 560 | image_list = sorted( 561 | image_list, key=lambda item: item[0][0][1] 562 | ) # sort by vertical position 563 | return image_list, max_width 564 | 565 | 566 | def diff(input_list): 567 | return max(input_list) - min(input_list) 568 | 569 | 570 | def get_paragraph(raw_result, x_ths: int = 1, y_ths: float = 0.5, mode: str = "ltr"): 571 | # create basic attributes 572 | box_group = [] 573 | for box in raw_result: 574 | all_x = [int(coord[0]) for coord in box[0]] 575 | all_y = [int(coord[1]) for coord in box[0]] 576 | min_x = min(all_x) 577 | max_x = max(all_x) 578 | min_y = min(all_y) 579 | max_y = max(all_y) 580 | height = max_y - min_y 581 | box_group.append( 582 | [box[1], min_x, max_x, min_y, max_y, height, 0.5 * (min_y + max_y), 0] 583 | ) # last element indicates group 584 | # cluster boxes into paragraph 585 | current_group = 1 586 | while len([box for box in box_group if box[7] == 0]) > 0: 587 | # group0 = non-group 588 | box_group0 = [box for box in box_group if box[7] == 0] 589 | # new group 590 | if len([box for box in box_group if box[7] == current_group]) == 0: 591 | # assign first box to form new group 592 | box_group0[0][7] = current_group 593 | # try to add group 594 | else: 595 | current_box_group = [box for box in box_group if box[7] == current_group] 596 | mean_height = np.mean([box[5] for box in current_box_group]) 597 | # yapf: disable 598 | min_gx = min([box[1] for box in current_box_group]) - x_ths * mean_height 599 | max_gx = max([box[2] for box in current_box_group]) + x_ths * mean_height 600 | min_gy = min([box[3] for box in current_box_group]) - y_ths * mean_height 601 | max_gy = max([box[4] for box in current_box_group]) + y_ths * mean_height 602 | add_box = False 603 | for box in box_group0: 604 | same_horizontal_level = (min_gx <= box[1] <= max_gx) or (min_gx <= box[2] <= max_gx) 605 | same_vertical_level = (min_gy <= box[3] <= max_gy) or (min_gy <= box[4] <= max_gy) 606 | if same_horizontal_level and same_vertical_level: 607 | box[7] = current_group 608 | add_box = True 609 | break 610 | # cannot add more box, go to next group 611 | if not add_box: 612 | current_group += 1 613 | # yapf: enable 614 | # arrage order in paragraph 615 | result = [] 616 | for i in set(box[7] for box in box_group): 617 | current_box_group = [box for box in box_group if box[7] == i] 618 | mean_height = np.mean([box[5] for box in current_box_group]) 619 | min_gx = min([box[1] for box in current_box_group]) 620 | max_gx = max([box[2] for box in current_box_group]) 621 | min_gy = min([box[3] for box in current_box_group]) 622 | max_gy = max([box[4] for box in current_box_group]) 623 | 624 | text = "" 625 | while len(current_box_group) > 0: 626 | highest = min([box[6] for box in current_box_group]) 627 | candidates = [ 628 | box for box in current_box_group if box[6] < highest + 0.4 * mean_height 629 | ] 630 | # get the far left 631 | if mode == "ltr": 632 | most_left = min([box[1] for box in candidates]) 633 | for box in candidates: 634 | if box[1] == most_left: 635 | best_box = box 636 | elif mode == "rtl": 637 | most_right = max([box[2] for box in candidates]) 638 | for box in candidates: 639 | if box[2] == most_right: 640 | best_box = box 641 | text += " " + best_box[0] 642 | current_box_group.remove(best_box) 643 | 644 | result.append( 645 | [ 646 | [ 647 | [min_gx, min_gy], 648 | [max_gx, min_gy], 649 | [max_gx, max_gy], 650 | [min_gx, max_gy], 651 | ], 652 | text[1:], 653 | ] 654 | ) 655 | 656 | return result 657 | 658 | 659 | def printProgressBar( 660 | prefix="", 661 | suffix="", 662 | decimals: int = 1, 663 | length: int = 100, 664 | fill: str = "█", 665 | printEnd: str = "\r", 666 | ): 667 | """ 668 | Call in a loop to create terminal progress bar 669 | @params: 670 | prefix - Optional : prefix string (Str) 671 | suffix - Optional : suffix string (Str) 672 | decimals - Optional : positive number of decimals in percent complete (Int) 673 | length - Optional : character length of bar (Int) 674 | fill - Optional : bar fill character (Str) 675 | printEnd - Optional : end character (e.g. "\r", "\r\n") (Str) 676 | """ 677 | 678 | def progress_hook(count, blockSize, totalSize): 679 | progress = count * blockSize / totalSize 680 | percent = ("{0:." + str(decimals) + "f}").format(progress * 100) 681 | filledLength = int(length * progress) 682 | bar = fill * filledLength + "-" * (length - filledLength) 683 | print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=printEnd) 684 | 685 | return progress_hook 686 | 687 | 688 | def reformat_input(image): 689 | """ 690 | :param image: image file path or bytes or array 691 | :return: 692 | img (array): (original_image_height, original_image_width, 3) 693 | img_cv_grey (array): (original_image_height, original_image_width, 3) 694 | """ 695 | if type(image) == str: 696 | if image.startswith("http://") or image.startswith("https://"): 697 | tmp, _ = urlretrieve( 698 | image, 699 | reporthook=printProgressBar( 700 | prefix="Progress:", 701 | suffix="Complete", 702 | length=50, 703 | ), 704 | ) 705 | img_cv_grey = cv2.imread(tmp, cv2.IMREAD_GRAYSCALE) 706 | os.remove(tmp) 707 | else: 708 | img_cv_grey = cv2.imread(image, cv2.IMREAD_GRAYSCALE) 709 | image = os.path.expanduser(image) 710 | img = load_image(image) # can accept URL 711 | elif type(image) == bytes: 712 | nparr = np.frombuffer(image, np.uint8) 713 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 714 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 715 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 716 | 717 | elif type(image) == np.ndarray: 718 | if len(image.shape) == 2: # grayscale 719 | img_cv_grey = image 720 | img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 721 | elif len(image.shape) == 3 and image.shape[2] == 3: # BGRscale 722 | img = image 723 | img_cv_grey = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 724 | elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale 725 | img = image[:, :, :3] 726 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 727 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 728 | 729 | return img, img_cv_grey 730 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/pororo.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """ 9 | Pororo task-specific factory class 10 | 11 | isort:skip_file 12 | 13 | """ 14 | 15 | import logging 16 | from typing import Optional 17 | from .tasks.utils.base import PororoTaskBase 18 | 19 | import torch 20 | 21 | from .tasks import ( 22 | PororoOcrFactory, 23 | ) 24 | 25 | SUPPORTED_TASKS = { 26 | "ocr": PororoOcrFactory, 27 | } 28 | 29 | LANG_ALIASES = { 30 | "english": "en", 31 | "eng": "en", 32 | "korean": "ko", 33 | "kor": "ko", 34 | "kr": "ko", 35 | "chinese": "zh", 36 | "chn": "zh", 37 | "cn": "zh", 38 | "japanese": "ja", 39 | "jap": "ja", 40 | "jp": "ja", 41 | "jejueo": "je", 42 | "jje": "je", 43 | } 44 | 45 | logging.getLogger("transformers").setLevel(logging.WARN) 46 | logging.getLogger("fairseq").setLevel(logging.WARN) 47 | logging.getLogger("sentence_transformers").setLevel(logging.WARN) 48 | logging.getLogger("youtube_dl").setLevel(logging.WARN) 49 | logging.getLogger("pydub").setLevel(logging.WARN) 50 | logging.getLogger("librosa").setLevel(logging.WARN) 51 | 52 | 53 | class Pororo: 54 | r""" 55 | This is a generic class that will return one of the task-specific model classes of the library 56 | when created with the `__new__()` method 57 | 58 | """ 59 | 60 | def __new__( 61 | cls, 62 | task: str, 63 | lang: str = "en", 64 | model: Optional[str] = None, 65 | **kwargs, 66 | ) -> PororoTaskBase: 67 | if task not in SUPPORTED_TASKS: 68 | raise KeyError( 69 | "Unknown task {}, available tasks are {}".format( 70 | task, 71 | list(SUPPORTED_TASKS.keys()), 72 | ) 73 | ) 74 | 75 | lang = lang.lower() 76 | lang = LANG_ALIASES[lang] if lang in LANG_ALIASES else lang 77 | 78 | # Get device information from torch API 79 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 80 | 81 | # Instantiate task-specific pipeline module, if possible 82 | task_module = SUPPORTED_TASKS[task]( 83 | task, 84 | lang, 85 | model, 86 | **kwargs, 87 | ).load(device) 88 | 89 | return task_module 90 | 91 | @staticmethod 92 | def available_tasks() -> str: 93 | """ 94 | Returns available tasks in Pororo project 95 | 96 | Returns: 97 | str: Supported task names 98 | 99 | """ 100 | return "Available tasks are {}".format(list(SUPPORTED_TASKS.keys())) 101 | 102 | @staticmethod 103 | def available_models(task: str) -> str: 104 | """ 105 | Returns available model names correponding to the user-input task 106 | 107 | Args: 108 | task (str): user-input task name 109 | 110 | Returns: 111 | str: Supported model names corresponding to the user-input task 112 | 113 | Raises: 114 | KeyError: When user-input task is not supported 115 | 116 | """ 117 | if task not in SUPPORTED_TASKS: 118 | raise KeyError( 119 | "Unknown task {} ! Please check available models via `available_tasks()`".format( 120 | task 121 | ) 122 | ) 123 | 124 | langs = SUPPORTED_TASKS[task].get_available_models() 125 | output = f"Available models for {task} are " 126 | for lang in langs: 127 | output += f"([lang]: {lang}, [model]: {', '.join(langs[lang])}), " 128 | return output[:-2] 129 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | # flake8: noqa 9 | """ 10 | __init__.py for import child .py files 11 | 12 | isort:skip_file 13 | """ 14 | 15 | # Utility classes & functions 16 | # import pororo.tasks.utils 17 | from .utils.download_utils import download_or_load 18 | from .utils.base import ( 19 | PororoBiencoderBase, 20 | PororoFactoryBase, 21 | PororoGenerationBase, 22 | PororoSimpleBase, 23 | PororoTaskGenerationBase, 24 | ) 25 | 26 | # Factory classes 27 | from .optical_character_recognition import PororoOcrFactory 28 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/optical_character_recognition.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """OCR related modeling class""" 9 | 10 | from typing import Optional 11 | 12 | from . import download_or_load 13 | from .utils.base import PororoFactoryBase, PororoSimpleBase 14 | 15 | 16 | class PororoOcrFactory(PororoFactoryBase): 17 | """ 18 | Recognize optical characters in image file 19 | Currently support Korean language 20 | 21 | English + Korean (`brainocr`) 22 | 23 | - dataset: Internal data + AI hub Font Image dataset 24 | - metric: TBU 25 | - ref: https://www.aihub.or.kr/aidata/133 26 | 27 | Examples: 28 | >>> ocr = Pororo(task="ocr", lang="ko") 29 | >>> ocr(IMAGE_PATH) 30 | ["사이렌'(' 신마'", "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"] 31 | 32 | >>> ocr = Pororo(task="ocr", lang="ko") 33 | >>> ocr(IMAGE_PATH, detail=True) 34 | { 35 | 'description': ["사이렌'(' 신마', "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"], 36 | 'bounding_poly': [ 37 | { 38 | 'description': "사이렌'(' 신마'", 39 | 'vertices': [ 40 | {'x': 93, 'y': 7}, 41 | {'x': 164, 'y': 7}, 42 | {'x': 164, 'y': 21}, 43 | {'x': 93, 'y': 21} 44 | ] 45 | }, 46 | { 47 | 'description': "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고", 48 | 'vertices': [ 49 | {'x': 0, 'y': 30}, 50 | {'x': 259, 'y': 30}, 51 | {'x': 259, 'y': 194}, 52 | {'x': 0, 'y': 194}]} 53 | ] 54 | } 55 | } 56 | """ 57 | 58 | def __init__(self, task: str, lang: str, model: Optional[str]): 59 | super().__init__(task, lang, model) 60 | self.detect_model = "craft" 61 | self.ocr_opt = "ocr-opt" 62 | 63 | @staticmethod 64 | def get_available_langs(): 65 | return ["en", "ko"] 66 | 67 | @staticmethod 68 | def get_available_models(): 69 | return { 70 | "en": ["brainocr"], 71 | "ko": ["brainocr"], 72 | } 73 | 74 | def load(self, device: str): 75 | """ 76 | Load user-selected task-specific model 77 | 78 | Args: 79 | device (str): device information 80 | 81 | Returns: 82 | object: User-selected task-specific model 83 | 84 | """ 85 | if self.config.n_model == "brainocr": 86 | from ..models.brainOCR import brainocr 87 | 88 | if self.config.lang not in self.get_available_langs(): 89 | raise ValueError( 90 | f"Unsupported Language : {self.config.lang}", 91 | 'Support Languages : ["en", "ko"]', 92 | ) 93 | 94 | det_model_path = download_or_load( 95 | f"misc/{self.detect_model}.pt", 96 | self.config.lang, 97 | ) 98 | rec_model_path = download_or_load( 99 | f"misc/{self.config.n_model}.pt", 100 | self.config.lang, 101 | ) 102 | opt_fp = download_or_load( 103 | f"misc/{self.ocr_opt}.txt", 104 | self.config.lang, 105 | ) 106 | model = brainocr.Reader( 107 | self.config.lang, 108 | det_model_ckpt_fp=det_model_path, 109 | rec_model_ckpt_fp=rec_model_path, 110 | opt_fp=opt_fp, 111 | device=device, 112 | ) 113 | model.detector.to(device) 114 | model.recognizer.to(device) 115 | return PororoOCR(model, self.config) 116 | 117 | 118 | class PororoOCR(PororoSimpleBase): 119 | def __init__(self, model, config): 120 | super().__init__(config) 121 | self._model = model 122 | 123 | def _postprocess(self, ocr_results, detail: bool = False): 124 | """ 125 | Post-process for OCR result 126 | 127 | Args: 128 | ocr_results (list): list contains result of OCR 129 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc) 130 | 131 | """ 132 | sorted_ocr_results = sorted( 133 | ocr_results, 134 | key=lambda x: ( 135 | x[0][0][1], 136 | x[0][0][0], 137 | ), 138 | ) 139 | 140 | if not detail: 141 | return [sorted_ocr_results[i][-1] for i in range(len(sorted_ocr_results))] 142 | 143 | result_dict = { 144 | "description": list(), 145 | "bounding_poly": list(), 146 | } 147 | 148 | for ocr_result in sorted_ocr_results: 149 | vertices = list() 150 | 151 | for vertice in ocr_result[0]: 152 | vertices.append( 153 | { 154 | "x": vertice[0], 155 | "y": vertice[1], 156 | } 157 | ) 158 | 159 | result_dict["description"].append(ocr_result[1]) 160 | result_dict["bounding_poly"].append( 161 | {"description": ocr_result[1], "vertices": vertices} 162 | ) 163 | 164 | return result_dict 165 | 166 | def predict(self, image_path: str, **kwargs): 167 | """ 168 | Conduct Optical Character Recognition (OCR) 169 | 170 | Args: 171 | image_path (str): the image file path 172 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc) 173 | 174 | """ 175 | detail = kwargs.get("detail", False) 176 | 177 | return self._postprocess( 178 | self._model( 179 | image_path, 180 | skip_details=False, 181 | batch_size=1, 182 | paragraph=True, 183 | ), 184 | detail, 185 | ) 186 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import re 9 | import unicodedata 10 | from abc import abstractmethod 11 | from dataclasses import dataclass 12 | from typing import List, Mapping, Optional, Union 13 | 14 | 15 | @dataclass 16 | class TaskConfig: 17 | task: str 18 | lang: str 19 | n_model: str 20 | 21 | 22 | class PororoTaskBase: 23 | r"""Task base class that implements basic functions for prediction""" 24 | 25 | def __init__(self, config: TaskConfig): 26 | self.config = config 27 | 28 | @property 29 | def n_model(self): 30 | return self.config.n_model 31 | 32 | @property 33 | def lang(self): 34 | return self.config.lang 35 | 36 | @abstractmethod 37 | def predict( 38 | self, 39 | text: Union[str, List[str]], 40 | **kwargs, 41 | ): 42 | raise NotImplementedError("`predict()` function is not implemented properly!") 43 | 44 | def __call__(self): 45 | raise NotImplementedError("`call()` function is not implemented properly!") 46 | 47 | def __repr__(self): 48 | return f"[TASK]: {self.config.task.upper()}\n[LANG]: {self.config.lang.upper()}\n[MODEL]: {self.config.n_model}" 49 | 50 | def _normalize(self, text: str): 51 | """Unicode normalization and whitespace removal (often needed for contexts)""" 52 | text = unicodedata.normalize("NFKC", text) 53 | text = re.sub(r"\s+", " ", text).strip() 54 | return text 55 | 56 | 57 | class PororoFactoryBase(object): 58 | r"""This is a factory base class that construct task-specific module""" 59 | 60 | def __init__( 61 | self, 62 | task: str, 63 | lang: str, 64 | model: Optional[str] = None, 65 | ): 66 | self._available_langs = self.get_available_langs() 67 | self._available_models = self.get_available_models() 68 | self._model2lang = { 69 | v: k for k, vs in self._available_models.items() for v in vs 70 | } 71 | 72 | # Set default language as very first supported language 73 | assert ( 74 | lang in self._available_langs 75 | ), f"Following langs are supported for this task: {self._available_langs}" 76 | 77 | if lang is None: 78 | lang = self._available_langs[0] 79 | 80 | # Change language option if model is defined by user 81 | if model is not None: 82 | lang = self._model2lang[model] 83 | 84 | # Set default model 85 | if model is None: 86 | model = self.get_default_model(lang) 87 | 88 | # yapf: disable 89 | assert (model in self._available_models[lang]), f"{model} is NOT supported for {lang}" 90 | # yapf: enable 91 | 92 | self.config = TaskConfig(task, lang, model) 93 | 94 | @abstractmethod 95 | def get_available_langs(self) -> List[str]: 96 | raise NotImplementedError( 97 | "`get_available_langs()` is not implemented properly!" 98 | ) 99 | 100 | @abstractmethod 101 | def get_available_models(self) -> Mapping[str, List[str]]: 102 | raise NotImplementedError( 103 | "`get_available_models()` is not implemented properly!" 104 | ) 105 | 106 | @abstractmethod 107 | def get_default_model(self, lang: str) -> str: 108 | return self._available_models[lang][0] 109 | 110 | @classmethod 111 | def load(cls) -> PororoTaskBase: 112 | raise NotImplementedError("Model load function is not implemented properly!") 113 | 114 | 115 | class PororoSimpleBase(PororoTaskBase): 116 | r"""Simple task base wrapper class""" 117 | 118 | def __call__(self, text: str, **kwargs): 119 | return self.predict(text, **kwargs) 120 | 121 | 122 | class PororoBiencoderBase(PororoTaskBase): 123 | r"""Bi-Encoder base wrapper class""" 124 | 125 | def __call__( 126 | self, 127 | sent_a: str, 128 | sent_b: Union[str, List[str]], 129 | **kwargs, 130 | ): 131 | assert isinstance(sent_a, str), "sent_a should be string type" 132 | assert isinstance(sent_b, str) or isinstance( 133 | sent_b, list 134 | ), "sent_b should be string or list of string type" 135 | 136 | sent_a = self._normalize(sent_a) 137 | 138 | # For "Find Similar Sentence" task 139 | if isinstance(sent_b, list): 140 | sent_b = [self._normalize(t) for t in sent_b] 141 | else: 142 | sent_b = self._normalize(sent_b) 143 | 144 | return self.predict(sent_a, sent_b, **kwargs) 145 | 146 | 147 | class PororoGenerationBase(PororoTaskBase): 148 | r"""Generation task wrapper class using various generation tricks""" 149 | 150 | def __call__( 151 | self, 152 | text: str, 153 | beam: int = 5, 154 | temperature: float = 1.0, 155 | top_k: int = -1, 156 | top_p: float = -1, 157 | no_repeat_ngram_size: int = 4, 158 | len_penalty: float = 1.0, 159 | **kwargs, 160 | ): 161 | assert isinstance(text, str), "Input text should be string type" 162 | 163 | return self.predict( 164 | text, 165 | beam=beam, 166 | temperature=temperature, 167 | top_k=top_k, 168 | top_p=top_p, 169 | no_repeat_ngram_size=no_repeat_ngram_size, 170 | len_penalty=len_penalty, 171 | **kwargs, 172 | ) 173 | 174 | 175 | class PororoTaskGenerationBase(PororoTaskBase): 176 | r"""Generation task wrapper class using only beam search""" 177 | 178 | def __call__(self, text: str, beam: int = 1, **kwargs): 179 | assert isinstance(text, str), "Input text should be string type" 180 | 181 | text = self._normalize(text) 182 | 183 | return self.predict(text, beam=beam, **kwargs) 184 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from dataclasses import dataclass 9 | from typing import Union 10 | 11 | 12 | @dataclass 13 | class TransformerConfig: 14 | src_dict: Union[str, None] 15 | tgt_dict: Union[str, None] 16 | src_tok: Union[str, None] 17 | tgt_tok: Union[str, None] 18 | 19 | 20 | CONFIGS = { 21 | "transformer.base.ko.const": TransformerConfig( 22 | "dict.transformer.base.ko.const", 23 | "dict.transformer.base.ko.const", 24 | None, 25 | None, 26 | ), 27 | "transformer.base.ko.pg": TransformerConfig( 28 | "dict.transformer.base.ko.mt", 29 | "dict.transformer.base.ko.mt", 30 | "bpe8k.ko", 31 | None, 32 | ), 33 | "transformer.base.ko.pg_long": TransformerConfig( 34 | "dict.transformer.base.ko.mt", 35 | "dict.transformer.base.ko.mt", 36 | "bpe8k.ko", 37 | None, 38 | ), 39 | "transformer.base.en.gec": TransformerConfig( 40 | "dict.transformer.base.en.mt", 41 | "dict.transformer.base.en.mt", 42 | "bpe32k.en", 43 | None, 44 | ), 45 | "transformer.base.zh.pg": TransformerConfig( 46 | "dict.transformer.base.zh.mt", 47 | "dict.transformer.base.zh.mt", 48 | None, 49 | None, 50 | ), 51 | "transformer.base.ja.pg": TransformerConfig( 52 | "dict.transformer.base.ja.mt", 53 | "dict.transformer.base.ja.mt", 54 | "bpe8k.ja", 55 | None, 56 | ), 57 | "transformer.base.zh.const": TransformerConfig( 58 | "dict.transformer.base.zh.const", 59 | "dict.transformer.base.zh.const", 60 | None, 61 | None, 62 | ), 63 | "transformer.base.en.const": TransformerConfig( 64 | "dict.transformer.base.en.const", 65 | "dict.transformer.base.en.const", 66 | None, 67 | None, 68 | ), 69 | "transformer.base.en.pg": TransformerConfig( 70 | "dict.transformer.base.en.mt", 71 | "dict.transformer.base.en.mt", 72 | "bpe32k.en", 73 | None, 74 | ), 75 | "transformer.base.ko.gec": TransformerConfig( 76 | "dict.transformer.base.ko.gec", 77 | "dict.transformer.base.ko.gec", 78 | "bpe8k.ko", 79 | None, 80 | ), 81 | "transformer.base.en.char_gec": TransformerConfig( 82 | "dict.transformer.base.en.char_gec", 83 | "dict.transformer.base.en.char_gec", 84 | None, 85 | None, 86 | ), 87 | "transformer.base.en.caption": TransformerConfig( 88 | None, 89 | None, 90 | None, 91 | None, 92 | ), 93 | "transformer.base.ja.p2g": TransformerConfig( 94 | "dict.transformer.base.ja.p2g", 95 | "dict.transformer.base.ja.p2g", 96 | None, 97 | None, 98 | ), 99 | "transformer.large.multi.mtpg": TransformerConfig( 100 | "dict.transformer.large.multi.mtpg", 101 | "dict.transformer.large.multi.mtpg", 102 | "bpe32k.en", 103 | None, 104 | ), 105 | "transformer.large.multi.fast.mtpg": TransformerConfig( 106 | "dict.transformer.large.multi.mtpg", 107 | "dict.transformer.large.multi.mtpg", 108 | "bpe32k.en", 109 | None, 110 | ), 111 | "transformer.large.ko.wsd": TransformerConfig( 112 | "dict.transformer.large.ko.wsd", 113 | "dict.transformer.large.ko.wsd", 114 | None, 115 | None, 116 | ), 117 | } 118 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/download_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | """Module download related function from. Tenth""" 9 | 10 | import logging 11 | import os 12 | import platform 13 | import sys 14 | import zipfile 15 | from dataclasses import dataclass 16 | from typing import Tuple, Union 17 | 18 | import wget 19 | 20 | from .config import CONFIGS 21 | 22 | DEFAULT_PREFIX = { 23 | "model": "https://twg.kakaocdn.net/pororo/{lang}/models", 24 | "dict": "https://twg.kakaocdn.net/pororo/{lang}/dicts", 25 | } 26 | 27 | 28 | @dataclass 29 | class TransformerInfo: 30 | r"Dataclass for transformer-based model" 31 | path: str 32 | dict_path: str 33 | src_dict: str 34 | tgt_dict: str 35 | src_tok: Union[str, None] 36 | tgt_tok: Union[str, None] 37 | 38 | 39 | @dataclass 40 | class DownloadInfo: 41 | r"Download information such as defined directory, language and model name" 42 | n_model: str 43 | lang: str 44 | root_dir: str 45 | 46 | 47 | def get_save_dir(save_dir: str = None) -> str: 48 | """ 49 | Get default save directory 50 | 51 | Args: 52 | savd_dir(str): User-defined save directory 53 | 54 | Returns: 55 | str: Set save directory 56 | 57 | """ 58 | # If user wants to manually define save directory 59 | if save_dir: 60 | os.makedirs(save_dir, exist_ok=True) 61 | return save_dir 62 | 63 | pf = platform.system() 64 | 65 | if pf == "Windows": 66 | save_dir = "C:\\pororo" 67 | else: 68 | home_dir = os.path.expanduser("~") 69 | save_dir = os.path.join(home_dir, ".pororo") 70 | 71 | if not os.path.exists(save_dir): 72 | os.makedirs(save_dir) 73 | 74 | return save_dir 75 | 76 | 77 | def get_download_url(n_model: str, key: str, lang: str) -> str: 78 | """ 79 | Get download url using default prefix 80 | 81 | Args: 82 | n_model (str): model name 83 | key (str): key name either `model` or `dict` 84 | lang (str): language name 85 | 86 | Returns: 87 | str: generated download url 88 | 89 | """ 90 | default_prefix = DEFAULT_PREFIX[key].format(lang=lang) 91 | return f"{default_prefix}/{n_model}" 92 | 93 | 94 | def download_or_load_bert(info: DownloadInfo) -> str: 95 | """ 96 | Download fine-tuned BrainBert & BrainSBert model and dict 97 | 98 | Args: 99 | info (DownloadInfo): download information 100 | 101 | Returns: 102 | str: downloaded bert & sbert path 103 | 104 | """ 105 | model_path = os.path.join(info.root_dir, info.n_model) 106 | 107 | if not os.path.exists(model_path): 108 | info.n_model += ".zip" 109 | zip_path = os.path.join(info.root_dir, info.n_model) 110 | 111 | type_dir = download_from_url( 112 | info.n_model, 113 | zip_path, 114 | key="model", 115 | lang=info.lang, 116 | ) 117 | 118 | zip_file = zipfile.ZipFile(zip_path) 119 | zip_file.extractall(type_dir) 120 | zip_file.close() 121 | 122 | return model_path 123 | 124 | 125 | def download_or_load_transformer(info: DownloadInfo) -> TransformerInfo: 126 | """ 127 | Download pre-trained Transformer model and corresponding dict 128 | 129 | Args: 130 | info (DownloadInfo): download information 131 | 132 | Returns: 133 | TransformerInfo: information dataclass for transformer construction 134 | 135 | """ 136 | config = CONFIGS[info.n_model.split("/")[-1]] 137 | 138 | src_dict_in = config.src_dict 139 | tgt_dict_in = config.tgt_dict 140 | src_tok = config.src_tok 141 | tgt_tok = config.tgt_tok 142 | 143 | info.n_model += ".pt" 144 | model_path = os.path.join(info.root_dir, info.n_model) 145 | 146 | # Download or load Transformer model 147 | model_type_dir = "/".join(model_path.split("/")[:-1]) 148 | if not os.path.exists(model_path): 149 | model_type_dir = download_from_url( 150 | info.n_model, 151 | model_path, 152 | key="model", 153 | lang=info.lang, 154 | ) 155 | 156 | dict_type_dir = str() 157 | src_dict, tgt_dict = str(), str() 158 | 159 | # Download or load corresponding dictionary 160 | if src_dict_in: 161 | src_dict = f"{src_dict_in}.txt" 162 | src_dict_path = os.path.join(info.root_dir, f"dicts/{src_dict}") 163 | dict_type_dir = "/".join(src_dict_path.split("/")[:-1]) 164 | if not os.path.exists(src_dict_path): 165 | dict_type_dir = download_from_url( 166 | src_dict, 167 | src_dict_path, 168 | key="dict", 169 | lang=info.lang, 170 | ) 171 | 172 | if tgt_dict_in: 173 | tgt_dict = f"{tgt_dict_in}.txt" 174 | tgt_dict_path = os.path.join(info.root_dir, f"dicts/{tgt_dict}") 175 | if not os.path.exists(tgt_dict_path): 176 | download_from_url( 177 | tgt_dict, 178 | tgt_dict_path, 179 | key="dict", 180 | lang=info.lang, 181 | ) 182 | 183 | # Download or load corresponding tokenizer 184 | src_tok_path, tgt_tok_path = None, None 185 | if src_tok: 186 | src_tok_path = download_or_load( 187 | f"tokenizers/{src_tok}.zip", 188 | lang=info.lang, 189 | ) 190 | if tgt_tok: 191 | tgt_tok_path = download_or_load( 192 | f"tokenizers/{tgt_tok}.zip", 193 | lang=info.lang, 194 | ) 195 | 196 | return TransformerInfo( 197 | path=model_type_dir, 198 | dict_path=dict_type_dir, 199 | # Drop prefix "dict." and postfix ".txt" 200 | src_dict=".".join(src_dict.split(".")[1:-1]), 201 | # to follow fairseq's dictionary load process 202 | tgt_dict=".".join(tgt_dict.split(".")[1:-1]), 203 | src_tok=src_tok_path, 204 | tgt_tok=tgt_tok_path, 205 | ) 206 | 207 | 208 | def download_or_load_misc(info: DownloadInfo) -> str: 209 | """ 210 | Download (pre-trained) miscellaneous model 211 | 212 | Args: 213 | info (DownloadInfo): download information 214 | 215 | Returns: 216 | str: miscellaneous model path 217 | 218 | """ 219 | # Add postfix <.model> for sentencepiece 220 | if "sentencepiece" in info.n_model: 221 | info.n_model += ".model" 222 | 223 | # Generate target model path using root directory 224 | model_path = os.path.join(info.root_dir, info.n_model) 225 | if not os.path.exists(model_path): 226 | type_dir = download_from_url( 227 | info.n_model, 228 | model_path, 229 | key="model", 230 | lang=info.lang, 231 | ) 232 | 233 | if ".zip" in info.n_model: 234 | zip_file = zipfile.ZipFile(model_path) 235 | zip_file.extractall(type_dir) 236 | zip_file.close() 237 | 238 | if ".zip" in info.n_model: 239 | model_path = model_path[: model_path.rfind(".zip")] 240 | return model_path 241 | 242 | 243 | def download_or_load_bart(info: DownloadInfo) -> Union[str, Tuple[str, str]]: 244 | """ 245 | Download BART model 246 | 247 | Args: 248 | info (DownloadInfo): download information 249 | 250 | Returns: 251 | Union[str, Tuple[str, str]]: BART model path (with. corresponding SentencePiece) 252 | 253 | """ 254 | info.n_model += ".pt" 255 | 256 | model_path = os.path.join(info.root_dir, info.n_model) 257 | if not os.path.exists(model_path): 258 | download_from_url( 259 | info.n_model, 260 | model_path, 261 | key="model", 262 | lang=info.lang, 263 | ) 264 | 265 | return model_path 266 | 267 | 268 | def download_from_url( 269 | n_model: str, 270 | model_path: str, 271 | key: str, 272 | lang: str, 273 | ) -> str: 274 | """ 275 | Download specified model from Tenth 276 | 277 | Args: 278 | n_model (str): model name 279 | model_path (str): pre-defined model path 280 | key (str): type key (either model or dict) 281 | lang (str): language name 282 | 283 | Returns: 284 | str: default type directory 285 | 286 | """ 287 | # Get default type dir path 288 | type_dir = "/".join(model_path.split("/")[:-1]) 289 | os.makedirs(type_dir, exist_ok=True) 290 | 291 | # Get download tenth url 292 | url = get_download_url(n_model, key=key, lang=lang) 293 | 294 | logging.info("Downloading user-selected model...") 295 | wget.download(url, type_dir) 296 | sys.stderr.write("\n") 297 | sys.stderr.flush() 298 | 299 | return type_dir 300 | 301 | 302 | def download_or_load( 303 | n_model: str, 304 | lang: str, 305 | custom_save_dir: str = None, 306 | ) -> Union[TransformerInfo, str, Tuple[str, str]]: 307 | """ 308 | Download or load model based on model information 309 | 310 | Args: 311 | n_model (str): model name 312 | lang (str): language information 313 | custom_save_dir (str, optional): user-defined save directory path. defaults to None. 314 | 315 | Returns: 316 | Union[TransformerInfo, str, Tuple[str, str]] 317 | 318 | """ 319 | root_dir = get_save_dir(save_dir=custom_save_dir) 320 | info = DownloadInfo(n_model, lang, root_dir) 321 | 322 | if "transformer" in n_model: 323 | return download_or_load_transformer(info) 324 | if "bert" in n_model: 325 | return download_or_load_bert(info) 326 | if "bart" in n_model and "bpe" not in n_model: 327 | return download_or_load_bart(info) 328 | 329 | return download_or_load_misc(info) 330 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/pororo/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | from contextlib import contextmanager 9 | from tempfile import NamedTemporaryFile 10 | 11 | from requests import get 12 | 13 | 14 | def postprocess_span(tagger, text: str) -> str: 15 | """ 16 | Postprocess NOUN span to remove unnecessary character 17 | 18 | Args: 19 | text (str): NOUN span to be processed 20 | 21 | Returns: 22 | (str): post-processed NOUN span 23 | 24 | Examples: 25 | >>> postprocess_span("강감찬 장군은") 26 | '강감찬 장군' 27 | >>> postprocess_span("그녀에게") 28 | '그녀' 29 | 30 | """ 31 | 32 | # First, strip punctuations 33 | text = text.strip("""!"\#$&'()*+,\-./:;<=>?@\^_‘{|}~《》""") 34 | 35 | # Complete imbalanced parentheses pair 36 | if text.count("(") == text.count(")") + 1: 37 | text += ")" 38 | elif text.count("(") + 1 == text.count(")"): 39 | text = "(" + text 40 | 41 | # Preserve beginning tokens since we only want to extract noun phrase of the last eojeol 42 | noun_phrase = " ".join(text.rsplit(" ", 1)[:-1]) 43 | tokens = text.split(" ") 44 | eojeols = list() 45 | for token in tokens: 46 | eojeols.append(tagger.pos(token)) 47 | last_eojeol = eojeols[-1] 48 | 49 | # Iterate backwardly to remove unnecessary postfixes 50 | i = 0 51 | for i, token in enumerate(last_eojeol[::-1]): 52 | _, pos = token 53 | # 1. The loop breaks when you meet a noun 54 | # 2. The loop also breaks when you meet a XSN (e.g. 8/SN+일/NNB LG/SL 전/XSN) 55 | if (pos[0] in ("N", "S")) or pos.startswith("XSN"): 56 | break 57 | idx = len(last_eojeol) - i 58 | 59 | # Extract noun span from last eojeol and postpend it to beginning tokens 60 | ext_last_eojeol = "".join(morph for morph, _ in last_eojeol[:idx]) 61 | noun_phrase += " " + ext_last_eojeol 62 | return noun_phrase.strip() 63 | 64 | 65 | @contextmanager 66 | def control_temp(file_path: str): 67 | """ 68 | Download temporary file from web, then remove it after some context 69 | 70 | Args: 71 | file_path (str): web file path 72 | 73 | """ 74 | # yapf: disable 75 | assert file_path.startswith("http"), "File path should contain `http` prefix !" 76 | # yapf: enable 77 | 78 | ext = file_path[file_path.rfind(".") :] 79 | 80 | with NamedTemporaryFile("wb", suffix=ext, delete=True) as f: 81 | response = get(file_path, allow_redirects=True) 82 | f.write(response.content) 83 | yield f.name 84 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/utils/image_convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | def convert_coord(coord_list): 13 | x_min, x_max, y_min, y_max = coord_list 14 | return [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]] 15 | 16 | 17 | # https://www.life2coding.com/cropping-polygon-or-non-rectangular-region-from-image-using-opencv-python/ 18 | # https://stackoverflow.com/questions/48301186/cropping-concave-polygon-from-image-using-opencv-python 19 | def crop(image, points): 20 | pts = np.array(points, np.int32) 21 | 22 | # Crop the bounding rect 23 | rect = cv2.boundingRect(pts) 24 | x, y, w, h = rect 25 | croped = image[y : y + h, x : x + w].copy() 26 | 27 | # make mask 28 | pts = pts - pts.min(axis=0) 29 | 30 | mask = np.zeros(croped.shape[:2], np.uint8) 31 | cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA) 32 | 33 | # do bit-op 34 | dst = cv2.bitwise_and(croped, croped, mask=mask) 35 | 36 | # add the white background 37 | bg = np.ones_like(croped, np.uint8) * 255 38 | cv2.bitwise_not(bg, bg, mask=mask) 39 | result = bg + dst 40 | 41 | return result 42 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/utils/image_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import platform 11 | from PIL import ImageFont, ImageDraw, Image 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def plt_imshow(title="image", img=None, figsize=(8, 5)): 16 | plt.figure(figsize=figsize) 17 | 18 | if type(img) is str: 19 | img = cv2.imread(img) 20 | 21 | if type(img) == list: 22 | if type(title) == list: 23 | titles = title 24 | else: 25 | titles = [] 26 | 27 | for i in range(len(img)): 28 | titles.append(title) 29 | 30 | for i in range(len(img)): 31 | if len(img[i].shape) <= 2: 32 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_GRAY2RGB) 33 | else: 34 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_BGR2RGB) 35 | 36 | plt.subplot(1, len(img), i + 1), plt.imshow(rgbImg) 37 | plt.title(titles[i]) 38 | plt.xticks([]), plt.yticks([]) 39 | 40 | plt.show() 41 | else: 42 | if len(img.shape) < 3: 43 | rgbImg = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 44 | else: 45 | rgbImg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 46 | 47 | plt.imshow(rgbImg) 48 | plt.title(title) 49 | plt.xticks([]), plt.yticks([]) 50 | plt.show() 51 | 52 | 53 | def put_text(image, text, x, y, color=(0, 255, 0), font_size=22): 54 | if type(image) == np.ndarray: 55 | color_coverted = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 56 | image = Image.fromarray(color_coverted) 57 | 58 | if platform.system() == "Darwin": 59 | font = "AppleGothic.ttf" 60 | elif platform.system() == "Windows": 61 | font = "malgun.ttf" 62 | elif platform.system() == "Linux": 63 | font = "NotoSansCJK-Regular.ttc" 64 | 65 | image_font = ImageFont.truetype(font, font_size) 66 | font = ImageFont.load_default() 67 | draw = ImageDraw.Draw(image) 68 | 69 | draw.text((x, y), text, font=image_font, fill=color) 70 | 71 | numpy_image = np.array(image) 72 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) 73 | 74 | return opencv_image 75 | -------------------------------------------------------------------------------- /betterocr/engines/easy_pororo_ocr/utils/pre_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo 3 | 4 | Apache License 2.0 @yunwoong7 5 | Apache License 2.0 @black7375 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | # https://tesseract-ocr.github.io/tessdoc/ImproveQuality.html 13 | # https://nanonets.com/blog/ocr-with-tesseract/ 14 | # https://github.com/TCAT-capstone/ocr-preprocessor/blob/main/main.py 15 | # https://towardsdatascience.com/pre-processing-in-ocr-fc231c6035a7 16 | # == Main ====================================================================== 17 | def load(path): 18 | return cv2.imread(path) 19 | 20 | 21 | def image_filter(image): 22 | image = grayscale(image) 23 | # image = thresholding(image, mode="GAUSSIAN") 24 | # image = opening(image) 25 | # image = closing(image) 26 | return image 27 | 28 | 29 | def roi_filter(image): 30 | image = resize(image) 31 | return image 32 | 33 | 34 | def load_with_filter(path): 35 | image = load(path) 36 | return image_filter(image) 37 | 38 | 39 | def isEven(num): 40 | return num % 2 == 0 41 | 42 | 43 | # == Color ===================================================================== 44 | def grayscale(image, blur=False): 45 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 46 | 47 | 48 | def invert(image): 49 | return cv2.bitwise_not(image) 50 | 51 | 52 | # https://opencv-python.readthedocs.io/en/latest/doc/09.imageThresholding/imageThresholding.html 53 | def thresholding(image, mode="GENERAL", block_size=9, C=5): 54 | if isEven(block_size): 55 | print("block_size to use odd") 56 | return 57 | 58 | if mode == "MEAN": 59 | return cv2.adaptiveThreshold( 60 | image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, block_size, C 61 | ) 62 | elif mode == "GAUSSIAN": 63 | return cv2.adaptiveThreshold( 64 | image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, C 65 | ) 66 | else: 67 | return cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] 68 | 69 | 70 | def normalization(image, mode="COLOR", result_size=None): 71 | result_size = np.zeros(result_size) if result_size is not None else None 72 | 73 | if mode == "COLOR": 74 | return cv2.normalize(image, result_size, 0, 255, cv2.NORM_MINMAX) 75 | 76 | if mode == "GRAY": 77 | return cv2.normalize( 78 | image, result_size, 0, 1.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F 79 | ) 80 | 81 | 82 | def equalization(image): 83 | return cv2.equalizeHist(image) 84 | 85 | 86 | # == Noise ===================================================================== 87 | def remove_noise(image, mode="COLOR"): 88 | if mode == "COLOR": 89 | return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 15) 90 | if mode == "GRAY": 91 | return cv2.fastNlMeansDenoising(image, None, 10, 7, 21) 92 | 93 | 94 | def blur(image, kernel=(3, 3)): 95 | return cv2.GaussianBlur(image, kernel, 0) 96 | 97 | 98 | def blur_median(image, kernel_size=3): 99 | return cv2.medianBlur(image, ksize=kernel_size) 100 | 101 | 102 | # == Morphology ================================================================ 103 | def dilation(image, kernel=np.ones((3, 3), np.uint8)): 104 | return cv2.dilate(image, kernel, iterations=1) 105 | 106 | 107 | def erosion(image, kernel=np.ones((3, 3), np.uint8)): 108 | return cv2.erode(image, kernel, iterations=1) 109 | 110 | 111 | def opening(image, kernel=np.ones((3, 3), np.uint8)): 112 | return cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel) 113 | 114 | 115 | def closing(image, kernel=np.ones((3, 3), np.uint8)): 116 | return cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel) 117 | 118 | 119 | def gradient(image, kernel=np.ones((3, 3), np.uint8)): 120 | return cv2.morphologyEx(image, cv2.MORPH_GRADIENT, kernel) 121 | 122 | 123 | def canny(image): 124 | return cv2.Canny(image, 100, 200) 125 | 126 | 127 | # == Others ==================================================================== 128 | def resize(image, interpolation=cv2.INTER_CUBIC): 129 | height, width = image.shape[:2] 130 | factor = max(1, (20.0 / height)) 131 | size = (int(factor * width), int(factor * height)) 132 | return cv2.resize(image, dsize=size, interpolation=interpolation) 133 | 134 | 135 | # https://becominghuman.ai/how-to-automatically-deskew-straighten-a-text-image-using-opencv-a0c30aed83df 136 | def deskew(image): 137 | coords = np.column_stack(np.where(image > 0)) 138 | angle = cv2.minAreaRect(coords)[-1] 139 | if angle < -45: 140 | angle = -(90 + angle) 141 | else: 142 | angle = -angle 143 | 144 | (h, w) = image.shape[:2] 145 | center = (w // 2, h // 2) 146 | M = cv2.getRotationMatrix2D(center, angle, 1.0) 147 | rotated = cv2.warpAffine( 148 | image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE 149 | ) 150 | return rotated 151 | 152 | 153 | def match_template(image, template): 154 | return cv2.matchTemplate(image, template, cv2.TM_CCOEFF_NORMED) 155 | -------------------------------------------------------------------------------- /betterocr/parsers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | 4 | 5 | def extract_json(input_string): 6 | # Find the JSON in the string 7 | matches = re.findall(r'{\s*"data"\s*:\s*"(.*?)"\s*}', input_string, re.DOTALL) 8 | if matches: 9 | # Correctly escape special characters 10 | matches = [m.replace("\n", "\\n").replace('"', '\\"') for m in matches] 11 | for match in matches: 12 | # Construct JSON string 13 | json_string = f'{{"data": "{match}"}}' 14 | try: 15 | # Load the JSON and return the data 16 | json_obj = json.loads(json_string) 17 | return json_obj 18 | except json.decoder.JSONDecodeError: 19 | continue 20 | 21 | # If no JSON found, return None 22 | return None 23 | 24 | 25 | def extract_list(s): 26 | stack = [] 27 | start_position = None 28 | 29 | # Iterate through each character in the string 30 | for i, c in enumerate(s): 31 | if c == "[": 32 | if start_position is None: # First '[' found 33 | start_position = i 34 | stack.append(c) 35 | 36 | elif c == "]": 37 | if stack: 38 | stack.pop() 39 | 40 | # If stack is empty and start was marked 41 | if not stack and start_position is not None: 42 | substring = s[start_position : i + 1] 43 | try: 44 | list_obj = json.loads(substring) 45 | for item in list_obj: 46 | if "box" in item and "text" in item: 47 | return list_obj 48 | except json.decoder.JSONDecodeError: 49 | # Reset the stack and start position as this isn't a valid JSON 50 | stack = [] 51 | start_position = None 52 | continue 53 | 54 | # If no valid list found, return None 55 | return None 56 | 57 | 58 | def rectangle_corners(rect): 59 | x, y, w, h = rect 60 | return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] 61 | -------------------------------------------------------------------------------- /betterocr/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .easy_ocr import job_easy_ocr, job_easy_ocr_boxes 2 | from .tesseract.job import job_tesseract, job_tesseract_boxes 3 | 4 | __all__ = [ 5 | "job_easy_ocr", 6 | "job_easy_ocr_boxes", 7 | "job_tesseract", 8 | "job_tesseract_boxes", 9 | ] 10 | -------------------------------------------------------------------------------- /betterocr/wrappers/easy_ocr.py: -------------------------------------------------------------------------------- 1 | import easyocr 2 | 3 | 4 | def job_easy_ocr(_options): 5 | reader = easyocr.Reader(_options["lang"]) 6 | text = reader.readtext(_options["path"], detail=0) 7 | text = "".join(text) 8 | print("[*] job_easy_ocr", text) 9 | return text 10 | 11 | 12 | def job_easy_ocr_boxes(_options): 13 | reader = easyocr.Reader(_options["lang"]) 14 | boxes = reader.readtext(_options["path"], output_format="dict") 15 | for box in boxes: 16 | box["box"] = box.pop("boxes") 17 | return boxes 18 | -------------------------------------------------------------------------------- /betterocr/wrappers/easy_pororo_ocr.py: -------------------------------------------------------------------------------- 1 | from ..engines.easy_pororo_ocr import EasyPororoOcr, load_with_filter 2 | 3 | 4 | def parse_languages(lang: list[str]): 5 | languages = [] 6 | for l in lang: 7 | if l in ["ko", "en"]: 8 | languages.append(l) 9 | 10 | if len(languages) == 0: 11 | languages = ["ko"] 12 | 13 | return languages 14 | 15 | 16 | def default_ocr(_options): 17 | lang = parse_languages(_options["lang"]) 18 | return EasyPororoOcr(lang) 19 | 20 | 21 | def job_easy_pororo_ocr(_options): 22 | image = load_with_filter(_options["path"]) 23 | 24 | ocr = _options.get("ocr") 25 | if not ocr: 26 | ocr = default_ocr(_options) 27 | 28 | text = ocr.run_ocr(image, debug=False) 29 | 30 | if isinstance(text, list): 31 | text = "\\n".join(text) 32 | 33 | print("[*] job_easy_pororo_ocr", text) 34 | return text 35 | 36 | 37 | def job_easy_pororo_ocr_boxes(_options): 38 | ocr = default_ocr(_options) 39 | job_easy_pororo_ocr({**_options, "ocr": ocr}) 40 | return ocr.get_boxes() 41 | -------------------------------------------------------------------------------- /betterocr/wrappers/tesseract/__init__.py: -------------------------------------------------------------------------------- 1 | from .job import convert_to_tesseract_lang_code, job_tesseract, job_tesseract_boxes 2 | 3 | __all__ = ["convert_to_tesseract_lang_code", "job_tesseract", "job_tesseract_boxes"] 4 | -------------------------------------------------------------------------------- /betterocr/wrappers/tesseract/job.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pytesseract 4 | 5 | from .mapping import LANG_CODE_MAPPING 6 | 7 | 8 | def convert_to_tesseract_lang_code(langs: list[str]) -> str: 9 | return "+".join( 10 | [ 11 | LANG_CODE_MAPPING[lang] 12 | for lang in langs 13 | if lang in LANG_CODE_MAPPING and LANG_CODE_MAPPING[lang] is not None 14 | ] 15 | ) 16 | 17 | 18 | def job_tesseract(_options): 19 | lang = convert_to_tesseract_lang_code(_options["lang"]) 20 | text = pytesseract.image_to_string( 21 | _options["path"], 22 | lang=lang, 23 | **_options["tesseract"] 24 | # pass rest of tesseract options here. 25 | ) 26 | text = text.replace("\n", "\\n") 27 | print("[*] job_tesseract_ocr", text) 28 | return text 29 | 30 | 31 | def job_tesseract_boxes(_options): 32 | lang = convert_to_tesseract_lang_code(_options["lang"]) 33 | df = pytesseract.image_to_data( 34 | _options["path"], 35 | lang=lang, 36 | **_options["tesseract"], 37 | output_type=pytesseract.Output.DATAFRAME 38 | # pass rest of tesseract options here. 39 | ) 40 | 41 | # https://stackoverflow.com/questions/74221064/draw-a-rectangle-around-a-string-of-words-using-pytesseract 42 | boxes = [] 43 | for line_num, words_per_line in df.groupby("line_num"): 44 | words_per_line = words_per_line[words_per_line["conf"] >= 5] 45 | if len(words_per_line) == 0: 46 | continue 47 | 48 | words = words_per_line["text"].values 49 | line = " ".join(words) 50 | 51 | word_boxes = [] 52 | for left, top, width, height in words_per_line[ 53 | ["left", "top", "width", "height"] 54 | ].values: 55 | word_boxes.append((left, top)) 56 | word_boxes.append((left + width, top + height)) 57 | 58 | x, y, w, h = cv2.boundingRect(np.array(word_boxes)) 59 | boxes.append( 60 | { 61 | "box": [[x, y], [x + w, y], [x + w, y + h], [x, y + h]], 62 | "text": line, 63 | } 64 | ) 65 | 66 | return boxes 67 | -------------------------------------------------------------------------------- /betterocr/wrappers/tesseract/mapping.py: -------------------------------------------------------------------------------- 1 | LANG_CODE_MAPPING = { 2 | "abq": "afr", # Not a direct match 3 | "ady": None, # Not found 4 | "af": None, # Not found 5 | "ang": None, # Not found 6 | "ar": "ara", 7 | "as": "asm", 8 | "ava": None, # Not found 9 | "az": "aze", 10 | "be": "bel", 11 | "bg": "bul", 12 | "bh": None, # Not found 13 | "bho": None, # Not found 14 | "bn": "ben", 15 | "bs": "bos", 16 | "ch_sim": "chi_sim", 17 | "ch_tra": "chi_tra", 18 | "che": None, # Not found 19 | "cs": "ces", 20 | "cy": "cym", 21 | "da": "dan", 22 | "dar": None, # Not found 23 | "de": "deu", 24 | "en": "eng", 25 | "es": "spa", 26 | "et": "est", 27 | "fa": "fas", 28 | "fr": "fra", 29 | "ga": "gle", 30 | "gom": None, # Not found 31 | "hi": "hin", 32 | "hr": "hrv", 33 | "hu": "hun", 34 | "id": "ind", 35 | "inh": None, # Not found 36 | "is": "isl", 37 | "it": "ita", 38 | "ja": "jpn", 39 | "kbd": None, # Not found 40 | "kn": "kan", 41 | "ko": "kor", 42 | "ku": None, # Not found 43 | "la": "lat", 44 | "lbe": None, # Not found 45 | "lez": None, # Not found 46 | "lt": "lit", 47 | "lv": "lav", 48 | "mah": None, # Not found 49 | "mai": None, # Not found 50 | "mi": "mri", 51 | "mn": "mon", 52 | "mr": "mar", 53 | "ms": "msa", 54 | "mt": "mlt", 55 | "ne": "nep", 56 | "new": None, # Not found 57 | "nl": "nld", 58 | "no": "nor", 59 | "oc": "oci", 60 | "pi": None, # Not found 61 | "pl": "pol", 62 | "pt": "por", 63 | "ro": "ron", 64 | "ru": "rus", 65 | "rs_cyrillic": None, # Not a direct match 66 | "rs_latin": None, # Not a direct match 67 | "sck": None, # Not found 68 | "sk": "slk", 69 | "sl": "slv", 70 | "sq": "sqi", 71 | "sv": "swe", 72 | "sw": "swa", 73 | "ta": "tam", 74 | "tab": None, # Not found 75 | "te": "tel", 76 | "th": "tha", 77 | "tjk": None, # Not found 78 | "tl": "tgl", 79 | "tr": "tur", 80 | "ug": "uig", 81 | "uk": "ukr", 82 | "ur": "urd", 83 | "uz": "uzb", 84 | "vi": "vie", 85 | } 86 | -------------------------------------------------------------------------------- /examples/detect_boxes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | from betterocr import detect_boxes 7 | 8 | image_path = ".github/images/demo-1.png" 9 | result = detect_boxes( 10 | image_path, 11 | ["ko", "en"], 12 | context="퍼멘테이션 펩타인 아이케어 크림", # product name 13 | tesseract={ 14 | "config": "--psm 6 --tessdata-dir ./tessdata -c tessedit_create_boxfile=1" 15 | }, 16 | ) 17 | print(result) 18 | 19 | font_path = ".github/examples/Pretendard-Medium.ttf" 20 | font_size = 36 21 | font = ImageFont.truetype(font_path, size=font_size, encoding="unic") 22 | 23 | img = cv2.imread(image_path) 24 | img_rgb = cv2.cvtColor( 25 | img, cv2.COLOR_BGR2RGB 26 | ) # Convert from BGR to RGB for correct color display in matplotlib 27 | 28 | for item in result: 29 | box = item["box"] 30 | pt1, pt2, pt3, pt4 = box 31 | top_left, bottom_right = tuple(pt1), tuple(pt3) 32 | 33 | # Draw rectangle 34 | cv2.rectangle(img_rgb, top_left, bottom_right, (0, 255, 0), 2) 35 | 36 | # For text 37 | pil_img = Image.fromarray(img_rgb) 38 | draw = ImageDraw.Draw(pil_img) 39 | 40 | # Get text dimensions 41 | text_width = draw.textlength(item["text"], font=font) 42 | text_height = font_size # line height = font size 43 | 44 | # Position the text just above the rectangle 45 | text_pos = (top_left[0], top_left[1] - text_height) 46 | 47 | # Draw text background rectangle 48 | draw.rectangle( 49 | [text_pos, (text_pos[0] + text_width, text_pos[1] + text_height)], 50 | fill=(0, 255, 0), 51 | ) 52 | 53 | # Draw text 54 | draw.text(text_pos, item["text"], font=font, fill=(0, 0, 0)) 55 | 56 | # Convert the PIL image back to an OpenCV image 57 | img_rgb = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) 58 | 59 | 60 | height, width, _ = img.shape 61 | 62 | # Convert pixel dimensions to inches 63 | dpi = 80 # Assuming default matplotlib dpi 64 | fig_width = width / dpi 65 | fig_height = height / dpi 66 | 67 | plt.figure(figsize=(fig_width, fig_height)) 68 | plt.imshow(img_rgb) 69 | plt.axis("off") 70 | plt.show() 71 | -------------------------------------------------------------------------------- /examples/detect_text.py: -------------------------------------------------------------------------------- 1 | from betterocr import detect_text 2 | 3 | text = detect_text( 4 | ".github/images/demo-1.png", 5 | ["ko", "en"], 6 | context="퍼멘테이션 펩타인 아이케어 크림", # product name 7 | tesseract={"config": "--tessdata-dir ./tessdata"}, 8 | openai={ 9 | "model": "gpt-3.5-turbo", 10 | }, 11 | ) 12 | print("\n") 13 | print(text) 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "betterocr" 3 | version = "1.2.0" 4 | description = "Better text detection by combining OCR engines with LLM." 5 | authors = ["Junho Yeo "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.11,<3.13" 10 | openai = { version = "^1.28.1", allow-prereleases = true } 11 | pytesseract = "^0.3.10" 12 | easyocr = "^1.7.1" 13 | numpy = "^1.26.4" 14 | pandas = "^2.2.2" 15 | opencv-python = "^4.9.0.80" 16 | 17 | [tool.poetry.group.dev.dependencies] 18 | matplotlib = "^3.8.4" 19 | pillow = "^10.3.0" 20 | pytest = "^8.2.0" 21 | 22 | [tool.poetry.group.pororo.dependencies] 23 | torch = "^2.1.0" 24 | torchvision = "^0.16.0" 25 | wget = "^3.2" 26 | scikit-image = "^0.22.0" 27 | 28 | [build-system] 29 | requires = ["poetry-core"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/tests/__init__.py -------------------------------------------------------------------------------- /tests/parsers/test_extract_json.py: -------------------------------------------------------------------------------- 1 | from betterocr.parsers import extract_json 2 | import json 3 | 4 | 5 | def is_same_dict(dict1, dict2): 6 | return json.dumps(dict1) == json.dumps(dict2) 7 | 8 | 9 | def test_extract_json__default(): 10 | assert is_same_dict(extract_json('{"data": "hello"}'), {"data": "hello"}) 11 | assert is_same_dict( 12 | extract_json( 13 | """ 14 | Using the provided context and both OCR data sets, I'll merge, correct, and provide the result: 15 | {"data": "{Optical Character Recognition}"} 16 | """ 17 | ), 18 | {"data": "{Optical Character Recognition}"}, 19 | ) 20 | 21 | 22 | def test_extract_json__newlines(): 23 | assert is_same_dict( 24 | extract_json('{"data": "hello\nworld"}'), {"data": "hello\nworld"} 25 | ) 26 | assert is_same_dict( 27 | extract_json('{"data": "hello \n\n world"}'), {"data": "hello \n\n world"} 28 | ) 29 | 30 | 31 | def test_extract_json__newlines_escaped(): 32 | assert is_same_dict( 33 | extract_json('{"data": "hello\\nworld"}'), {"data": "hello\nworld"} 34 | ) 35 | 36 | 37 | def test_extract_json__multiple_dicts_should_return_first_occurrence(): 38 | assert is_same_dict( 39 | extract_json( 40 | """ 41 | {"data": "JSON Data 1"} 42 | {"data": "JSON Data 2"} 43 | """ 44 | ), 45 | {"data": "JSON Data 1"}, 46 | ) 47 | assert is_same_dict( 48 | extract_json( 49 | """ 50 | {"invalid_key": "JSON Data 1"} 51 | {"data": "JSON Data 2"} 52 | """ 53 | ), 54 | {"data": "JSON Data 2"}, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/parsers/test_extract_list.py: -------------------------------------------------------------------------------- 1 | from betterocr.parsers import extract_list 2 | import json 3 | 4 | 5 | def is_same_list(dict1, dict2): 6 | return json.dumps(dict1) == json.dumps(dict2) 7 | 8 | 9 | def test_extract_list__default(): 10 | assert is_same_list( 11 | extract_list('[{"box":[],"text":""}]'), 12 | [{"box": [], "text": ""}], 13 | ) 14 | assert is_same_list( 15 | extract_list( 16 | """ 17 | Using the provided context and both OCR data sets, I'll merge, correct, and provide the result: 18 | [ 19 | { 20 | "box": [123, 456, 789, 101112], 21 | "text": "Hello World!" 22 | } 23 | ] 24 | """ 25 | ), 26 | [{"box": [123, 456, 789, 101112], "text": "Hello World!"}], 27 | ) 28 | --------------------------------------------------------------------------------