├── .github
├── examples
│ └── Pretendard-Medium.ttf
└── images
│ ├── arch.jpg
│ ├── boxes-0.png
│ ├── demo-0.webp
│ ├── demo-1.png
│ ├── demo-2.png
│ ├── demo-3.webp
│ ├── logo.png
│ └── starmie.jpg
├── .gitignore
├── .prettierignore
├── .prettierrc.js
├── .pylintrc
├── .vscode
├── extensions.json
└── settings.json
├── LICENSE
├── README.md
├── betterocr
├── __init__.py
├── detect.py
├── engines
│ └── easy_pororo_ocr
│ │ ├── __init__.py
│ │ ├── pororo
│ │ ├── __init__.py
│ │ ├── models
│ │ │ └── brainOCR
│ │ │ │ ├── __init__.py
│ │ │ │ ├── _dataset.py
│ │ │ │ ├── _modules.py
│ │ │ │ ├── brainocr.py
│ │ │ │ ├── craft.py
│ │ │ │ ├── craft_utils.py
│ │ │ │ ├── detection.py
│ │ │ │ ├── imgproc.py
│ │ │ │ ├── model.py
│ │ │ │ ├── modules
│ │ │ │ ├── __init__.py
│ │ │ │ ├── basenet.py
│ │ │ │ ├── feature_extraction.py
│ │ │ │ ├── prediction.py
│ │ │ │ ├── sequence_modeling.py
│ │ │ │ └── transformation.py
│ │ │ │ ├── recognition.py
│ │ │ │ └── utils.py
│ │ ├── pororo.py
│ │ ├── tasks
│ │ │ ├── __init__.py
│ │ │ ├── optical_character_recognition.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── config.py
│ │ │ │ └── download_utils.py
│ │ └── utils.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── image_convert.py
│ │ ├── image_util.py
│ │ └── pre_processing.py
├── parsers.py
└── wrappers
│ ├── __init__.py
│ ├── easy_ocr.py
│ ├── easy_pororo_ocr.py
│ └── tesseract
│ ├── __init__.py
│ ├── job.py
│ └── mapping.py
├── examples
├── detect_boxes.py
└── detect_text.py
├── poetry.lock
├── pyproject.toml
└── tests
├── __init__.py
└── parsers
├── test_extract_json.py
└── test_extract_list.py
/.github/examples/Pretendard-Medium.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/examples/Pretendard-Medium.ttf
--------------------------------------------------------------------------------
/.github/images/arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/arch.jpg
--------------------------------------------------------------------------------
/.github/images/boxes-0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/boxes-0.png
--------------------------------------------------------------------------------
/.github/images/demo-0.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-0.webp
--------------------------------------------------------------------------------
/.github/images/demo-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-1.png
--------------------------------------------------------------------------------
/.github/images/demo-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-2.png
--------------------------------------------------------------------------------
/.github/images/demo-3.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/demo-3.webp
--------------------------------------------------------------------------------
/.github/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/logo.png
--------------------------------------------------------------------------------
/.github/images/starmie.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/.github/images/starmie.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 | .pytest_cache
4 | dist
5 |
6 | .DS_Store
7 | *.traineddata
8 |
9 | *.png
10 | !.github/images/*.png
11 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | *.md
2 |
--------------------------------------------------------------------------------
/.prettierrc.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | bracketSpacing: true,
3 | jsxBracketSameLine: false,
4 | singleQuote: true,
5 | trailingComma: 'all',
6 | semi: true,
7 | };
8 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # A comma-separated list of package or module names from where C extensions may
2 | # be loaded. Extensions are loading into the active Python interpreter and may
3 | # run arbitrary code.
4 | extension-pkg-whitelist=cv2
5 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "esbenp.prettier-vscode",
4 | "ms-python.pylint",
5 | "ms-python.black-formatter"
6 | ]
7 | }
8 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.formatOnSave": true,
3 | "[json]": {
4 | "editor.defaultFormatter": "esbenp.prettier-vscode"
5 | },
6 | "[jsonc]": {
7 | "editor.defaultFormatter": "esbenp.prettier-vscode"
8 | },
9 | "[javascript]": {
10 | "editor.defaultFormatter": "esbenp.prettier-vscode"
11 | },
12 | "[typescript]": {
13 | "editor.defaultFormatter": "esbenp.prettier-vscode"
14 | },
15 | "[typescriptreact]": {
16 | "editor.defaultFormatter": "esbenp.prettier-vscode"
17 | },
18 | "[python]": {
19 | "editor.defaultFormatter": "ms-python.black-formatter"
20 | },
21 | "pylint.args": [
22 | "--disable=dangerous-default-value",
23 | "--disable=missing-module-docstring",
24 | "--disable=missing-class-docstring",
25 | "--disable=missing-function-docstring",
26 | "--disable=line-too-long",
27 | "--disable=invalid-name"
28 | ]
29 | }
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Junho Yeo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | BetterOCR
7 |
8 |
9 |
10 |
11 |
12 |
13 | > 🔍 Better text detection by combining multiple OCR engines with 🧠 LLM.
14 |
15 | OCR _still_ sucks! ... Especially when you're from the _other side_ of the world (and face a significant lack of training data in your language) — or just not thrilled with noisy results.
16 |
17 | **BetterOCR** combines results from multiple OCR engines with an LLM to correct & reconstruct the output.
18 |
19 |
20 | 
21 | My open source work is supported by the community
22 | Special thanks to (주)한국모바일상품권(Korea Mobile Git Card, Inc.) and others
23 |
24 |
25 | ### 🔍 OCR Engines
26 | Currently supports [EasyOCR](https://github.com/JaidedAI/EasyOCR) (JaidedAI), [Tesseract](https://github.com/tesseract-ocr/tesseract) (Google), and [Pororo](https://github.com/kakaobrain/pororo) (KakaoBrain).
27 |
28 | - For Pororo, we're using the code from https://github.com/black7375/korean_ocr_using_pororo
29 | (Pre-processing ➡️ _Text detection_ with EasyOCR ➡️ _Text recognition_ with BrainOCR (Pororo's OCR module)).
30 | - Pororo is used only if the language options (`lang`) specified include either 🇺🇸 English (`en`) or 🇰🇷 Korean (`ko`). Also additional dependencies listed in [tool.poetry.group.pororo.dependencies]
must be available. (If not, it'll automatically be excluded from enabled engines.)
31 |
32 | ### 🧠 LLM
33 | Supports [Chat models](https://github.com/openai/openai-python#chat-completions) from OpenAI.
34 |
35 | ### 📒 Custom Context
36 | Allows users to provide an optional context to use specific keywords such as proper nouns and product names. This assists in spelling correction and noise identification, ensuring accuracy even with rare or unconventional words.
37 |
38 | ### 🛢️ Resources
39 |
40 | - Head over to [💯 Examples](#-Examples) to view performace by languages (🇺🇸, 🇰🇷, 🇮🇳).
41 | - Coming Soon: ~~box detection~~ 🧪✅, improved interface 🚧, async support, and more. Contributions are welcomed.
42 |
43 | > **Warning**
44 | > This package is under rapid development 🛠
45 |
46 |
47 |
48 |
49 |
50 | > Architecture
51 |
52 | ## 🚀 Usage (WIP)
53 |
54 | ```bash
55 | pip install betterocr
56 | # pip3 install betterocr
57 | ```
58 |
59 | ```py
60 | import betterocr
61 |
62 | # text detection
63 | text = betterocr.detect_text(
64 | "demo.png",
65 | ["ko", "en"], # language codes (from EasyOCR)
66 | context="", # (optional) context
67 | tesseract={
68 | # Tesseract options here
69 | "config": "--tessdata-dir ./tessdata"
70 | },
71 | openai={
72 | # OpenAI options here
73 |
74 | # `os.environ["OPENAI_API_KEY"]` is used by default
75 | "API_KEY": "sk-xxxxxxx",
76 |
77 | # rest are used to pass params to `client.chat.completions.create`
78 | # `{"model": "gpt-4"}` by default
79 | "model": "gpt-3.5-turbo",
80 | },
81 | )
82 | print(text)
83 | ```
84 |
85 | ### 📦 Box Detection
86 |
87 | | Original | Detected |
88 | |:---:|:---:|
89 | |
|
|
90 |
91 | Example Script: https://github.com/junhoyeo/BetterOCR/blob/main/examples/detect_boxes.py (Uses OpenCV and Matplotlib to draw rectangles)
92 |
93 | ```py
94 | import betterocr
95 |
96 | image_path = ".github/images/demo-1.png"
97 | items = betterocr.detect_boxes(
98 | image_path,
99 | ["ko", "en"],
100 | context="퍼멘테이션 펩타인 아이케어 크림", # product name
101 | tesseract={
102 | "config": "--psm 6 --tessdata-dir ./tessdata -c tessedit_create_boxfile=1"
103 | },
104 | )
105 | print(items)
106 | ```
107 |
108 |
109 | View Output
110 |
111 | ```py
112 | [
113 | {'text': 'JUST FOR YOU', 'box': [[543, 87], [1013, 87], [1013, 151], [543, 151]]},
114 | {'text': '이런 분들께 추천드리는 퍼멘테이션 펩타인 아이케어 크림', 'box': [[240, 171], [1309, 171], [1309, 224], [240, 224]]},
115 | {'text': '매일매일 진해지는 다크서클을 개선하고 싶다면', 'box': [[123, 345], [1166, 345], [1166, 396], [123, 396]]},
116 | {'text': '축축 처지는 피부를 탄력 있게 바꾸고 싶다면', 'box': [[125, 409], [1242, 409], [1242, 470], [125, 470]]},
117 | {'text': '나날이 늘어가는 눈가 주름을 완화하고 싶다면', 'box': [[123, 479], [1112, 479], [1112, 553], [123, 553]]},
118 | {'text': 'FERMENATION', 'box': [[1216, 578], [1326, 578], [1326, 588], [1216, 588]]},
119 | {'text': '민감성 피부에도 사용할 수 있는 아이크림을 찾는다면', 'box': [[134, 534], [1071, 534], [1071, 618], [134, 618]]},
120 | {'text': '얇고 예민한 눈가 주변 피부를 관리하고 싶다면', 'box': [[173, 634], [1098, 634], [1098, 690], [173, 690]]}
121 | ]
122 | ```
123 |
124 |
125 |
126 | ## 💯 Examples
127 |
128 | > **Note**
129 | > Results may vary due to inherent variability and potential future updates to OCR engines or the OpenAI API.
130 |
131 | ### Example 1 (English with Noise)
132 |
133 |
134 |
135 | | Source | Text |
136 | | ------ | ---- |
137 | | EasyOCR | `CHAINSAWMANChapter 109:The Easy Way to Stop Bullying~BV-THTSUKIFUUIMUTU ETT` |
138 | | Tesseract | `A\ ira \| LT ge a TE ay NS\nye SE F Pa Ce YI AIG 44\nopr See aC\n; a) Ny 7S =u \|\n_ F2 SENN\n\ ZR\n3 ~ 1 A \ Ws —— “s 7 “A\n=) 24 4 = rt fl /1\n£72 7 a NS dA Chapter 109:77/ ¢ 4\nZz % = ~ oes os \| \STheEasf Way.to Stop Bullying:\n© Wa) ROT\n\n` |
139 | | Pororo | `CHAINSAWNAN\nChapter 109\nThe Easy Way.to Stop Bullying.\nCBY=TATSUKI FUJIMDTO` |
140 | | LLM | 🤖 GPT-3.5 |
141 | | **Result** | **`CHAINSAW MAN\n\nChapter 109: The Easy Way to Stop Bullying\n\nBY: TATSUKI FUJIMOTO`** |
142 |
143 | ### Example 2 (Korean+English)
144 |
145 |
146 |
147 | | Source | Text |
148 | | ------ | ---- |
149 | | EasyOCR | `JUST FOR YOU이런 분들께 추천드리는 퍼멘테이선 팬타인 아이켜어 크림매일매일 진해지논 다크서클올 개선하고 싶다면축축 처지논 피부름 탄력 잇게 바꾸고 싶다면나날이 늘어가는 눈가 주름올 완화하고 싶다면FERMENATION민감성 피부에도 사용할 수잇는 아이크림올 찾는다면얇고 예민한 눈가 주변 피부름 관리하고 싶다면` |
150 | | Tesseract | `9051 508 \ㅇ4\n이런 분들께 추천드리는 퍼멘테이션 타인 아이케어 크림\n.매일매일 진해지는 다크서클을 개선하고 싶다면 "도\nㆍ축축 처지는 피부를 탄력 있게 바꾸고 싶다면 7\nㆍ나날이 늘어가는 눈가 주름을 완화하고 싶다면 /\n-민감성 피부에도 사용할 수 있는 아이크림을 찾는다면 (프\nㆍ않고 예민한 눈가 주변 피부를 관리하고 싶다면 밸\n\n` |
151 | | Pororo | `JUST FOR YOU\n이런 분들께 추천드리는 퍼맨테이션 펩타인 아이케어 크림\n매일매일 진해지는 다크서클을 개선하고 싶다면\n촉촉 처지는 피부를 탄력 있게 바꾸고 싶다면\n나날이 늘어가는 눈가 주름을 완화하고 싶다면\nFERMENTATIOM\n민감성 피부에도 사용할 수 있는 아이크림을 찾는다면\n얇고 예민한 눈가 주변 피부를 관리하고 싶다면` |
152 | | LLM | 🤖 GPT-3.5 |
153 | | **Result** | **`JUST FOR YOU\n이런 분들께 추천드리는 퍼멘테이션 펩타인 아이케어 크림\n매일매일 진해지는 다크서클을 개선하고 싶다면\n축축 처지는 피부를 탄력 있게 바꾸고 싶다면\n나날이 늘어가는 눈가 주름을 완화하고 싶다면\nFERMENTATION\n민감성 피부에도 사용할 수 있는 아이크림을 찾는다면\n얇고 예민한 눈가 주변 피부를 관리하고 싶다면`** |
154 |
155 | ### Example 3 (Korean with custom `context`)
156 |
157 |
158 |
159 | | Source | Text |
160 | | ------ | ---- |
161 | | EasyOCR | `바이오함보#세로모공존존세럼6글로우픽 설문단 100인이꼼꼼하게 평가햇어요"#누적 판매액 40억#제품만족도 1009` |
162 | | Tesseract | `바이오힐보\n#세로모공폰폰세럼\n“글로 으피 석무다 1 00인이\n꼼꼼하게평가했어요”\n\n` |
163 | | Pororo | `바이오힐보\n#세로모공쫀쫀세럼\n'.\n'글로우픽 설문단 100인이\n꼼꼼하게 평가했어요'"\n#누적 판매액 40억\n# 제품 만족도 100%` |
164 | | Context | `[바이오힐보] 세로모공쫀쫀세럼으로 콜라겐 타이트닝! (6S)` |
165 | | LLM | 🤖 GPT-4 |
166 | | **Result** | **`바이오힐보\n#세로모공쫀쫀세럼\n글로우픽 설문단 100인이 꼼꼼하게 평가했어요\n#누적 판매액 40억\n#제품 만족도 100%`** |
167 |
168 | #### 🧠 LLM Reasoning (*Old)
169 |
170 | Based on the given OCR results and the context, here is the combined and corrected result:
171 |
172 | ```
173 | {
174 | "data": "바이오힐보\n#세로모공쫀쫀세럼\n글로우픽 설문단 100인이 꼼꼼하게 평가했어요\n#누적 판매액 40억\n#제품만족도 100%"
175 | }
176 | ```
177 |
178 | - `바이오힐보` is the correct brand name, taken from [1] and the context.
179 | - `#세로모공쫀쫀세럼` seems to be the product name and is derived from the context.
180 | - `글로우픽 설문단 100인이 꼼꼼하게 평가했어요` is extracted and corrected from both OCR results.
181 | - `#누적 판매액 40억` is taken from [0].
182 | - `#제품만족도 100%` is corrected from [0].
183 |
184 | ### Example 4 (Hindi)
185 |
186 |
187 |
188 | | Source | Text |
189 | | ------ | ---- |
190 | | EasyOCR | `` `७नवभारतटाइम्सतोक्यो ओलिंपिक के लिए भारतीय दलका थीम सॉन्ग लॉन्च कर दिया गयाबुधवार को इस सॉन्ग को किया गया लॉन्चसिंगर मोहित चौहान ने दी है आवाज7लखेल मंत्री किरण रिजिजू ने ट्विटर पर शेयरकिया थीम सॉन्ग का वीडियो0ब४0 २०२०गीत का नाम- '्लक्ष्य तेरा सामने है' , खेलमंत्री ने ५७ सेकंड का वीडियो किया शेयर `` |
191 | | Tesseract | `'8ा.\nनवभोरत टैइम्स\n\nतोक्यो ओलिंपिक के लिंए भारतीय दल\n\nका थीम सॉन्ग लॉन्च कर दिया गया\n\nबुधवार को हस सॉन्ग को किया गया लॉन्च\nसिंगर मोहित चौहान ने दी है आवाज\n\nखेल मंत्री किरण रिजिजू ने द्विटर पर शेयर\nकिया थीम सॉन्ग का वीडियो\n\nपृ 0 (९ है 0 2 0 2 0 गीत का नाम- 'लक्ष्य तेरा सामने है', खेल\n\n(2 (9९) मंत्री ने 57 सेकंड का वीडियो किया शेयर\n\n` |
192 | | LLM | 🤖 GPT-4 |
193 | | **Result** | **`नवभारत टाइम्स\nतोक्यो ओलिंपिक के लिए भारतीय दल का थीम सॉन्ग लॉन्च कर दिया गया\nबुधवार को इस सॉन्ग को किया गया लॉन्च\nसिंगर मोहित चौहान ने दी है आवाज\n\nखेल मंत्री किरण रिजिजू ने ट्विटर पर शेयर किया थीम सॉन्ग का वीडियो\n2020 गीत का नाम- 'लक्ष्य तेरा सामने है', खेल मंत्री ने 57 सेकंड का वीडियो किया शेयर`** |
194 |
195 | ## License (Starmie!)
196 |
197 |
198 | MIT © Junho Yeo
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 | If you find this project interesting, **please consider giving it a star(⭐)** and following me on [GitHub](https://github.com/junhoyeo). I code 24/7 and ship mind-breaking things on a regular basis, so your support definitely won't be in vain!
208 |
--------------------------------------------------------------------------------
/betterocr/__init__.py:
--------------------------------------------------------------------------------
1 | from .detect import (
2 | detect,
3 | detect_async,
4 | detect_text,
5 | detect_text_async,
6 | detect_boxes,
7 | detect_boxes_async,
8 | NoTextDetectedError,
9 | )
10 | from .wrappers import job_easy_ocr, job_tesseract
11 | from .parsers import extract_json
12 |
13 | __all__ = [
14 | "detect",
15 | "detect_async",
16 | "detect_text",
17 | "detect_text_async",
18 | "detect_boxes",
19 | "detect_boxes_async",
20 | "NoTextDetectedError",
21 | "job_easy_ocr",
22 | "job_tesseract",
23 | "extract_json",
24 | ]
25 |
26 | __author__ = "junhoyeo"
27 |
--------------------------------------------------------------------------------
/betterocr/detect.py:
--------------------------------------------------------------------------------
1 | from threading import Thread
2 | import json
3 | from queue import Queue
4 | import os
5 |
6 | from openai import OpenAI
7 |
8 | from .parsers import extract_json, extract_list, rectangle_corners
9 | from .wrappers import (
10 | job_easy_ocr,
11 | job_easy_ocr_boxes,
12 | job_tesseract,
13 | job_tesseract_boxes,
14 | )
15 |
16 |
17 | def wrapper(func, args, queue):
18 | queue.put(func(args))
19 |
20 |
21 | # custom error
22 | class NoTextDetectedError(Exception):
23 | pass
24 |
25 |
26 | def detect():
27 | """Unimplemented"""
28 | raise NotImplementedError
29 |
30 |
31 | def detect_async():
32 | """Unimplemented"""
33 | raise NotImplementedError
34 |
35 |
36 | def get_jobs(languages: list[str], boxes=False):
37 | jobs = [
38 | job_easy_ocr if not boxes else job_easy_ocr_boxes,
39 | job_tesseract if not boxes else job_tesseract_boxes,
40 | ]
41 | # ko or en in languages
42 | if "ko" in languages or "en" in languages:
43 | try:
44 | if not boxes:
45 | from .wrappers.easy_pororo_ocr import job_easy_pororo_ocr
46 |
47 | jobs.append(job_easy_pororo_ocr)
48 | else:
49 | from .wrappers.easy_pororo_ocr import job_easy_pororo_ocr_boxes
50 |
51 | jobs.append(job_easy_pororo_ocr_boxes)
52 | except ImportError as e:
53 | print(e)
54 | print(
55 | "[!] Pororo dependencies is not installed. Skipping Pororo (EasyPororoOCR)."
56 | )
57 | pass
58 | return jobs
59 |
60 |
61 | def detect_text(
62 | image_path: str,
63 | lang: list[str],
64 | context: str = "",
65 | tesseract: dict = {},
66 | openai: dict = {"model": "gpt-4"},
67 | ):
68 | """Detect text from an image using EasyOCR and Tesseract, then combine and correct the results using OpenAI's LLM."""
69 | options = {
70 | "path": image_path, # "demo.png",
71 | "lang": lang, # ["ko", "en"]
72 | "context": context,
73 | "tesseract": tesseract,
74 | "openai": openai,
75 | }
76 | jobs = get_jobs(languages=options["lang"], boxes=False)
77 |
78 | queues = []
79 | for job in jobs:
80 | queue = Queue()
81 | Thread(target=wrapper, args=(job, options, queue)).start()
82 | queues.append(queue)
83 |
84 | results = [queue.get() for queue in queues]
85 |
86 | result_indexes_prompt = "" # "[0][1][2]"
87 | result_prompt = "" # "[0]: result_0\n[1]: result_1\n[2]: result_2"
88 |
89 | for i in range(len(results)):
90 | result_indexes_prompt += f"[{i}]"
91 | result_prompt += f"[{i}]: {results[i]}"
92 |
93 | if i != len(results) - 1:
94 | result_prompt += "\n"
95 |
96 | optional_context_prompt = (
97 | f"[context]: {options['context']}" if options["context"] else ""
98 | )
99 |
100 | prompt = f"""Combine and correct OCR results {result_indexes_prompt}, using \\n for line breaks. Langauge is in {'+'.join(options['lang'])}. Remove unintended noise. Refer to the [context] keywords. Answer in the JSON format {{data:}}:
101 | {result_prompt}
102 | {optional_context_prompt}"""
103 |
104 | prompt = prompt.strip()
105 |
106 | print("=====")
107 | print(prompt)
108 |
109 | # Prioritize user-specified API_KEY
110 | api_key = options["openai"].get("API_KEY", os.environ.get("OPENAI_API_KEY"))
111 |
112 | # Make a shallow copy of the openai options and remove the API_KEY
113 | openai_options = options["openai"].copy()
114 | if "API_KEY" in openai_options:
115 | del openai_options["API_KEY"]
116 |
117 | client = OpenAI(
118 | api_key=api_key,
119 | )
120 |
121 | print("=====")
122 |
123 | completion = client.chat.completions.create(
124 | messages=[
125 | {"role": "user", "content": prompt},
126 | ],
127 | **openai_options,
128 | )
129 | output = completion.choices[0].message.content
130 | print("[*] LLM", output)
131 |
132 | result = extract_json(output)
133 | print(result)
134 |
135 | if "data" in result:
136 | return result["data"]
137 | if isinstance(result, str):
138 | return result
139 | raise NoTextDetectedError("No text detected")
140 |
141 |
142 | def detect_text_async():
143 | """Unimplemented"""
144 | raise NotImplementedError
145 |
146 |
147 | def detect_boxes(
148 | image_path: str,
149 | lang: list[str],
150 | context: str = "",
151 | tesseract: dict = {},
152 | openai: dict = {"model": "gpt-4"},
153 | ):
154 | options = {
155 | "path": image_path, # "demo.png",
156 | "lang": lang, # ["ko", "en"]
157 | "context": context,
158 | "tesseract": tesseract,
159 | "openai": openai,
160 | }
161 | jobs = get_jobs(languages=options["lang"], boxes=True)
162 |
163 | queues = []
164 | for job in jobs:
165 | queue = Queue()
166 | Thread(target=wrapper, args=(job, options, queue)).start()
167 | queues.append(queue)
168 |
169 | results = [queue.get() for queue in queues]
170 |
171 | result_indexes_prompt = "" # "[0][1][2]"
172 | result_prompt = "" # "[0]: result_0\n[1]: result_1\n[2]: result_2"
173 |
174 | for i in range(len(results)):
175 | result_indexes_prompt += f"[{i}]"
176 |
177 | boxes = results[i]
178 | boxes_json = json.dumps(boxes, ensure_ascii=False, default=int)
179 |
180 | result_prompt += f"[{i}]: {boxes_json}"
181 |
182 | if i != len(results) - 1:
183 | result_prompt += "\n"
184 |
185 | optional_context_prompt = (
186 | " " + "Please refer to the keywords and spelling in [context]"
187 | if options["context"]
188 | else ""
189 | )
190 | optional_context_prompt_data = (
191 | f"[context]: {options['context']}" if options["context"] else ""
192 | )
193 |
194 | prompt = f"""Combine and correct OCR data {result_indexes_prompt}. Include many items as possible. Langauge is in {'+'.join(options['lang'])} (Avoid arbitrary translations). Remove unintended noise.{optional_context_prompt} Answer in the JSON format. Ensure coordinates are integers (round based on confidence if necessary) and output in the same JSON format (indent=0): Array({{box:[[x,y],[x+w,y],[x+w,y+h],[x,y+h]],text:str}}):
195 | {result_prompt}
196 | {optional_context_prompt_data}"""
197 |
198 | prompt = prompt.strip()
199 |
200 | print("=====")
201 | print(prompt)
202 |
203 | # Prioritize user-specified API_KEY
204 | api_key = options["openai"].get("API_KEY", os.environ.get("OPENAI_API_KEY"))
205 |
206 | # Make a shallow copy of the openai options and remove the API_KEY
207 | openai_options = options["openai"].copy()
208 | if "API_KEY" in openai_options:
209 | del openai_options["API_KEY"]
210 |
211 | client = OpenAI(
212 | api_key=api_key,
213 | )
214 |
215 | print("=====")
216 |
217 | completion = client.chat.completions.create(
218 | messages=[
219 | {"role": "user", "content": prompt},
220 | ],
221 | **openai_options,
222 | )
223 | output = completion.choices[0].message.content
224 | output = output.replace("\n", "")
225 | print("[*] LLM", output)
226 |
227 | items = extract_list(output)
228 |
229 | for idx, item in enumerate(items):
230 | box = item["box"]
231 |
232 | # [x,y,w,h]
233 | if len(box) == 4 and isinstance(box[0], int):
234 | rect = rectangle_corners(box)
235 | items[idx]["box"] = rect
236 |
237 | # [[x,y],[w,h]]
238 | elif len(box) == 2 and isinstance(box[0], list) and len(box[0]) == 2:
239 | flattened = [i for sublist in box for i in sublist]
240 | rect = rectangle_corners(flattened)
241 | items[idx]["box"] = rect
242 |
243 | return items
244 |
245 |
246 | def detect_boxes_async():
247 | """Unimplemented"""
248 | raise NotImplementedError
249 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import cv2
9 | from abc import ABC, abstractmethod
10 | from .pororo import Pororo
11 | from .utils.image_util import plt_imshow, put_text
12 | from .utils.image_convert import convert_coord, crop
13 | from .utils.pre_processing import load_with_filter, roi_filter
14 | from easyocr import Reader
15 | import warnings
16 |
17 | warnings.filterwarnings("ignore")
18 |
19 |
20 | class BaseOcr(ABC):
21 | def __init__(self):
22 | self.img_path = None
23 | self.ocr_result = {}
24 |
25 | def get_ocr_result(self):
26 | return self.ocr_result
27 |
28 | def get_img_path(self):
29 | return self.img_path
30 |
31 | def show_img(self):
32 | plt_imshow(img=self.img_path)
33 |
34 | def show_img_with_ocr(self, bounding, description, vertices, point):
35 | img = (
36 | cv2.imread(self.img_path)
37 | if isinstance(self.img_path, str)
38 | else self.img_path
39 | )
40 | roi_img = img.copy()
41 | color = (0, 200, 0)
42 |
43 | x, y = point
44 | ocr_result = self.ocr_result if bounding is None else self.ocr_result[bounding]
45 | for text_result in ocr_result:
46 | text = text_result[description]
47 | rect = text_result[vertices]
48 |
49 | topLeft, topRight, bottomRight, bottomLeft = [
50 | (round(point[x]), round(point[y])) for point in rect
51 | ]
52 |
53 | cv2.line(roi_img, topLeft, topRight, color, 2)
54 | cv2.line(roi_img, topRight, bottomRight, color, 2)
55 | cv2.line(roi_img, bottomRight, bottomLeft, color, 2)
56 | cv2.line(roi_img, bottomLeft, topLeft, color, 2)
57 | roi_img = put_text(roi_img, text, topLeft[0], topLeft[1] - 20, color=color)
58 |
59 | plt_imshow(["Original", "ROI"], [img, roi_img], figsize=(16, 10))
60 |
61 | @abstractmethod
62 | def run_ocr(self, img_path: str, debug: bool = False):
63 | pass
64 |
65 |
66 | class EasyPororoOcr(BaseOcr):
67 | def __init__(self, lang: list[str] = ["ko", "en"], gpu=True, **kwargs):
68 | super().__init__()
69 | self._detector = Reader(lang_list=lang, gpu=gpu, **kwargs).detect
70 | self.detect_result = None
71 | self.languages = lang
72 |
73 | def create_result(self, points):
74 | roi = crop(self.img, points)
75 | result = self._ocr(roi_filter(roi))
76 | text = " ".join(result)
77 |
78 | return [points, text]
79 |
80 | def run_ocr(self, img_path: str, debug: bool = False, **kwargs):
81 | self.img_path = img_path
82 | self.img = cv2.imread(img_path) if isinstance(img_path, str) else self.img_path
83 |
84 | lang = "ko" if "ko" in self.languages else "en"
85 | self._ocr = Pororo(task="ocr", lang=lang, model="brainocr", **kwargs)
86 |
87 | self.detect_result = self._detector(self.img, slope_ths=0.3, height_ths=1)
88 | if debug:
89 | print(self.detect_result)
90 |
91 | horizontal_list, free_list = self.detect_result
92 |
93 | rois = [convert_coord(point) for point in horizontal_list[0]] + free_list[0]
94 |
95 | self.ocr_result = list(
96 | filter(
97 | lambda result: len(result[1]) > 0,
98 | [self.create_result(roi) for roi in rois],
99 | )
100 | )
101 |
102 | if len(self.ocr_result) != 0:
103 | ocr_text = list(map(lambda result: result[1], self.ocr_result))
104 | else:
105 | ocr_text = "No text detected."
106 |
107 | if debug:
108 | self.show_img_with_ocr(None, 1, 0, [0, 1])
109 |
110 | return ocr_text
111 |
112 | def get_boxes(self):
113 | x, y = [0, 1]
114 | ocr_result = self.ocr_result
115 | description = 1
116 | vertices = 0
117 |
118 | items = []
119 |
120 | for text_result in ocr_result:
121 | text = text_result[description]
122 | rect = text_result[vertices]
123 | rect = [[round(point[x]), round(point[y])] for point in rect]
124 | items.append({"box": rect, "text": text})
125 |
126 | return items
127 |
128 |
129 | __all__ = ["EasyPororoOcr", "load_with_filter"]
130 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from .pororo import Pororo # noqa
9 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from .brainocr import Reader # noqa
9 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import os
9 |
10 | from natsort import natsorted
11 | from PIL import Image
12 | from torch.utils.data import Dataset
13 |
14 |
15 | class RawDataset(Dataset):
16 | def __init__(self, root, imgW, imgH):
17 | self.imgW = imgW
18 | self.imgH = imgH
19 | self.image_path_list = []
20 | for dirpath, _, filenames in os.walk(root):
21 | for name in filenames:
22 | _, ext = os.path.splitext(name)
23 | ext = ext.lower()
24 | if ext in (".jpg", ".jpeg", ".png"):
25 | self.image_path_list.append(os.path.join(dirpath, name))
26 |
27 | self.image_path_list = natsorted(self.image_path_list)
28 | self.nSamples = len(self.image_path_list)
29 |
30 | def __len__(self):
31 | return self.nSamples
32 |
33 | def __getitem__(self, index):
34 | try:
35 | img = Image.open(self.image_path_list[index]).convert("L")
36 |
37 | except IOError:
38 | print(f"Corrupted image for {index}")
39 | img = Image.new("L", (self.imgW, self.imgH))
40 |
41 | return img, self.image_path_list[index]
42 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from collections import namedtuple
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.nn.init as init
15 | from torchvision import models
16 |
17 | # from torchvision.models.vgg import model_urls
18 |
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 |
22 | def init_weights(modules):
23 | for m in modules:
24 | if isinstance(m, nn.Conv2d):
25 | init.xavier_uniform_(m.weight.data)
26 | if m.bias is not None:
27 | m.bias.data.zero_()
28 | elif isinstance(m, nn.BatchNorm2d):
29 | m.weight.data.fill_(1)
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.Linear):
32 | m.weight.data.normal_(0, 0.01)
33 | m.bias.data.zero_()
34 |
35 |
36 | class Vgg16BN(torch.nn.Module):
37 | def __init__(self, pretrained: bool = True, freeze: bool = True):
38 | super(Vgg16BN, self).__init__()
39 | # model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace(
40 | # "https://", "http://")
41 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
42 | self.slice1 = torch.nn.Sequential()
43 | self.slice2 = torch.nn.Sequential()
44 | self.slice3 = torch.nn.Sequential()
45 | self.slice4 = torch.nn.Sequential()
46 | self.slice5 = torch.nn.Sequential()
47 | for x in range(12): # conv2_2
48 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
49 | for x in range(12, 19): # conv3_3
50 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
51 | for x in range(19, 29): # conv4_3
52 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
53 | for x in range(29, 39): # conv5_3
54 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
55 |
56 | # fc6, fc7 without atrous conv
57 | self.slice5 = torch.nn.Sequential(
58 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
59 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
60 | nn.Conv2d(1024, 1024, kernel_size=1),
61 | )
62 |
63 | if not pretrained:
64 | init_weights(self.slice1.modules())
65 | init_weights(self.slice2.modules())
66 | init_weights(self.slice3.modules())
67 | init_weights(self.slice4.modules())
68 |
69 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
70 |
71 | if freeze:
72 | for param in self.slice1.parameters(): # only first conv
73 | param.requires_grad = False
74 |
75 | def forward(self, x):
76 | h = self.slice1(x)
77 | h_relu2_2 = h
78 | h = self.slice2(h)
79 | h_relu3_2 = h
80 | h = self.slice3(h)
81 | h_relu4_3 = h
82 | h = self.slice4(h)
83 | h_relu5_3 = h
84 | h = self.slice5(h)
85 | h_fc7 = h
86 | vgg_outputs = namedtuple(
87 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
88 | )
89 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
90 | return out
91 |
92 |
93 | class VGGFeatureExtractor(nn.Module):
94 | """FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf)"""
95 |
96 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
97 | super(VGGFeatureExtractor, self).__init__()
98 |
99 | self.output_channel = [
100 | int(n_output_channels / 8),
101 | int(n_output_channels / 4),
102 | int(n_output_channels / 2),
103 | n_output_channels,
104 | ] # [64, 128, 256, 512]
105 | self.ConvNet = nn.Sequential(
106 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
107 | nn.ReLU(True),
108 | nn.MaxPool2d(2, 2), # 64x16x50
109 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1),
110 | nn.ReLU(True),
111 | nn.MaxPool2d(2, 2), # 128x8x25
112 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1),
113 | nn.ReLU(True), # 256x8x25
114 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1),
115 | nn.ReLU(True),
116 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
117 | nn.Conv2d(
118 | self.output_channel[2],
119 | self.output_channel[3],
120 | 3,
121 | 1,
122 | 1,
123 | bias=False,
124 | ),
125 | nn.BatchNorm2d(self.output_channel[3]),
126 | nn.ReLU(True), # 512x4x25
127 | nn.Conv2d(
128 | self.output_channel[3],
129 | self.output_channel[3],
130 | 3,
131 | 1,
132 | 1,
133 | bias=False,
134 | ),
135 | nn.BatchNorm2d(self.output_channel[3]),
136 | nn.ReLU(True),
137 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
138 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0),
139 | nn.ReLU(True),
140 | ) # 512x1x24
141 |
142 | def forward(self, x):
143 | return self.ConvNet(x)
144 |
145 |
146 | class BidirectionalLSTM(nn.Module):
147 | def __init__(self, input_size: int, hidden_size: int, output_size: int):
148 | super(BidirectionalLSTM, self).__init__()
149 | self.rnn = nn.LSTM(
150 | input_size,
151 | hidden_size,
152 | bidirectional=True,
153 | batch_first=True,
154 | )
155 | self.linear = nn.Linear(hidden_size * 2, output_size)
156 |
157 | def forward(self, x):
158 | """
159 | x : visual feature [batch_size x T x input_size]
160 | output : contextual feature [batch_size x T x output_size]
161 | """
162 | self.rnn.flatten_parameters()
163 | recurrent, _ = self.rnn(
164 | x
165 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
166 | output = self.linear(recurrent) # batch_size x T x output_size
167 | return output
168 |
169 |
170 | class ResNetFeatureExtractor(nn.Module):
171 | """
172 | FeatureExtractor of FAN
173 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf)
174 |
175 | """
176 |
177 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
178 | super(ResNetFeatureExtractor, self).__init__()
179 | self.ConvNet = ResNet(
180 | n_input_channels,
181 | n_output_channels,
182 | BasicBlock,
183 | [1, 2, 5, 3],
184 | )
185 |
186 | def forward(self, inputs):
187 | return self.ConvNet(inputs)
188 |
189 |
190 | class BasicBlock(nn.Module):
191 | expansion = 1
192 |
193 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
194 | super(BasicBlock, self).__init__()
195 | self.conv1 = self._conv3x3(inplanes, planes)
196 | self.bn1 = nn.BatchNorm2d(planes)
197 | self.conv2 = self._conv3x3(planes, planes)
198 | self.bn2 = nn.BatchNorm2d(planes)
199 | self.relu = nn.ReLU(inplace=True)
200 | self.downsample = downsample
201 | self.stride = stride
202 |
203 | def _conv3x3(self, in_planes, out_planes, stride=1):
204 | "3x3 convolution with padding"
205 | return nn.Conv2d(
206 | in_planes,
207 | out_planes,
208 | kernel_size=3,
209 | stride=stride,
210 | padding=1,
211 | bias=False,
212 | )
213 |
214 | def forward(self, x):
215 | residual = x
216 |
217 | out = self.conv1(x)
218 | out = self.bn1(out)
219 | out = self.relu(out)
220 |
221 | out = self.conv2(out)
222 | out = self.bn2(out)
223 |
224 | if self.downsample is not None:
225 | residual = self.downsample(x)
226 | out += residual
227 | out = self.relu(out)
228 |
229 | return out
230 |
231 |
232 | class ResNet(nn.Module):
233 | def __init__(
234 | self,
235 | n_input_channels: int,
236 | n_output_channels: int,
237 | block,
238 | layers,
239 | ):
240 | """
241 | :param n_input_channels (int): The number of input channels of the feature extractor
242 | :param n_output_channels (int): The number of output channels of the feature extractor
243 | :param block:
244 | :param layers:
245 | """
246 | super(ResNet, self).__init__()
247 |
248 | self.output_channel_blocks = [
249 | int(n_output_channels / 4),
250 | int(n_output_channels / 2),
251 | n_output_channels,
252 | n_output_channels,
253 | ]
254 |
255 | self.inplanes = int(n_output_channels / 8)
256 | self.conv0_1 = nn.Conv2d(
257 | n_input_channels,
258 | int(n_output_channels / 16),
259 | kernel_size=3,
260 | stride=1,
261 | padding=1,
262 | bias=False,
263 | )
264 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16))
265 | self.conv0_2 = nn.Conv2d(
266 | int(n_output_channels / 16),
267 | self.inplanes,
268 | kernel_size=3,
269 | stride=1,
270 | padding=1,
271 | bias=False,
272 | )
273 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
274 | self.relu = nn.ReLU(inplace=True)
275 |
276 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
277 | self.layer1 = self._make_layer(
278 | block,
279 | self.output_channel_blocks[0],
280 | layers[0],
281 | )
282 | self.conv1 = nn.Conv2d(
283 | self.output_channel_blocks[0],
284 | self.output_channel_blocks[0],
285 | kernel_size=3,
286 | stride=1,
287 | padding=1,
288 | bias=False,
289 | )
290 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0])
291 |
292 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
293 | self.layer2 = self._make_layer(
294 | block,
295 | self.output_channel_blocks[1],
296 | layers[1],
297 | stride=1,
298 | )
299 | self.conv2 = nn.Conv2d(
300 | self.output_channel_blocks[1],
301 | self.output_channel_blocks[1],
302 | kernel_size=3,
303 | stride=1,
304 | padding=1,
305 | bias=False,
306 | )
307 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1])
308 |
309 | self.maxpool3 = nn.MaxPool2d(
310 | kernel_size=2,
311 | stride=(2, 1),
312 | padding=(0, 1),
313 | )
314 | self.layer3 = self._make_layer(
315 | block,
316 | self.output_channel_blocks[2],
317 | layers[2],
318 | stride=1,
319 | )
320 | self.conv3 = nn.Conv2d(
321 | self.output_channel_blocks[2],
322 | self.output_channel_blocks[2],
323 | kernel_size=3,
324 | stride=1,
325 | padding=1,
326 | bias=False,
327 | )
328 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2])
329 |
330 | self.layer4 = self._make_layer(
331 | block,
332 | self.output_channel_blocks[3],
333 | layers[3],
334 | stride=1,
335 | )
336 | self.conv4_1 = nn.Conv2d(
337 | self.output_channel_blocks[3],
338 | self.output_channel_blocks[3],
339 | kernel_size=2,
340 | stride=(2, 1),
341 | padding=(0, 1),
342 | bias=False,
343 | )
344 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3])
345 | self.conv4_2 = nn.Conv2d(
346 | self.output_channel_blocks[3],
347 | self.output_channel_blocks[3],
348 | kernel_size=2,
349 | stride=1,
350 | padding=0,
351 | bias=False,
352 | )
353 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3])
354 |
355 | def _make_layer(self, block, planes, blocks, stride=1):
356 | downsample = None
357 | if stride != 1 or self.inplanes != planes * block.expansion:
358 | downsample = nn.Sequential(
359 | nn.Conv2d(
360 | self.inplanes,
361 | planes * block.expansion,
362 | kernel_size=1,
363 | stride=stride,
364 | bias=False,
365 | ),
366 | nn.BatchNorm2d(planes * block.expansion),
367 | )
368 |
369 | layers = []
370 | layers.append(block(self.inplanes, planes, stride, downsample))
371 | self.inplanes = planes * block.expansion
372 | for i in range(1, blocks):
373 | layers.append(block(self.inplanes, planes))
374 |
375 | return nn.Sequential(*layers)
376 |
377 | def forward(self, x):
378 | x = self.conv0_1(x)
379 | x = self.bn0_1(x)
380 | x = self.relu(x)
381 | x = self.conv0_2(x)
382 | x = self.bn0_2(x)
383 | x = self.relu(x)
384 |
385 | x = self.maxpool1(x)
386 | x = self.layer1(x)
387 | x = self.conv1(x)
388 | x = self.bn1(x)
389 | x = self.relu(x)
390 |
391 | x = self.maxpool2(x)
392 | x = self.layer2(x)
393 | x = self.conv2(x)
394 | x = self.bn2(x)
395 | x = self.relu(x)
396 |
397 | x = self.maxpool3(x)
398 | x = self.layer3(x)
399 | x = self.conv3(x)
400 | x = self.bn3(x)
401 | x = self.relu(x)
402 |
403 | x = self.layer4(x)
404 | x = self.conv4_1(x)
405 | x = self.bn4_1(x)
406 | x = self.relu(x)
407 | x = self.conv4_2(x)
408 | x = self.bn4_2(x)
409 | x = self.relu(x)
410 |
411 | return x
412 |
413 |
414 | class TpsSpatialTransformerNetwork(nn.Module):
415 | """Rectification Network of RARE, namely TPS based STN"""
416 |
417 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1):
418 | """Based on RARE TPS
419 | input:
420 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
421 | I_size : (height, width) of the input image I
422 | I_r_size : (height, width) of the rectified image I_r
423 | I_channel_num : the number of channels of the input image I
424 | output:
425 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
426 | """
427 | super(TpsSpatialTransformerNetwork, self).__init__()
428 | self.F = F
429 | self.I_size = I_size
430 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
431 | self.I_channel_num = I_channel_num
432 | self.LocalizationNetwork = LocalizationNetwork(
433 | self.F,
434 | self.I_channel_num,
435 | )
436 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
437 |
438 | def forward(self, batch_I):
439 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
440 | build_P_prime = self.GridGenerator.build_P_prime(
441 | batch_C_prime
442 | ) # batch_size x n (= I_r_width x I_r_height) x 2
443 | build_P_prime_reshape = build_P_prime.reshape(
444 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]
445 | )
446 |
447 | batch_I_r = F.grid_sample(
448 | batch_I,
449 | build_P_prime_reshape,
450 | padding_mode="border",
451 | )
452 |
453 | return batch_I_r
454 |
455 |
456 | class LocalizationNetwork(nn.Module):
457 | """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)"""
458 |
459 | def __init__(self, F, I_channel_num: int):
460 | super(LocalizationNetwork, self).__init__()
461 | self.F = F
462 | self.I_channel_num = I_channel_num
463 | self.conv = nn.Sequential(
464 | nn.Conv2d(
465 | in_channels=self.I_channel_num,
466 | out_channels=64,
467 | kernel_size=3,
468 | stride=1,
469 | padding=1,
470 | bias=False,
471 | ),
472 | nn.BatchNorm2d(64),
473 | nn.ReLU(True),
474 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
475 | nn.Conv2d(64, 128, 3, 1, 1, bias=False),
476 | nn.BatchNorm2d(128),
477 | nn.ReLU(True),
478 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
479 | nn.Conv2d(128, 256, 3, 1, 1, bias=False),
480 | nn.BatchNorm2d(256),
481 | nn.ReLU(True),
482 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
483 | nn.Conv2d(256, 512, 3, 1, 1, bias=False),
484 | nn.BatchNorm2d(512),
485 | nn.ReLU(True),
486 | nn.AdaptiveAvgPool2d(1), # batch_size x 512
487 | )
488 |
489 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
490 | self.localization_fc2 = nn.Linear(256, self.F * 2)
491 |
492 | # Init fc2 in LocalizationNetwork
493 | self.localization_fc2.weight.data.fill_(0)
494 |
495 | # see RARE paper Fig. 6 (a)
496 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
497 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
498 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
499 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
500 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
501 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
502 | self.localization_fc2.bias.data = (
503 | torch.from_numpy(initial_bias).float().view(-1)
504 | )
505 |
506 | def forward(self, batch_I):
507 | """
508 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
509 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
510 | """
511 | batch_size = batch_I.size(0)
512 | features = self.conv(batch_I).view(batch_size, -1)
513 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(
514 | batch_size, self.F, 2
515 | )
516 | return batch_C_prime
517 |
518 |
519 | class GridGenerator(nn.Module):
520 | """Grid Generator of RARE, which produces P_prime by multipling T with P"""
521 |
522 | def __init__(self, F, I_r_size):
523 | """Generate P_hat and inv_delta_C for later"""
524 | super(GridGenerator, self).__init__()
525 | self.eps = 1e-6
526 | self.I_r_height, self.I_r_width = I_r_size
527 | self.F = F
528 | self.C = self._build_C(self.F) # F x 2
529 | self.P = self._build_P(self.I_r_width, self.I_r_height)
530 |
531 | # for multi-gpu, you need register buffer
532 | self.register_buffer(
533 | "inv_delta_C",
534 | torch.tensor(
535 | self._build_inv_delta_C(
536 | self.F,
537 | self.C,
538 | )
539 | ).float(),
540 | ) # F+3 x F+3
541 | self.register_buffer(
542 | "P_hat",
543 | torch.tensor(
544 | self._build_P_hat(
545 | self.F,
546 | self.C,
547 | self.P,
548 | )
549 | ).float(),
550 | ) # n x F+3
551 |
552 | def _build_C(self, F):
553 | """Return coordinates of fiducial points in I_r; C"""
554 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
555 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
556 | ctrl_pts_y_bottom = np.ones(int(F / 2))
557 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
558 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
559 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
560 | return C # F x 2
561 |
562 | def _build_inv_delta_C(self, F, C):
563 | """Return inv_delta_C which is needed to calculate T"""
564 | hat_C = np.zeros((F, F), dtype=float) # F x F
565 | for i in range(0, F):
566 | for j in range(i, F):
567 | r = np.linalg.norm(C[i] - C[j])
568 | hat_C[i, j] = r
569 | hat_C[j, i] = r
570 | np.fill_diagonal(hat_C, 1)
571 | hat_C = (hat_C**2) * np.log(hat_C)
572 | # print(C.shape, hat_C.shape)
573 | delta_C = np.concatenate( # F+3 x F+3
574 | [
575 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
576 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
577 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
578 | ],
579 | axis=0,
580 | )
581 | inv_delta_C = np.linalg.inv(delta_C)
582 | return inv_delta_C # F+3 x F+3
583 |
584 | def _build_P(self, I_r_width, I_r_height):
585 | I_r_grid_x = (
586 | np.arange(-I_r_width, I_r_width, 2) + 1.0
587 | ) / I_r_width # self.I_r_width
588 | I_r_grid_y = (
589 | np.arange(-I_r_height, I_r_height, 2) + 1.0
590 | ) / I_r_height # self.I_r_height
591 | P = np.stack( # self.I_r_width x self.I_r_height x 2
592 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2
593 | )
594 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
595 |
596 | def _build_P_hat(self, F, C, P):
597 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
598 | P_tile = np.tile(
599 | np.expand_dims(P, axis=1), (1, F, 1)
600 | ) # n x 2 -> n x 1 x 2 -> n x F x 2
601 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
602 | P_diff = P_tile - C_tile # n x F x 2
603 | rbf_norm = np.linalg.norm(
604 | P_diff,
605 | ord=2,
606 | axis=2,
607 | keepdims=False,
608 | ) # n x F
609 | rbf = np.multiply(
610 | np.square(rbf_norm),
611 | np.log(rbf_norm + self.eps),
612 | ) # n x F
613 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
614 | return P_hat # n x F+3
615 |
616 | def build_P_prime(self, batch_C_prime):
617 | """Generate Grid from batch_C_prime [batch_size x F x 2]"""
618 | batch_size = batch_C_prime.size(0)
619 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
620 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
621 | batch_C_prime_with_zeros = torch.cat(
622 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1
623 | ) # batch_size x F+3 x 2
624 | batch_T = torch.bmm(
625 | batch_inv_delta_C,
626 | batch_C_prime_with_zeros,
627 | ) # batch_size x F+3 x 2
628 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
629 | return batch_P_prime # batch_size x n x 2
630 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/brainocr.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is primarily based on the following:
10 | https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/easyocr.py
11 |
12 | Basic usage:
13 | >>> from pororo import Pororo
14 | >>> ocr = Pororo(task="ocr", lang="ko")
15 | >>> ocr("IMAGE_FILE")
16 | """
17 |
18 | import ast
19 | from logging import getLogger
20 | from typing import List
21 |
22 | import cv2
23 | import numpy as np
24 | from PIL import Image
25 |
26 | from .detection import get_detector, get_textbox
27 | from .recognition import get_recognizer, get_text
28 | from .utils import (
29 | diff,
30 | get_image_list,
31 | get_paragraph,
32 | group_text_box,
33 | reformat_input,
34 | )
35 |
36 | LOGGER = getLogger(__name__)
37 |
38 |
39 | class Reader(object):
40 | def __init__(
41 | self,
42 | lang: str,
43 | det_model_ckpt_fp: str,
44 | rec_model_ckpt_fp: str,
45 | opt_fp: str,
46 | device: str,
47 | ) -> None:
48 | """
49 | TODO @karter: modify this such that you download the pretrained checkpoint files
50 | Parameters:
51 | lang: language code. e.g, "en" or "ko"
52 | det_model_ckpt_fp: Detection model's checkpoint path e.g., 'craft_mlt_25k.pth'
53 | rec_model_ckpt_fp: Recognition model's checkpoint path
54 | opt_fp: option file path
55 | """
56 | # Plug options in the dictionary
57 | opt2val = self.parse_options(opt_fp) # e.g., {"imgH": 64, ...}
58 | opt2val["vocab"] = self.build_vocab(opt2val["character"])
59 | opt2val["vocab_size"] = len(opt2val["vocab"])
60 | opt2val["device"] = device
61 | opt2val["lang"] = lang
62 | opt2val["det_model_ckpt_fp"] = det_model_ckpt_fp
63 | opt2val["rec_model_ckpt_fp"] = rec_model_ckpt_fp
64 |
65 | # Get model objects
66 | self.detector = get_detector(det_model_ckpt_fp, opt2val["device"])
67 | self.recognizer, self.converter = get_recognizer(opt2val)
68 | self.opt2val = opt2val
69 |
70 | @staticmethod
71 | def parse_options(opt_fp: str) -> dict:
72 | opt2val = dict()
73 | for line in open(opt_fp, "r", encoding="utf8"):
74 | line = line.strip()
75 | if ": " in line:
76 | opt, val = line.split(": ", 1)
77 | try:
78 | opt2val[opt] = ast.literal_eval(val)
79 | except:
80 | opt2val[opt] = val
81 |
82 | return opt2val
83 |
84 | @staticmethod
85 | def build_vocab(character: str) -> List[str]:
86 | """Returns vocabulary (=list of characters)"""
87 | vocab = ["[blank]"] + list(
88 | character
89 | ) # dummy '[blank]' token for CTCLoss (index 0)
90 | return vocab
91 |
92 | def detect(self, img: np.ndarray, opt2val: dict):
93 | """
94 | :return:
95 | horizontal_list (list): e.g., [[613, 1496, 51, 190], [136, 1544, 134, 508]]
96 | free_list (list): e.g., []
97 | """
98 | text_box = get_textbox(self.detector, img, opt2val)
99 | horizontal_list, free_list = group_text_box(
100 | text_box,
101 | opt2val["slope_ths"],
102 | opt2val["ycenter_ths"],
103 | opt2val["height_ths"],
104 | opt2val["width_ths"],
105 | opt2val["add_margin"],
106 | )
107 |
108 | min_size = opt2val["min_size"]
109 | if min_size:
110 | horizontal_list = [
111 | i for i in horizontal_list if max(i[1] - i[0], i[3] - i[2]) > min_size
112 | ]
113 | free_list = [
114 | i
115 | for i in free_list
116 | if max(diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size
117 | ]
118 |
119 | return horizontal_list, free_list
120 |
121 | def recognize(
122 | self,
123 | img_cv_grey: np.ndarray,
124 | horizontal_list: list,
125 | free_list: list,
126 | opt2val: dict,
127 | ):
128 | """
129 | Read text in the image
130 | :return:
131 | result (list): bounding box, text and confident score
132 | e.g., [([[189, 75], [469, 75], [469, 165], [189, 165]], '愚园路', 0.3754989504814148),
133 | ([[86, 80], [134, 80], [134, 128], [86, 128]], '西', 0.40452659130096436),
134 | ([[517, 81], [565, 81], [565, 123], [517, 123]], '东', 0.9989598989486694),
135 | ([[78, 126], [136, 126], [136, 156], [78, 156]], '315', 0.8125889301300049),
136 | ([[514, 126], [574, 126], [574, 156], [514, 156]], '309', 0.4971577227115631),
137 | ([[226, 170], [414, 170], [414, 220], [226, 220]], 'Yuyuan Rd.', 0.8261902332305908),
138 | ([[79, 173], [125, 173], [125, 213], [79, 213]], 'W', 0.9848111271858215),
139 | ([[529, 173], [569, 173], [569, 213], [529, 213]], 'E', 0.8405593633651733)]
140 | or list of texts (if skip_details is True)
141 | e.g., ['愚园路', '西', '东', '315', '309', 'Yuyuan Rd.', 'W', 'E']
142 | """
143 | imgH = opt2val["imgH"]
144 | paragraph = opt2val["paragraph"]
145 | skip_details = opt2val["skip_details"]
146 |
147 | if (horizontal_list is None) and (free_list is None):
148 | y_max, x_max = img_cv_grey.shape
149 | ratio = x_max / y_max
150 | max_width = int(imgH * ratio)
151 | crop_img = cv2.resize(
152 | img_cv_grey,
153 | (max_width, imgH),
154 | interpolation=Image.LANCZOS,
155 | )
156 | image_list = [([[0, 0], [x_max, 0], [x_max, y_max], [0, y_max]], crop_img)]
157 | else:
158 | image_list, max_width = get_image_list(
159 | horizontal_list,
160 | free_list,
161 | img_cv_grey,
162 | model_height=imgH,
163 | )
164 |
165 | result = get_text(image_list, self.recognizer, self.converter, opt2val)
166 |
167 | if paragraph:
168 | result = get_paragraph(result, mode="ltr")
169 |
170 | if skip_details: # texts only
171 | return [item[1] for item in result]
172 | else: # full outputs: bounding box, text and confident score
173 | return result
174 |
175 | def __call__(
176 | self,
177 | image,
178 | batch_size: int = 1,
179 | n_workers: int = 0,
180 | skip_details: bool = False,
181 | paragraph: bool = False,
182 | min_size: int = 20,
183 | contrast_ths: float = 0.1,
184 | adjust_contrast: float = 0.5,
185 | filter_ths: float = 0.003,
186 | text_threshold: float = 0.7,
187 | low_text: float = 0.4,
188 | link_threshold: float = 0.4,
189 | canvas_size: int = 2560,
190 | mag_ratio: float = 1.0,
191 | slope_ths: float = 0.1,
192 | ycenter_ths: float = 0.5,
193 | height_ths: float = 0.5,
194 | width_ths: float = 0.5,
195 | add_margin: float = 0.1,
196 | ):
197 | """
198 | Detect text in the image and then recognize it.
199 | :param image: file path or numpy-array or a byte stream object
200 | :param batch_size:
201 | :param n_workers:
202 | :param skip_details:
203 | :param paragraph:
204 | :param min_size:
205 | :param contrast_ths:
206 | :param adjust_contrast:
207 | :param filter_ths:
208 | :param text_threshold:
209 | :param low_text:
210 | :param link_threshold:
211 | :param canvas_size:
212 | :param mag_ratio:
213 | :param slope_ths:
214 | :param ycenter_ths:
215 | :param height_ths:
216 | :param width_ths:
217 | :param add_margin:
218 | :return:
219 | """
220 | # update `opt2val`
221 | self.opt2val["batch_size"] = batch_size
222 | self.opt2val["n_workers"] = n_workers
223 | self.opt2val["skip_details"] = skip_details
224 | self.opt2val["paragraph"] = paragraph
225 | self.opt2val["min_size"] = min_size
226 | self.opt2val["contrast_ths"] = contrast_ths
227 | self.opt2val["adjust_contrast"] = adjust_contrast
228 | self.opt2val["filter_ths"] = filter_ths
229 | self.opt2val["text_threshold"] = text_threshold
230 | self.opt2val["low_text"] = low_text
231 | self.opt2val["link_threshold"] = link_threshold
232 | self.opt2val["canvas_size"] = canvas_size
233 | self.opt2val["mag_ratio"] = mag_ratio
234 | self.opt2val["slope_ths"] = slope_ths
235 | self.opt2val["ycenter_ths"] = ycenter_ths
236 | self.opt2val["height_ths"] = height_ths
237 | self.opt2val["width_ths"] = width_ths
238 | self.opt2val["add_margin"] = add_margin
239 |
240 | img, img_cv_grey = reformat_input(image) # img, img_cv_grey: array
241 |
242 | horizontal_list, free_list = self.detect(img, self.opt2val)
243 | result = self.recognize(
244 | img_cv_grey,
245 | horizontal_list,
246 | free_list,
247 | self.opt2val,
248 | )
249 |
250 | return result
251 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/craft.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/craft.py.
10 | Copyright (c) 2019-present NAVER Corp.
11 | MIT License
12 | """
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torch import Tensor
18 |
19 | from ._modules import Vgg16BN, init_weights
20 |
21 |
22 | class DoubleConv(nn.Module):
23 | def __init__(self, in_ch: int, mid_ch: int, out_ch: int) -> None:
24 | super(DoubleConv, self).__init__()
25 | self.conv = nn.Sequential(
26 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
27 | nn.BatchNorm2d(mid_ch),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
30 | nn.BatchNorm2d(out_ch),
31 | nn.ReLU(inplace=True),
32 | )
33 |
34 | def forward(self, x: Tensor):
35 | x = self.conv(x)
36 | return x
37 |
38 |
39 | class CRAFT(nn.Module):
40 | def __init__(self, pretrained: bool = False, freeze: bool = False) -> None:
41 | super(CRAFT, self).__init__()
42 |
43 | # Base network
44 | self.basenet = Vgg16BN(pretrained, freeze)
45 |
46 | # U network
47 | self.upconv1 = DoubleConv(1024, 512, 256)
48 | self.upconv2 = DoubleConv(512, 256, 128)
49 | self.upconv3 = DoubleConv(256, 128, 64)
50 | self.upconv4 = DoubleConv(128, 64, 32)
51 |
52 | num_class = 2
53 | self.conv_cls = nn.Sequential(
54 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
55 | nn.ReLU(inplace=True),
56 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(32, 16, kernel_size=3, padding=1),
59 | nn.ReLU(inplace=True),
60 | nn.Conv2d(16, 16, kernel_size=1),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(16, num_class, kernel_size=1),
63 | )
64 |
65 | init_weights(self.upconv1.modules())
66 | init_weights(self.upconv2.modules())
67 | init_weights(self.upconv3.modules())
68 | init_weights(self.upconv4.modules())
69 | init_weights(self.conv_cls.modules())
70 |
71 | def forward(self, x: Tensor):
72 | # Base network
73 | sources = self.basenet(x)
74 |
75 | # U network
76 | y = torch.cat([sources[0], sources[1]], dim=1)
77 | y = self.upconv1(y)
78 |
79 | y = F.interpolate(
80 | y,
81 | size=sources[2].size()[2:],
82 | mode="bilinear",
83 | align_corners=False,
84 | )
85 | y = torch.cat([y, sources[2]], dim=1)
86 | y = self.upconv2(y)
87 |
88 | y = F.interpolate(
89 | y,
90 | size=sources[3].size()[2:],
91 | mode="bilinear",
92 | align_corners=False,
93 | )
94 | y = torch.cat([y, sources[3]], dim=1)
95 | y = self.upconv3(y)
96 |
97 | y = F.interpolate(
98 | y,
99 | size=sources[4].size()[2:],
100 | mode="bilinear",
101 | align_corners=False,
102 | )
103 | y = torch.cat([y, sources[4]], dim=1)
104 | feature = self.upconv4(y)
105 |
106 | y = self.conv_cls(feature)
107 |
108 | return y.permute(0, 2, 3, 1), feature
109 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/craft_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from avhttps://github.com/clovaai/CRAFT-pytorch/blob/master/craft_utils.py
10 | MIT License
11 | """
12 |
13 | import math
14 |
15 | import cv2
16 | import numpy as np
17 |
18 |
19 | def warp_coord(Minv, pt):
20 | """auxilary functions: unwarp corodinates:"""
21 | out = np.matmul(Minv, (pt[0], pt[1], 1))
22 | return np.array([out[0] / out[2], out[1] / out[2]])
23 |
24 |
25 | def get_det_boxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
26 | # prepare data
27 | linkmap = linkmap.copy()
28 | textmap = textmap.copy()
29 | img_h, img_w = textmap.shape
30 |
31 | # labeling method
32 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
33 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
34 |
35 | text_score_comb = np.clip(text_score + link_score, 0, 1)
36 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
37 | text_score_comb.astype(np.uint8), connectivity=4
38 | )
39 |
40 | det = []
41 | mapper = []
42 | for k in range(1, nLabels):
43 | # size filtering
44 | size = stats[k, cv2.CC_STAT_AREA]
45 | if size < 10:
46 | continue
47 |
48 | # thresholding
49 | if np.max(textmap[labels == k]) < text_threshold:
50 | continue
51 |
52 | # make segmentation map
53 | segmap = np.zeros(textmap.shape, dtype=np.uint8)
54 | segmap[labels == k] = 255
55 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 # remove link area
56 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
57 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
58 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
59 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
60 | # boundary check
61 | if sx < 0:
62 | sx = 0
63 | if sy < 0:
64 | sy = 0
65 | if ex >= img_w:
66 | ex = img_w
67 | if ey >= img_h:
68 | ey = img_h
69 | kernel = cv2.getStructuringElement(
70 | cv2.MORPH_RECT,
71 | (1 + niter, 1 + niter),
72 | )
73 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
74 |
75 | # make box
76 | np_contours = (
77 | np.roll(np.array(np.where(segmap != 0)), 1, axis=0)
78 | .transpose()
79 | .reshape(-1, 2)
80 | )
81 | rectangle = cv2.minAreaRect(np_contours)
82 | box = cv2.boxPoints(rectangle)
83 |
84 | # align diamond-shape
85 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
86 | box_ratio = max(w, h) / (min(w, h) + 1e-5)
87 | if abs(1 - box_ratio) <= 0.1:
88 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
89 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
90 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
91 |
92 | # make clock-wise order
93 | startidx = box.sum(axis=1).argmin()
94 | box = np.roll(box, 4 - startidx, 0)
95 | box = np.array(box)
96 |
97 | det.append(box)
98 | mapper.append(k)
99 |
100 | return det, labels, mapper
101 |
102 |
103 | def get_poly_core(boxes, labels, mapper, linkmap):
104 | # configs
105 | num_cp = 5
106 | max_len_ratio = 0.7
107 | expand_ratio = 1.45
108 | max_r = 2.0
109 | step_r = 0.2
110 |
111 | polys = []
112 | for k, box in enumerate(boxes):
113 | # size filter for small instance
114 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(
115 | np.linalg.norm(box[1] - box[2]) + 1
116 | )
117 | if w < 10 or h < 10:
118 | polys.append(None)
119 | continue
120 |
121 | # warp image
122 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
123 | M = cv2.getPerspectiveTransform(box, tar)
124 | word_label = cv2.warpPerspective(
125 | labels,
126 | M,
127 | (w, h),
128 | flags=cv2.INTER_NEAREST,
129 | )
130 | try:
131 | Minv = np.linalg.inv(M)
132 | except:
133 | polys.append(None)
134 | continue
135 |
136 | # binarization for selected label
137 | cur_label = mapper[k]
138 | word_label[word_label != cur_label] = 0
139 | word_label[word_label > 0] = 1
140 |
141 | # Polygon generation: find top/bottom contours
142 | cp = []
143 | max_len = -1
144 | for i in range(w):
145 | region = np.where(word_label[:, i] != 0)[0]
146 | if len(region) < 2:
147 | continue
148 | cp.append((i, region[0], region[-1]))
149 | length = region[-1] - region[0] + 1
150 | if length > max_len:
151 | max_len = length
152 |
153 | # pass if max_len is similar to h
154 | if h * max_len_ratio < max_len:
155 | polys.append(None)
156 | continue
157 |
158 | # get pivot points with fixed length
159 | tot_seg = num_cp * 2 + 1
160 | seg_w = w / tot_seg # segment width
161 | pp = [None] * num_cp # init pivot points
162 | cp_section = [[0, 0]] * tot_seg
163 | seg_height = [0] * num_cp
164 | seg_num = 0
165 | num_sec = 0
166 | prev_h = -1
167 | for i in range(0, len(cp)):
168 | (x, sy, ey) = cp[i]
169 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
170 | # average previous segment
171 | if num_sec == 0:
172 | break
173 | cp_section[seg_num] = [
174 | cp_section[seg_num][0] / num_sec,
175 | cp_section[seg_num][1] / num_sec,
176 | ]
177 | num_sec = 0
178 |
179 | # reset variables
180 | seg_num += 1
181 | prev_h = -1
182 |
183 | # accumulate center points
184 | cy = (sy + ey) * 0.5
185 | cur_h = ey - sy + 1
186 | cp_section[seg_num] = [
187 | cp_section[seg_num][0] + x,
188 | cp_section[seg_num][1] + cy,
189 | ]
190 | num_sec += 1
191 |
192 | if seg_num % 2 == 0:
193 | continue # No polygon area
194 |
195 | if prev_h < cur_h:
196 | pp[int((seg_num - 1) / 2)] = (x, cy)
197 | seg_height[int((seg_num - 1) / 2)] = cur_h
198 | prev_h = cur_h
199 |
200 | # processing last segment
201 | if num_sec != 0:
202 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
203 |
204 | # pass if num of pivots is not sufficient or segment width is smaller than character height
205 | if None in pp or seg_w < np.max(seg_height) * 0.25:
206 | polys.append(None)
207 | continue
208 |
209 | # calc median maximum of pivot points
210 | half_char_h = np.median(seg_height) * expand_ratio / 2
211 |
212 | # calc gradiant and apply to make horizontal pivots
213 | new_pp = []
214 | for i, (x, cy) in enumerate(pp):
215 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
216 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
217 | if dx == 0: # gradient if zero
218 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
219 | continue
220 | rad = -math.atan2(dy, dx)
221 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
222 | new_pp.append([x - s, cy - c, x + s, cy + c])
223 |
224 | # get edge points to cover character heatmaps
225 | isSppFound, isEppFound = False, False
226 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (
227 | pp[2][1] - pp[1][1]
228 | ) / (pp[2][0] - pp[1][0])
229 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (
230 | pp[-3][1] - pp[-2][1]
231 | ) / (pp[-3][0] - pp[-2][0])
232 | for r in np.arange(0.5, max_r, step_r):
233 | dx = 2 * half_char_h * r
234 | if not isSppFound:
235 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
236 | dy = grad_s * dx
237 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
238 | cv2.line(
239 | line_img,
240 | (int(p[0]), int(p[1])),
241 | (int(p[2]), int(p[3])),
242 | 1,
243 | thickness=1,
244 | )
245 | if (
246 | np.sum(np.logical_and(word_label, line_img)) == 0
247 | or r + 2 * step_r >= max_r
248 | ):
249 | spp = p
250 | isSppFound = True
251 | if not isEppFound:
252 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
253 | dy = grad_e * dx
254 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
255 | cv2.line(
256 | line_img,
257 | (int(p[0]), int(p[1])),
258 | (int(p[2]), int(p[3])),
259 | 1,
260 | thickness=1,
261 | )
262 | if (
263 | np.sum(np.logical_and(word_label, line_img)) == 0
264 | or r + 2 * step_r >= max_r
265 | ):
266 | epp = p
267 | isEppFound = True
268 | if isSppFound and isEppFound:
269 | break
270 |
271 | # pass if boundary of polygon is not found
272 | if not (isSppFound and isEppFound):
273 | polys.append(None)
274 | continue
275 |
276 | # make final polygon
277 | poly = []
278 | poly.append(warp_coord(Minv, (spp[0], spp[1])))
279 | for p in new_pp:
280 | poly.append(warp_coord(Minv, (p[0], p[1])))
281 | poly.append(warp_coord(Minv, (epp[0], epp[1])))
282 | poly.append(warp_coord(Minv, (epp[2], epp[3])))
283 | for p in reversed(new_pp):
284 | poly.append(warp_coord(Minv, (p[2], p[3])))
285 | poly.append(warp_coord(Minv, (spp[2], spp[3])))
286 |
287 | # add to final result
288 | polys.append(np.array(poly))
289 |
290 | return polys
291 |
292 |
293 | def get_det_boxes(
294 | textmap,
295 | linkmap,
296 | text_threshold,
297 | link_threshold,
298 | low_text,
299 | poly=False,
300 | ):
301 | boxes, labels, mapper = get_det_boxes_core(
302 | textmap,
303 | linkmap,
304 | text_threshold,
305 | link_threshold,
306 | low_text,
307 | )
308 |
309 | if poly:
310 | polys = get_poly_core(boxes, labels, mapper, linkmap)
311 | else:
312 | polys = [None] * len(boxes)
313 |
314 | return boxes, polys
315 |
316 |
317 | def adjust_result_coordinates(polys, ratio_w, ratio_h, ratio_net=2):
318 | if len(polys) > 0:
319 | polys = np.array(polys)
320 | for k in range(len(polys)):
321 | if polys[k] is not None:
322 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
323 | return polys
324 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/detection.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/detection.py
10 | """
11 |
12 | from collections import OrderedDict
13 |
14 | import cv2
15 | import numpy as np
16 | import torch
17 | import torch.backends.cudnn as cudnn
18 | from torch.autograd import Variable
19 |
20 | from .craft import CRAFT
21 | from .craft_utils import adjust_result_coordinates, get_det_boxes
22 | from .imgproc import normalize_mean_variance, resize_aspect_ratio
23 |
24 |
25 | def copy_state_dict(state_dict):
26 | if list(state_dict.keys())[0].startswith("module"):
27 | start_idx = 1
28 | else:
29 | start_idx = 0
30 | new_state_dict = OrderedDict()
31 | for k, v in state_dict.items():
32 | name = ".".join(k.split(".")[start_idx:])
33 | new_state_dict[name] = v
34 | return new_state_dict
35 |
36 |
37 | def test_net(image: np.ndarray, net, opt2val: dict):
38 | canvas_size = opt2val["canvas_size"]
39 | mag_ratio = opt2val["mag_ratio"]
40 | text_threshold = opt2val["text_threshold"]
41 | link_threshold = opt2val["link_threshold"]
42 | low_text = opt2val["low_text"]
43 | device = opt2val["device"]
44 |
45 | # resize
46 | img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
47 | image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
48 | )
49 | ratio_h = ratio_w = 1 / target_ratio
50 |
51 | # preprocessing
52 | x = normalize_mean_variance(img_resized)
53 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
54 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
55 | x = x.to(device)
56 |
57 | # forward pass
58 | with torch.no_grad():
59 | y, feature = net(x)
60 |
61 | # make score and link map
62 | score_text = y[0, :, :, 0].cpu().data.numpy()
63 | score_link = y[0, :, :, 1].cpu().data.numpy()
64 |
65 | # Post-processing
66 | boxes, polys = get_det_boxes(
67 | score_text,
68 | score_link,
69 | text_threshold,
70 | link_threshold,
71 | low_text,
72 | )
73 |
74 | # coordinate adjustment
75 | boxes = adjust_result_coordinates(boxes, ratio_w, ratio_h)
76 | polys = adjust_result_coordinates(polys, ratio_w, ratio_h)
77 | for k in range(len(polys)):
78 | if polys[k] is None:
79 | polys[k] = boxes[k]
80 |
81 | return boxes, polys
82 |
83 |
84 | def get_detector(det_model_ckpt_fp: str, device: str = "cpu"):
85 | net = CRAFT()
86 |
87 | net.load_state_dict(
88 | copy_state_dict(torch.load(det_model_ckpt_fp, map_location=device))
89 | )
90 | if device == "cuda":
91 | net = torch.nn.DataParallel(net).to(device)
92 | cudnn.benchmark = False
93 |
94 | net.eval()
95 | return net
96 |
97 |
98 | def get_textbox(detector, image: np.ndarray, opt2val: dict):
99 | bboxes, polys = test_net(image, detector, opt2val)
100 | result = []
101 | for i, box in enumerate(polys):
102 | poly = np.array(box).astype(np.int32).reshape((-1))
103 | result.append(poly)
104 |
105 | return result
106 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/imgproc.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/imgproc.py
10 | Copyright (c) 2019-present NAVER Corp.
11 | MIT License
12 | """
13 |
14 | import cv2
15 | import numpy as np
16 | from skimage import io
17 |
18 |
19 | def load_image(img_file):
20 | img = io.imread(img_file) # RGB order
21 | if img.shape[0] == 2:
22 | img = img[0]
23 | if len(img.shape) == 2:
24 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
25 | if img.shape[2] == 4:
26 | img = img[:, :, :3]
27 | img = np.array(img)
28 |
29 | return img
30 |
31 |
32 | def normalize_mean_variance(
33 | in_img,
34 | mean=(0.485, 0.456, 0.406),
35 | variance=(0.229, 0.224, 0.225),
36 | ):
37 | # should be RGB order
38 | img = in_img.copy().astype(np.float32)
39 |
40 | img -= np.array(
41 | [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
42 | )
43 | img /= np.array(
44 | [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
45 | dtype=np.float32,
46 | )
47 | return img
48 |
49 |
50 | def denormalize_mean_variance(
51 | in_img,
52 | mean=(0.485, 0.456, 0.406),
53 | variance=(0.229, 0.224, 0.225),
54 | ):
55 | # should be RGB order
56 | img = in_img.copy()
57 | img *= variance
58 | img += mean
59 | img *= 255.0
60 | img = np.clip(img, 0, 255).astype(np.uint8)
61 | return img
62 |
63 |
64 | def resize_aspect_ratio(
65 | img: np.ndarray,
66 | square_size: int,
67 | interpolation: int,
68 | mag_ratio: float = 1.0,
69 | ):
70 | height, width, channel = img.shape
71 |
72 | # magnify image size
73 | target_size = mag_ratio * max(height, width)
74 |
75 | # set original image size
76 | if target_size > square_size:
77 | target_size = square_size
78 |
79 | ratio = target_size / max(height, width)
80 |
81 | target_h, target_w = int(height * ratio), int(width * ratio)
82 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
83 |
84 | # make canvas and paste image
85 | target_h32, target_w32 = target_h, target_w
86 | if target_h % 32 != 0:
87 | target_h32 = target_h + (32 - target_h % 32)
88 | if target_w % 32 != 0:
89 | target_w32 = target_w + (32 - target_w % 32)
90 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
91 | resized[0:target_h, 0:target_w, :] = proc
92 | target_h, target_w = target_h32, target_w32
93 |
94 | size_heatmap = (int(target_w / 2), int(target_h / 2))
95 |
96 | return resized, ratio, size_heatmap
97 |
98 |
99 | def cvt2heatmap_img(img):
100 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
101 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
102 | return img
103 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/model.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from
10 | https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/model.py
11 | """
12 |
13 | import torch.nn as nn
14 | from torch import Tensor
15 |
16 | from .modules.feature_extraction import (
17 | ResNetFeatureExtractor,
18 | VGGFeatureExtractor,
19 | )
20 | from .modules.prediction import Attention
21 | from .modules.sequence_modeling import BidirectionalLSTM
22 | from .modules.transformation import TpsSpatialTransformerNetwork
23 |
24 |
25 | class Model(nn.Module):
26 | def __init__(self, opt2val: dict):
27 | super(Model, self).__init__()
28 |
29 | input_channel = opt2val["input_channel"]
30 | output_channel = opt2val["output_channel"]
31 | hidden_size = opt2val["hidden_size"]
32 | vocab_size = opt2val["vocab_size"]
33 | num_fiducial = opt2val["num_fiducial"]
34 | imgH = opt2val["imgH"]
35 | imgW = opt2val["imgW"]
36 | FeatureExtraction = opt2val["FeatureExtraction"]
37 | Transformation = opt2val["Transformation"]
38 | SequenceModeling = opt2val["SequenceModeling"]
39 | Prediction = opt2val["Prediction"]
40 |
41 | # Transformation
42 | if Transformation == "TPS":
43 | self.Transformation = TpsSpatialTransformerNetwork(
44 | F=num_fiducial,
45 | I_size=(imgH, imgW),
46 | I_r_size=(imgH, imgW),
47 | I_channel_num=input_channel,
48 | )
49 | else:
50 | print("No Transformation module specified")
51 |
52 | # FeatureExtraction
53 | if FeatureExtraction == "VGG":
54 | extractor = VGGFeatureExtractor
55 | else: # ResNet
56 | extractor = ResNetFeatureExtractor
57 | self.FeatureExtraction = extractor(
58 | input_channel,
59 | output_channel,
60 | opt2val,
61 | )
62 | self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
63 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
64 | (None, 1)
65 | ) # Transform final (imgH/16-1) -> 1
66 |
67 | # Sequence modeling
68 | if SequenceModeling == "BiLSTM":
69 | self.SequenceModeling = nn.Sequential(
70 | BidirectionalLSTM(
71 | self.FeatureExtraction_output,
72 | hidden_size,
73 | hidden_size,
74 | ),
75 | BidirectionalLSTM(hidden_size, hidden_size, hidden_size),
76 | )
77 | self.SequenceModeling_output = hidden_size
78 | else:
79 | print("No SequenceModeling module specified")
80 | self.SequenceModeling_output = self.FeatureExtraction_output
81 |
82 | # Prediction
83 | if Prediction == "CTC":
84 | self.Prediction = nn.Linear(
85 | self.SequenceModeling_output,
86 | vocab_size,
87 | )
88 | elif Prediction == "Attn":
89 | self.Prediction = Attention(
90 | self.SequenceModeling_output,
91 | hidden_size,
92 | vocab_size,
93 | )
94 | elif Prediction == "Transformer": # TODO
95 | pass
96 | else:
97 | raise Exception("Prediction is neither CTC or Attn")
98 |
99 | def forward(self, x: Tensor):
100 | """
101 | :param x: (batch, input_channel, height, width)
102 | :return:
103 | """
104 | # Transformation stage
105 | x = self.Transformation(x)
106 |
107 | # Feature extraction stage
108 | visual_feature = self.FeatureExtraction(x) # (b, output_channel=512, h=3, w)
109 | visual_feature = self.AdaptiveAvgPool(
110 | visual_feature.permute(0, 3, 1, 2)
111 | ) # (b, w, channel=512, h=1)
112 | visual_feature = visual_feature.squeeze(3) # (b, w, channel=512)
113 |
114 | # Sequence modeling stage
115 | self.SequenceModeling.eval()
116 | contextual_feature = self.SequenceModeling(visual_feature)
117 |
118 | # Prediction stage
119 | prediction = self.Prediction(
120 | contextual_feature.contiguous()
121 | ) # (b, T, num_classes)
122 |
123 | return prediction
124 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/basenet.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from collections import namedtuple
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 | from torchvision import models
14 | from torchvision.models.vgg import model_urls
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 |
19 | def init_weights(modules):
20 | for m in modules:
21 | if isinstance(m, nn.Conv2d):
22 | init.xavier_uniform_(m.weight.data)
23 | if m.bias is not None:
24 | m.bias.data.zero_()
25 | elif isinstance(m, nn.BatchNorm2d):
26 | m.weight.data.fill_(1)
27 | m.bias.data.zero_()
28 | elif isinstance(m, nn.Linear):
29 | m.weight.data.normal_(0, 0.01)
30 | m.bias.data.zero_()
31 |
32 |
33 | class Vgg16BN(torch.nn.Module):
34 | def __init__(self, pretrained: bool = True, freeze: bool = True):
35 | super(Vgg16BN, self).__init__()
36 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace("https://", "http://")
37 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
38 | self.slice1 = torch.nn.Sequential()
39 | self.slice2 = torch.nn.Sequential()
40 | self.slice3 = torch.nn.Sequential()
41 | self.slice4 = torch.nn.Sequential()
42 | self.slice5 = torch.nn.Sequential()
43 | for x in range(12): # conv2_2
44 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
45 | for x in range(12, 19): # conv3_3
46 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
47 | for x in range(19, 29): # conv4_3
48 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
49 | for x in range(29, 39): # conv5_3
50 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
51 |
52 | # fc6, fc7 without atrous conv
53 | self.slice5 = torch.nn.Sequential(
54 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
55 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
56 | nn.Conv2d(1024, 1024, kernel_size=1),
57 | )
58 |
59 | if not pretrained:
60 | init_weights(self.slice1.modules())
61 | init_weights(self.slice2.modules())
62 | init_weights(self.slice3.modules())
63 | init_weights(self.slice4.modules())
64 |
65 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
66 |
67 | if freeze:
68 | for param in self.slice1.parameters(): # only first conv
69 | param.requires_grad = False
70 |
71 | def forward(self, x):
72 | h = self.slice1(x)
73 | h_relu2_2 = h
74 | h = self.slice2(h)
75 | h_relu3_2 = h
76 | h = self.slice3(h)
77 | h_relu4_3 = h
78 | h = self.slice4(h)
79 | h_relu5_3 = h
80 | h = self.slice5(h)
81 | h_fc7 = h
82 | vgg_outputs = namedtuple(
83 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
84 | )
85 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
86 | return out
87 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/feature_extraction.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import torch.nn as nn
9 |
10 |
11 | class VGGFeatureExtractor(nn.Module):
12 | """FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf)"""
13 |
14 | def __init__(
15 | self, n_input_channels: int = 1, n_output_channels: int = 512, opt2val=None
16 | ):
17 | super(VGGFeatureExtractor, self).__init__()
18 |
19 | self.output_channel = [
20 | int(n_output_channels / 8),
21 | int(n_output_channels / 4),
22 | int(n_output_channels / 2),
23 | n_output_channels,
24 | ] # [64, 128, 256, 512]
25 |
26 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"]
27 | if "baseline" in rec_model_ckpt_fp:
28 | self.ConvNet = nn.Sequential(
29 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
30 | nn.ReLU(True),
31 | nn.MaxPool2d(2, 2), # 64x16x50
32 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1),
33 | nn.ReLU(True),
34 | nn.MaxPool2d(2, 2), # 128x8x25
35 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1),
36 | nn.ReLU(True), # 256x8x25
37 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1),
38 | nn.ReLU(True),
39 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
40 | nn.Conv2d(
41 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False
42 | ),
43 | nn.BatchNorm2d(self.output_channel[3]),
44 | nn.ReLU(True), # 512x4x25
45 | nn.Conv2d(
46 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False
47 | ),
48 | nn.BatchNorm2d(self.output_channel[3]),
49 | nn.ReLU(True),
50 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
51 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
52 | nn.ConvTranspose2d(
53 | self.output_channel[3], self.output_channel[3], 2, 2
54 | ),
55 | nn.ReLU(True),
56 | ) # 512x4x50
57 | else:
58 | self.ConvNet = nn.Sequential(
59 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
60 | nn.ReLU(True),
61 | nn.MaxPool2d(2, 2), # 64x16x50
62 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1),
63 | nn.ReLU(True),
64 | nn.MaxPool2d(2, 2), # 128x8x25
65 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1),
66 | nn.ReLU(True), # 256x8x25
67 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1),
68 | nn.ReLU(True),
69 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
70 | nn.Conv2d(
71 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False
72 | ),
73 | nn.BatchNorm2d(self.output_channel[3]),
74 | nn.ReLU(True), # 512x4x25
75 | nn.Conv2d(
76 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False
77 | ),
78 | nn.BatchNorm2d(self.output_channel[3]),
79 | nn.ReLU(True),
80 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
81 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
82 | nn.ConvTranspose2d(
83 | self.output_channel[3], self.output_channel[3], 2, 2
84 | ),
85 | nn.ReLU(True), # 512x4x50
86 | nn.ConvTranspose2d(
87 | self.output_channel[3], self.output_channel[3], 2, 2
88 | ),
89 | nn.ReLU(True),
90 | ) # 512x4x50
91 |
92 | def forward(self, x):
93 | return self.ConvNet(x)
94 |
95 |
96 | class ResNetFeatureExtractor(nn.Module):
97 | """
98 | FeatureExtractor of FAN
99 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf)
100 | """
101 |
102 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
103 | super(ResNetFeatureExtractor, self).__init__()
104 | self.ConvNet = ResNet(
105 | n_input_channels, n_output_channels, BasicBlock, [1, 2, 5, 3]
106 | )
107 |
108 | def forward(self, inputs):
109 | return self.ConvNet(inputs)
110 |
111 |
112 | class BasicBlock(nn.Module):
113 | expansion = 1
114 |
115 | def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
116 | super(BasicBlock, self).__init__()
117 | self.conv1 = self._conv3x3(inplanes, planes)
118 | self.bn1 = nn.BatchNorm2d(planes)
119 | self.conv2 = self._conv3x3(planes, planes)
120 | self.bn2 = nn.BatchNorm2d(planes)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.downsample = downsample
123 | self.stride = stride
124 |
125 | def _conv3x3(self, in_planes, out_planes, stride=1):
126 | "3x3 convolution with padding"
127 | return nn.Conv2d(
128 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
129 | )
130 |
131 | def forward(self, x):
132 | residual = x
133 |
134 | out = self.conv1(x)
135 | out = self.bn1(out)
136 | out = self.relu(out)
137 |
138 | out = self.conv2(out)
139 | out = self.bn2(out)
140 |
141 | if self.downsample is not None:
142 | residual = self.downsample(x)
143 | out += residual
144 | out = self.relu(out)
145 |
146 | return out
147 |
148 |
149 | class ResNet(nn.Module):
150 | def __init__(self, n_input_channels: int, n_output_channels: int, block, layers):
151 | """
152 | :param n_input_channels (int): The number of input channels of the feature extractor
153 | :param n_output_channels (int): The number of output channels of the feature extractor
154 | :param block:
155 | :param layers:
156 | """
157 | super(ResNet, self).__init__()
158 |
159 | self.output_channel_blocks = [
160 | int(n_output_channels / 4),
161 | int(n_output_channels / 2),
162 | n_output_channels,
163 | n_output_channels,
164 | ]
165 |
166 | self.inplanes = int(n_output_channels / 8)
167 | self.conv0_1 = nn.Conv2d(
168 | n_input_channels,
169 | int(n_output_channels / 16),
170 | kernel_size=3,
171 | stride=1,
172 | padding=1,
173 | bias=False,
174 | )
175 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16))
176 | self.conv0_2 = nn.Conv2d(
177 | int(n_output_channels / 16),
178 | self.inplanes,
179 | kernel_size=3,
180 | stride=1,
181 | padding=1,
182 | bias=False,
183 | )
184 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
185 | self.relu = nn.ReLU(inplace=True)
186 |
187 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
188 | self.layer1 = self._make_layer(block, self.output_channel_blocks[0], layers[0])
189 | self.conv1 = nn.Conv2d(
190 | self.output_channel_blocks[0],
191 | self.output_channel_blocks[0],
192 | kernel_size=3,
193 | stride=1,
194 | padding=1,
195 | bias=False,
196 | )
197 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0])
198 |
199 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
200 | self.layer2 = self._make_layer(
201 | block, self.output_channel_blocks[1], layers[1], stride=1
202 | )
203 | self.conv2 = nn.Conv2d(
204 | self.output_channel_blocks[1],
205 | self.output_channel_blocks[1],
206 | kernel_size=3,
207 | stride=1,
208 | padding=1,
209 | bias=False,
210 | )
211 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1])
212 |
213 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
214 | self.layer3 = self._make_layer(
215 | block, self.output_channel_blocks[2], layers[2], stride=1
216 | )
217 | self.conv3 = nn.Conv2d(
218 | self.output_channel_blocks[2],
219 | self.output_channel_blocks[2],
220 | kernel_size=3,
221 | stride=1,
222 | padding=1,
223 | bias=False,
224 | )
225 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2])
226 |
227 | self.layer4 = self._make_layer(
228 | block, self.output_channel_blocks[3], layers[3], stride=1
229 | )
230 | self.conv4_1 = nn.Conv2d(
231 | self.output_channel_blocks[3],
232 | self.output_channel_blocks[3],
233 | kernel_size=2,
234 | stride=(2, 1),
235 | padding=(0, 1),
236 | bias=False,
237 | )
238 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3])
239 | self.conv4_2 = nn.Conv2d(
240 | self.output_channel_blocks[3],
241 | self.output_channel_blocks[3],
242 | kernel_size=2,
243 | stride=1,
244 | padding=0,
245 | bias=False,
246 | )
247 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3])
248 |
249 | def _make_layer(self, block, planes, blocks, stride=1):
250 | downsample = None
251 | if stride != 1 or self.inplanes != planes * block.expansion:
252 | downsample = nn.Sequential(
253 | nn.Conv2d(
254 | self.inplanes,
255 | planes * block.expansion,
256 | kernel_size=1,
257 | stride=stride,
258 | bias=False,
259 | ),
260 | nn.BatchNorm2d(planes * block.expansion),
261 | )
262 |
263 | layers = []
264 | layers.append(block(self.inplanes, planes, stride, downsample))
265 | self.inplanes = planes * block.expansion
266 | for i in range(1, blocks):
267 | layers.append(block(self.inplanes, planes))
268 |
269 | return nn.Sequential(*layers)
270 |
271 | def forward(self, x):
272 | x = self.conv0_1(x)
273 | x = self.bn0_1(x)
274 | x = self.relu(x)
275 | x = self.conv0_2(x)
276 | x = self.bn0_2(x)
277 | x = self.relu(x)
278 |
279 | x = self.maxpool1(x)
280 | x = self.layer1(x)
281 | x = self.conv1(x)
282 | x = self.bn1(x)
283 | x = self.relu(x)
284 |
285 | x = self.maxpool2(x)
286 | x = self.layer2(x)
287 | x = self.conv2(x)
288 | x = self.bn2(x)
289 | x = self.relu(x)
290 |
291 | x = self.maxpool3(x)
292 | x = self.layer3(x)
293 | x = self.conv3(x)
294 | x = self.bn3(x)
295 | x = self.relu(x)
296 |
297 | x = self.layer4(x)
298 | x = self.conv4_1(x)
299 | x = self.bn4_1(x)
300 | x = self.relu(x)
301 | x = self.conv4_2(x)
302 | x = self.bn4_2(x)
303 | x = self.relu(x)
304 |
305 | return x
306 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/prediction.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | class Attention(nn.Module):
16 | def __init__(self, input_size, hidden_size, num_classes):
17 | super(Attention, self).__init__()
18 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
19 | self.hidden_size = hidden_size
20 | self.num_classes = num_classes
21 | self.generator = nn.Linear(hidden_size, num_classes)
22 |
23 | def _char_to_onehot(self, input_char, onehot_dim=38):
24 | input_char = input_char.unsqueeze(1)
25 | batch_size = input_char.size(0)
26 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
27 | one_hot = one_hot.scatter_(1, input_char, 1)
28 | return one_hot
29 |
30 | def forward(self, batch_H, text, is_train=True, batch_max_length=25):
31 | """
32 | input:
33 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
34 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
35 | output: probability distribution at each step [batch_size x num_steps x num_classes]
36 | """
37 | batch_size = batch_H.size(0)
38 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
39 |
40 | output_hiddens = (
41 | torch.FloatTensor(batch_size, num_steps, self.hidden_size)
42 | .fill_(0)
43 | .to(device)
44 | )
45 | hidden = (
46 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
47 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
48 | )
49 |
50 | if is_train:
51 | for i in range(num_steps):
52 | # one-hot vectors for a i-th char. in a batch
53 | char_onehots = self._char_to_onehot(
54 | text[:, i], onehot_dim=self.num_classes
55 | )
56 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
57 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
58 | output_hiddens[:, i, :] = hidden[
59 | 0
60 | ] # LSTM hidden index (0: hidden, 1: Cell)
61 | probs = self.generator(output_hiddens)
62 |
63 | else:
64 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
65 | probs = (
66 | torch.FloatTensor(batch_size, num_steps, self.num_classes)
67 | .fill_(0)
68 | .to(device)
69 | )
70 |
71 | for i in range(num_steps):
72 | char_onehots = self._char_to_onehot(
73 | targets, onehot_dim=self.num_classes
74 | )
75 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
76 | probs_step = self.generator(hidden[0])
77 | probs[:, i, :] = probs_step
78 | _, next_input = probs_step.max(1)
79 | targets = next_input
80 |
81 | return probs # batch_size x num_steps x num_classes
82 |
83 |
84 | class AttentionCell(nn.Module):
85 | def __init__(self, input_size, hidden_size, num_embeddings):
86 | super(AttentionCell, self).__init__()
87 | self.i2h = nn.Linear(input_size, hidden_size, bias=False)
88 | self.h2h = nn.Linear(
89 | hidden_size, hidden_size
90 | ) # either i2i or h2h should have bias
91 | self.score = nn.Linear(hidden_size, 1, bias=False)
92 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
93 | self.hidden_size = hidden_size
94 |
95 | def forward(self, prev_hidden, batch_H, char_onehots):
96 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
97 | batch_H_proj = self.i2h(batch_H)
98 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
99 | e = self.score(
100 | torch.tanh(batch_H_proj + prev_hidden_proj)
101 | ) # batch_size x num_encoder_step * 1
102 |
103 | alpha = F.softmax(e, dim=1)
104 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(
105 | 1
106 | ) # batch_size x num_channel
107 | concat_context = torch.cat(
108 | [context, char_onehots], 1
109 | ) # batch_size x (num_channel + num_embedding)
110 | cur_hidden = self.rnn(concat_context, prev_hidden)
111 | return cur_hidden, alpha
112 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/sequence_modeling.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import torch.nn as nn
9 |
10 |
11 | class BidirectionalLSTM(nn.Module):
12 | def __init__(self, input_size: int, hidden_size: int, output_size: int):
13 | super(BidirectionalLSTM, self).__init__()
14 | self.rnn = nn.LSTM(
15 | input_size, hidden_size, bidirectional=True, batch_first=True
16 | )
17 | self.linear = nn.Linear(hidden_size * 2, output_size)
18 |
19 | def forward(self, x):
20 | """
21 | x : visual feature [batch_size x T=24 x input_size=512]
22 | output : contextual feature [batch_size x T x output_size]
23 | """
24 | self.rnn.flatten_parameters()
25 | recurrent, _ = self.rnn(
26 | x
27 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
28 | output = self.linear(recurrent) # batch_size x T x output_size
29 | return output
30 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/modules/transformation.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | class TpsSpatialTransformerNetwork(nn.Module):
17 | """Rectification Network of RARE, namely TPS based STN"""
18 |
19 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1):
20 | """Based on RARE TPS
21 | input:
22 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
23 | I_size : (height, width) of the input image I
24 | I_r_size : (height, width) of the rectified image I_r
25 | I_channel_num : the number of channels of the input image I
26 | output:
27 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
28 | """
29 | super(TpsSpatialTransformerNetwork, self).__init__()
30 | self.F = F
31 | self.I_size = I_size
32 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
33 | self.I_channel_num = I_channel_num
34 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
35 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
36 |
37 | def forward(self, batch_I):
38 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
39 | build_P_prime = self.GridGenerator.build_P_prime(
40 | batch_C_prime
41 | ) # batch_size x n (= I_r_width x I_r_height) x 2
42 | build_P_prime_reshape = build_P_prime.reshape(
43 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]
44 | )
45 |
46 | # if torch.__version__ > "1.2.0":
47 | # batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
48 | # else:
49 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode="border")
50 |
51 | return batch_I_r
52 |
53 |
54 | class LocalizationNetwork(nn.Module):
55 | """Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height)"""
56 |
57 | def __init__(self, F, I_channel_num: int):
58 | super(LocalizationNetwork, self).__init__()
59 | self.F = F
60 | self.I_channel_num = I_channel_num
61 | self.conv = nn.Sequential(
62 | nn.Conv2d(
63 | in_channels=self.I_channel_num,
64 | out_channels=64,
65 | kernel_size=3,
66 | stride=1,
67 | padding=1,
68 | bias=False,
69 | ),
70 | nn.BatchNorm2d(64),
71 | nn.ReLU(True),
72 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
73 | nn.Conv2d(64, 128, 3, 1, 1, bias=False),
74 | nn.BatchNorm2d(128),
75 | nn.ReLU(True),
76 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
77 | nn.Conv2d(128, 256, 3, 1, 1, bias=False),
78 | nn.BatchNorm2d(256),
79 | nn.ReLU(True),
80 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
81 | nn.Conv2d(256, 512, 3, 1, 1, bias=False),
82 | nn.BatchNorm2d(512),
83 | nn.ReLU(True),
84 | nn.AdaptiveAvgPool2d(1), # batch_size x 512
85 | )
86 |
87 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
88 | self.localization_fc2 = nn.Linear(256, self.F * 2)
89 |
90 | # Init fc2 in LocalizationNetwork
91 | self.localization_fc2.weight.data.fill_(0)
92 |
93 | # see RARE paper Fig. 6 (a)
94 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
95 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
96 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
97 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
98 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
99 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
100 | self.localization_fc2.bias.data = (
101 | torch.from_numpy(initial_bias).float().view(-1)
102 | )
103 |
104 | def forward(self, batch_I):
105 | """
106 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
107 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
108 | """
109 | batch_size = batch_I.size(0)
110 | features = self.conv(batch_I).view(batch_size, -1)
111 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(
112 | batch_size, self.F, 2
113 | )
114 | return batch_C_prime
115 |
116 |
117 | class GridGenerator(nn.Module):
118 | """Grid Generator of RARE, which produces P_prime by multipling T with P"""
119 |
120 | def __init__(self, F, I_r_size):
121 | """Generate P_hat and inv_delta_C for later"""
122 | super(GridGenerator, self).__init__()
123 | self.eps = 1e-6
124 | self.I_r_height, self.I_r_width = I_r_size
125 | self.F = F
126 | self.C = self._build_C(self.F) # F x 2
127 | self.P = self._build_P(self.I_r_width, self.I_r_height)
128 |
129 | # for multi-gpu, you need register buffer
130 | self.register_buffer(
131 | "inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()
132 | ) # F+3 x F+3
133 | self.register_buffer(
134 | "P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()
135 | ) # n x F+3
136 |
137 | def _build_C(self, F):
138 | """Return coordinates of fiducial points in I_r; C"""
139 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
140 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
141 | ctrl_pts_y_bottom = np.ones(int(F / 2))
142 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
143 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
144 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
145 | return C # F x 2
146 |
147 | def _build_inv_delta_C(self, F, C):
148 | """Return inv_delta_C which is needed to calculate T"""
149 | hat_C = np.zeros((F, F), dtype=float) # F x F
150 | for i in range(0, F):
151 | for j in range(i, F):
152 | r = np.linalg.norm(C[i] - C[j])
153 | hat_C[i, j] = r
154 | hat_C[j, i] = r
155 | np.fill_diagonal(hat_C, 1)
156 | hat_C = (hat_C**2) * np.log(hat_C)
157 | # print(C.shape, hat_C.shape)
158 | delta_C = np.concatenate( # F+3 x F+3
159 | [
160 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
161 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
162 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
163 | ],
164 | axis=0,
165 | )
166 | inv_delta_C = np.linalg.inv(delta_C)
167 | return inv_delta_C # F+3 x F+3
168 |
169 | def _build_P(self, I_r_width, I_r_height):
170 | I_r_grid_x = (
171 | np.arange(-I_r_width, I_r_width, 2) + 1.0
172 | ) / I_r_width # self.I_r_width
173 | I_r_grid_y = (
174 | np.arange(-I_r_height, I_r_height, 2) + 1.0
175 | ) / I_r_height # self.I_r_height
176 | P = np.stack( # self.I_r_width x self.I_r_height x 2
177 | np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2
178 | )
179 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
180 |
181 | def _build_P_hat(self, F, C, P):
182 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
183 | P_tile = np.tile(
184 | np.expand_dims(P, axis=1), (1, F, 1)
185 | ) # n x 2 -> n x 1 x 2 -> n x F x 2
186 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
187 | P_diff = P_tile - C_tile # n x F x 2
188 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
189 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
190 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
191 | return P_hat # n x F+3
192 |
193 | def build_P_prime(self, batch_C_prime):
194 | """Generate Grid from batch_C_prime [batch_size x F x 2]"""
195 | batch_size = batch_C_prime.size(0)
196 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
197 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
198 | batch_C_prime_with_zeros = torch.cat(
199 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), dim=1
200 | ) # batch_size x F+3 x 2
201 | batch_T = torch.bmm(
202 | batch_inv_delta_C, batch_C_prime_with_zeros
203 | ) # batch_size x F+3 x 2
204 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
205 | return batch_P_prime # batch_size x n x 2
206 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/recognition.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/recognition.py
10 | """
11 |
12 | import math
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn.functional as F
17 | import torch.utils.data
18 | import torchvision.transforms as transforms
19 | from PIL import Image
20 |
21 | from .model import Model
22 | from .utils import CTCLabelConverter
23 |
24 |
25 | def contrast_grey(img):
26 | high = np.percentile(img, 90)
27 | low = np.percentile(img, 10)
28 | return (high - low) / np.maximum(10, high + low), high, low
29 |
30 |
31 | def adjust_contrast_grey(img, target: float = 0.4):
32 | contrast, high, low = contrast_grey(img)
33 | if contrast < target:
34 | img = img.astype(int)
35 | ratio = 200.0 / np.maximum(10, high - low)
36 | img = (img - low + 25) * ratio
37 | img = np.maximum(
38 | np.full(img.shape, 0),
39 | np.minimum(
40 | np.full(img.shape, 255),
41 | img,
42 | ),
43 | ).astype(np.uint8)
44 | return img
45 |
46 |
47 | class NormalizePAD(object):
48 | def __init__(self, max_size, PAD_type: str = "right"):
49 | self.toTensor = transforms.ToTensor()
50 | self.max_size = max_size
51 | self.max_width_half = math.floor(max_size[2] / 2)
52 | self.PAD_type = PAD_type
53 |
54 | def __call__(self, img):
55 | img = self.toTensor(img)
56 | img.sub_(0.5).div_(0.5)
57 | c, h, w = img.size()
58 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
59 | Pad_img[:, :, :w] = img # right pad
60 | if self.max_size[2] != w: # add border Pad
61 | Pad_img[:, :, w:] = (
62 | img[:, :, w - 1]
63 | .unsqueeze(2)
64 | .expand(
65 | c,
66 | h,
67 | self.max_size[2] - w,
68 | )
69 | )
70 |
71 | return Pad_img
72 |
73 |
74 | class ListDataset(torch.utils.data.Dataset):
75 | def __init__(self, image_list: list):
76 | self.image_list = image_list
77 | self.nSamples = len(image_list)
78 |
79 | def __len__(self):
80 | return self.nSamples
81 |
82 | def __getitem__(self, index):
83 | img = self.image_list[index]
84 | return Image.fromarray(img, "L")
85 |
86 |
87 | class AlignCollate(object):
88 | def __init__(self, imgH: int, imgW: int, adjust_contrast: float):
89 | self.imgH = imgH
90 | self.imgW = imgW
91 | self.keep_ratio_with_pad = True # Do Not Change
92 | self.adjust_contrast = adjust_contrast
93 |
94 | def __call__(self, batch):
95 | batch = filter(lambda x: x is not None, batch)
96 | images = batch
97 |
98 | resized_max_w = self.imgW
99 | input_channel = 1
100 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
101 |
102 | resized_images = []
103 | for image in images:
104 | w, h = image.size
105 | # augmentation here - change contrast
106 | if self.adjust_contrast > 0:
107 | image = np.array(image.convert("L"))
108 | image = adjust_contrast_grey(image, target=self.adjust_contrast)
109 | image = Image.fromarray(image, "L")
110 |
111 | ratio = w / float(h)
112 | if math.ceil(self.imgH * ratio) > self.imgW:
113 | resized_w = self.imgW
114 | else:
115 | resized_w = math.ceil(self.imgH * ratio)
116 |
117 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
118 | resized_images.append(transform(resized_image))
119 |
120 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
121 | return image_tensors
122 |
123 |
124 | def recognizer_predict(model, converter, test_loader, opt2val: dict):
125 | device = opt2val["device"]
126 |
127 | model.eval()
128 | result = []
129 | with torch.no_grad():
130 | for image_tensors in test_loader:
131 | batch_size = image_tensors.size(0)
132 | inputs = image_tensors.to(device)
133 | preds = model(inputs) # (N, length, num_classes)
134 |
135 | # rebalance
136 | preds_prob = F.softmax(preds, dim=2)
137 | preds_prob = preds_prob.cpu().detach().numpy()
138 | pred_norm = preds_prob.sum(axis=2)
139 | preds_prob = preds_prob / np.expand_dims(pred_norm, axis=-1)
140 | preds_prob = torch.from_numpy(preds_prob).float().to(device)
141 |
142 | # Select max probabilty (greedy decoding), then decode index to character
143 | preds_lengths = torch.IntTensor([preds.size(1)] * batch_size) # (N,)
144 | _, preds_indices = preds_prob.max(2) # (N, length)
145 | preds_indices = preds_indices.view(-1) # (N*length)
146 | preds_str = converter.decode_greedy(preds_indices, preds_lengths)
147 |
148 | preds_max_prob, _ = preds_prob.max(dim=2)
149 |
150 | for pred, pred_max_prob in zip(preds_str, preds_max_prob):
151 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
152 | result.append([pred, confidence_score.item()])
153 |
154 | return result
155 |
156 |
157 | def get_recognizer(opt2val: dict):
158 | """
159 | :return:
160 | recognizer: recognition net
161 | converter: CTCLabelConverter
162 | """
163 | # converter
164 | vocab = opt2val["vocab"]
165 | converter = CTCLabelConverter(vocab)
166 |
167 | # recognizer
168 | recognizer = Model(opt2val)
169 |
170 | # state_dict
171 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"]
172 | device = opt2val["device"]
173 | state_dict = torch.load(rec_model_ckpt_fp, map_location=device)
174 |
175 | if device == "cuda":
176 | recognizer = torch.nn.DataParallel(recognizer).to(device)
177 | else:
178 | # TODO temporary: multigpu 학습한 뒤 ckpt loading 문제
179 | from collections import OrderedDict
180 |
181 | def _sync_tensor_name(state_dict):
182 | state_dict_ = OrderedDict()
183 | for name, val in state_dict.items():
184 | name = name.replace("module.", "")
185 | state_dict_[name] = val
186 | return state_dict_
187 |
188 | state_dict = _sync_tensor_name(state_dict)
189 |
190 | recognizer.load_state_dict(state_dict)
191 |
192 | return recognizer, converter
193 |
194 |
195 | def get_text(image_list, recognizer, converter, opt2val: dict):
196 | imgW = opt2val["imgW"]
197 | imgH = opt2val["imgH"]
198 | adjust_contrast = opt2val["adjust_contrast"]
199 | batch_size = opt2val["batch_size"]
200 | n_workers = opt2val["n_workers"]
201 | contrast_ths = opt2val["contrast_ths"]
202 |
203 | # TODO: figure out what is this for
204 | # batch_max_length = int(imgW / 10)
205 |
206 | coord = [item[0] for item in image_list]
207 | img_list = [item[1] for item in image_list]
208 | AlignCollate_normal = AlignCollate(imgH, imgW, adjust_contrast)
209 | test_data = ListDataset(img_list)
210 | test_loader = torch.utils.data.DataLoader(
211 | test_data,
212 | batch_size=batch_size,
213 | shuffle=False,
214 | num_workers=n_workers,
215 | collate_fn=AlignCollate_normal,
216 | pin_memory=True,
217 | )
218 |
219 | # predict first round
220 | result1 = recognizer_predict(recognizer, converter, test_loader, opt2val)
221 |
222 | # predict second round
223 | low_confident_idx = [
224 | i for i, item in enumerate(result1) if (item[1] < contrast_ths)
225 | ]
226 | if len(low_confident_idx) > 0:
227 | img_list2 = [img_list[i] for i in low_confident_idx]
228 | AlignCollate_contrast = AlignCollate(imgH, imgW, adjust_contrast)
229 | test_data = ListDataset(img_list2)
230 | test_loader = torch.utils.data.DataLoader(
231 | test_data,
232 | batch_size=batch_size,
233 | shuffle=False,
234 | num_workers=n_workers,
235 | collate_fn=AlignCollate_contrast,
236 | pin_memory=True,
237 | )
238 | result2 = recognizer_predict(recognizer, converter, test_loader, opt2val)
239 |
240 | result = []
241 | for i, zipped in enumerate(zip(coord, result1)):
242 | box, pred1 = zipped
243 | if i in low_confident_idx:
244 | pred2 = result2[low_confident_idx.index(i)]
245 | if pred1[1] > pred2[1]:
246 | result.append((box, pred1[0], pred1[1]))
247 | else:
248 | result.append((box, pred2[0], pred2[1]))
249 | else:
250 | result.append((box, pred1[0], pred1[1]))
251 |
252 | return result
253 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/models/brainOCR/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/utils.py
10 | """
11 |
12 | import math
13 | import os
14 | from urllib.request import urlretrieve
15 |
16 | import cv2
17 | import numpy as np
18 | import torch
19 | from PIL import Image
20 | from torch import Tensor
21 |
22 | from .imgproc import load_image
23 |
24 |
25 | def consecutive(data, mode: str = "first", stepsize: int = 1):
26 | group = np.split(data, np.where(np.diff(data) != stepsize)[0] + 1)
27 | group = [item for item in group if len(item) > 0]
28 |
29 | if mode == "first":
30 | result = [l[0] for l in group]
31 | elif mode == "last":
32 | result = [l[-1] for l in group]
33 | return result
34 |
35 |
36 | def word_segmentation(
37 | mat,
38 | separator_idx={"th": [1, 2], "en": [3, 4]},
39 | separator_idx_list=[1, 2, 3, 4],
40 | ):
41 | result = []
42 | sep_list = []
43 | start_idx = 0
44 | sep_lang = ""
45 | for sep_idx in separator_idx_list:
46 | if sep_idx % 2 == 0:
47 | mode = "first"
48 | else:
49 | mode = "last"
50 | a = consecutive(np.argwhere(mat == sep_idx).flatten(), mode)
51 | new_sep = [[item, sep_idx] for item in a]
52 | sep_list += new_sep
53 | sep_list = sorted(sep_list, key=lambda x: x[0])
54 |
55 | for sep in sep_list:
56 | for lang in separator_idx.keys():
57 | if sep[1] == separator_idx[lang][0]: # start lang
58 | sep_lang = lang
59 | sep_start_idx = sep[0]
60 | elif sep[1] == separator_idx[lang][1]: # end lang
61 | if sep_lang == lang: # check if last entry if the same start lang
62 | new_sep_pair = [lang, [sep_start_idx + 1, sep[0] - 1]]
63 | if sep_start_idx > start_idx:
64 | result.append(["", [start_idx, sep_start_idx - 1]])
65 | start_idx = sep[0] + 1
66 | result.append(new_sep_pair)
67 | sep_lang = "" # reset
68 |
69 | if start_idx <= len(mat) - 1:
70 | result.append(["", [start_idx, len(mat) - 1]])
71 | return result
72 |
73 |
74 | # code is based from https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py
75 | class BeamEntry:
76 | "information about one single beam at specific time-step"
77 |
78 | def __init__(self):
79 | self.prTotal = 0 # blank and non-blank
80 | self.prNonBlank = 0 # non-blank
81 | self.prBlank = 0 # blank
82 | self.prText = 1 # LM score
83 | self.lmApplied = False # flag if LM was already applied to this beam
84 | self.labeling = () # beam-labeling
85 |
86 |
87 | class BeamState:
88 | "information about the beams at specific time-step"
89 |
90 | def __init__(self):
91 | self.entries = {}
92 |
93 | def norm(self):
94 | "length-normalise LM score"
95 | for k, _ in self.entries.items():
96 | labelingLen = len(self.entries[k].labeling)
97 | self.entries[k].prText = self.entries[k].prText ** (
98 | 1.0 / (labelingLen if labelingLen else 1.0)
99 | )
100 |
101 | def sort(self):
102 | "return beam-labelings, sorted by probability"
103 | beams = [v for (_, v) in self.entries.items()]
104 | sortedBeams = sorted(
105 | beams,
106 | reverse=True,
107 | key=lambda x: x.prTotal * x.prText,
108 | )
109 | return [x.labeling for x in sortedBeams]
110 |
111 | def wordsearch(self, classes, ignore_idx, maxCandidate, dict_list):
112 | beams = [v for (_, v) in self.entries.items()]
113 | sortedBeams = sorted(
114 | beams,
115 | reverse=True,
116 | key=lambda x: x.prTotal * x.prText,
117 | )
118 | if len(sortedBeams) > maxCandidate:
119 | sortedBeams = sortedBeams[:maxCandidate]
120 |
121 | for j, candidate in enumerate(sortedBeams):
122 | idx_list = candidate.labeling
123 | text = ""
124 | for i, l in enumerate(idx_list):
125 | if l not in ignore_idx and (
126 | not (i > 0 and idx_list[i - 1] == idx_list[i])
127 | ):
128 | text += classes[l]
129 |
130 | if j == 0:
131 | best_text = text
132 | if text in dict_list:
133 | # print('found text: ', text)
134 | best_text = text
135 | break
136 | else:
137 | pass
138 | # print('not in dict: ', text)
139 | return best_text
140 |
141 |
142 | def applyLM(parentBeam, childBeam, classes, lm_model, lm_factor: float = 0.01):
143 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
144 | if lm_model is not None and not childBeam.lmApplied:
145 | history = parentBeam.labeling
146 | history = " ".join(
147 | classes[each].replace(" ", "▁") for each in history if each != 0
148 | )
149 |
150 | current_char = classes[childBeam.labeling[-1]].replace(" ", "▁")
151 | if current_char == "[blank]":
152 | lmProb = 1
153 | else:
154 | text = history + " " + current_char
155 | lmProb = 10 ** lm_model.score(text, bos=True) * lm_factor
156 |
157 | childBeam.prText = lmProb # probability of char sequence
158 | childBeam.lmApplied = True # only apply LM once per beam entry
159 |
160 |
161 | def simplify_label(labeling, blankIdx: int = 0):
162 | labeling = np.array(labeling)
163 |
164 | # collapse blank
165 | idx = np.where(~((np.roll(labeling, 1) == labeling) & (labeling == blankIdx)))[0]
166 | labeling = labeling[idx]
167 |
168 | # get rid of blank between different characters
169 | idx = np.where(
170 | ~((np.roll(labeling, 1) != np.roll(labeling, -1)) & (labeling == blankIdx))
171 | )[0]
172 |
173 | if len(labeling) > 0:
174 | last_idx = len(labeling) - 1
175 | if last_idx not in idx:
176 | idx = np.append(idx, [last_idx])
177 | labeling = labeling[idx]
178 |
179 | return tuple(labeling)
180 |
181 |
182 | def addBeam(beamState, labeling):
183 | "add beam if it does not yet exist"
184 | if labeling not in beamState.entries:
185 | beamState.entries[labeling] = BeamEntry()
186 |
187 |
188 | def ctcBeamSearch(
189 | mat,
190 | classes: list,
191 | ignore_idx: int,
192 | lm_model,
193 | lm_factor: float = 0.01,
194 | beam_width: int = 5,
195 | ):
196 | blankIdx = 0
197 | maxT, maxC = mat.shape
198 |
199 | # initialise beam state
200 | last = BeamState()
201 | labeling = ()
202 | last.entries[labeling] = BeamEntry()
203 | last.entries[labeling].prBlank = 1
204 | last.entries[labeling].prTotal = 1
205 |
206 | # go over all time-steps
207 | for t in range(maxT):
208 | # print("t=", t)
209 | curr = BeamState()
210 | # get beam-labelings of best beams
211 | bestLabelings = last.sort()[0:beam_width]
212 | # go over best beams
213 | for labeling in bestLabelings:
214 | # print("labeling:", labeling)
215 | # probability of paths ending with a non-blank
216 | prNonBlank = 0
217 | # in case of non-empty beam
218 | if labeling:
219 | # probability of paths with repeated last char at the end
220 | prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
221 |
222 | # probability of paths ending with a blank
223 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
224 |
225 | # add beam at current time-step if needed
226 | labeling = simplify_label(labeling, blankIdx)
227 | addBeam(curr, labeling)
228 |
229 | # fill in data
230 | curr.entries[labeling].labeling = labeling
231 | curr.entries[labeling].prNonBlank += prNonBlank
232 | curr.entries[labeling].prBlank += prBlank
233 | curr.entries[labeling].prTotal += prBlank + prNonBlank
234 | curr.entries[labeling].prText = last.entries[labeling].prText
235 | # beam-labeling not changed, therefore also LM score unchanged from
236 |
237 | curr.entries[
238 | labeling
239 | ].lmApplied = (
240 | True # LM already applied at previous time-step for this beam-labeling
241 | )
242 |
243 | # extend current beam-labeling
244 | # char_highscore = np.argpartition(mat[t, :], -5)[-5:] # run through 5 highest probability
245 | char_highscore = np.where(mat[t, :] >= 0.5 / maxC)[
246 | 0
247 | ] # run through all probable characters
248 | for c in char_highscore:
249 | # for c in range(maxC - 1):
250 | # add new char to current beam-labeling
251 | newLabeling = labeling + (c,)
252 | newLabeling = simplify_label(newLabeling, blankIdx)
253 |
254 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank
255 | if labeling and labeling[-1] == c:
256 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank
257 | else:
258 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal
259 |
260 | # add beam at current time-step if needed
261 | addBeam(curr, newLabeling)
262 |
263 | # fill in data
264 | curr.entries[newLabeling].labeling = newLabeling
265 | curr.entries[newLabeling].prNonBlank += prNonBlank
266 | curr.entries[newLabeling].prTotal += prNonBlank
267 |
268 | # apply LM
269 | applyLM(
270 | curr.entries[labeling],
271 | curr.entries[newLabeling],
272 | classes,
273 | lm_model,
274 | lm_factor,
275 | )
276 |
277 | # set new beam state
278 |
279 | last = curr
280 |
281 | # normalise LM scores according to beam-labeling-length
282 | last.norm()
283 |
284 | bestLabeling = last.sort()[0] # get most probable labeling
285 | res = ""
286 | for i, l in enumerate(bestLabeling):
287 | # removing repeated characters and blank.
288 | if l != ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])):
289 | res += classes[l]
290 |
291 | return res
292 |
293 |
294 | class CTCLabelConverter(object):
295 | """Convert between text-label and text-index"""
296 |
297 | def __init__(self, vocab: list):
298 | self.char2idx = {char: idx for idx, char in enumerate(vocab)}
299 | self.idx2char = {idx: char for idx, char in enumerate(vocab)}
300 | self.ignored_index = 0
301 | self.vocab = vocab
302 |
303 | def encode(self, texts: list):
304 | """
305 | Convert input texts into indices
306 | texts (list): text labels of each image. [batch_size]
307 |
308 | Returns
309 | text: concatenated text index for CTCLoss.
310 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
311 | length: length of each text. [batch_size]
312 | """
313 | lengths = [len(text) for text in texts]
314 | concatenated_text = "".join(texts)
315 | indices = [self.char2idx[char] for char in concatenated_text]
316 |
317 | return torch.IntTensor(indices), torch.IntTensor(lengths)
318 |
319 | def decode_greedy(self, indices: Tensor, lengths: Tensor):
320 | """convert text-index into text-label.
321 |
322 | :param indices (1D int32 Tensor): [N*length,]
323 | :param lengths (1D int32 Tensor): [N,]
324 | :return:
325 | """
326 | texts = []
327 | index = 0
328 | for length in lengths:
329 | text = indices[index : index + length]
330 |
331 | chars = []
332 | for i in range(length):
333 | if (text[i] != self.ignored_index) and (
334 | not (i > 0 and text[i - 1] == text[i])
335 | ): # removing repeated characters and blank (and separator).
336 | chars.append(self.idx2char[text[i].item()])
337 | texts.append("".join(chars))
338 | index += length
339 | return texts
340 |
341 | def decode_beamsearch(self, mat, lm_model, lm_factor, beam_width: int = 5):
342 | texts = []
343 | for i in range(mat.shape[0]):
344 | text = ctcBeamSearch(
345 | mat[i],
346 | self.vocab,
347 | self.ignored_index,
348 | lm_model,
349 | lm_factor,
350 | beam_width,
351 | )
352 | texts.append(text)
353 | return texts
354 |
355 |
356 | def four_point_transform(image, rect):
357 | (tl, tr, br, bl) = rect
358 |
359 | widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
360 | widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
361 | maxWidth = max(int(widthA), int(widthB))
362 |
363 | # compute the height of the new image, which will be the
364 | # maximum distance between the top-right and bottom-right
365 | # y-coordinates or the top-left and bottom-left y-coordinates
366 | heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
367 | heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
368 | maxHeight = max(int(heightA), int(heightB))
369 |
370 | dst = np.array(
371 | [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
372 | dtype="float32",
373 | )
374 |
375 | # compute the perspective transform matrix and then apply it
376 | M = cv2.getPerspectiveTransform(rect, dst)
377 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
378 |
379 | return warped
380 |
381 |
382 | def group_text_box(
383 | polys,
384 | slope_ths: float = 0.1,
385 | ycenter_ths: float = 0.5,
386 | height_ths: float = 0.5,
387 | width_ths: float = 1.0,
388 | add_margin: float = 0.05,
389 | ):
390 | # poly top-left, top-right, low-right, low-left
391 | horizontal_list, free_list, combined_list, merged_list = [], [], [], []
392 |
393 | for poly in polys:
394 | slope_up = (poly[3] - poly[1]) / np.maximum(10, (poly[2] - poly[0]))
395 | slope_down = (poly[5] - poly[7]) / np.maximum(10, (poly[4] - poly[6]))
396 | if max(abs(slope_up), abs(slope_down)) < slope_ths:
397 | x_max = max([poly[0], poly[2], poly[4], poly[6]])
398 | x_min = min([poly[0], poly[2], poly[4], poly[6]])
399 | y_max = max([poly[1], poly[3], poly[5], poly[7]])
400 | y_min = min([poly[1], poly[3], poly[5], poly[7]])
401 | horizontal_list.append(
402 | [x_min, x_max, y_min, y_max, 0.5 * (y_min + y_max), y_max - y_min]
403 | )
404 | else:
405 | height = np.linalg.norm([poly[6] - poly[0], poly[7] - poly[1]])
406 | margin = int(1.44 * add_margin * height)
407 |
408 | theta13 = abs(
409 | np.arctan((poly[1] - poly[5]) / np.maximum(10, (poly[0] - poly[4])))
410 | )
411 | theta24 = abs(
412 | np.arctan((poly[3] - poly[7]) / np.maximum(10, (poly[2] - poly[6])))
413 | )
414 | # do I need to clip minimum, maximum value here?
415 | x1 = poly[0] - np.cos(theta13) * margin
416 | y1 = poly[1] - np.sin(theta13) * margin
417 | x2 = poly[2] + np.cos(theta24) * margin
418 | y2 = poly[3] - np.sin(theta24) * margin
419 | x3 = poly[4] + np.cos(theta13) * margin
420 | y3 = poly[5] + np.sin(theta13) * margin
421 | x4 = poly[6] - np.cos(theta24) * margin
422 | y4 = poly[7] + np.sin(theta24) * margin
423 |
424 | free_list.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
425 | horizontal_list = sorted(horizontal_list, key=lambda item: item[4])
426 |
427 | # combine box
428 | new_box = []
429 | for poly in horizontal_list:
430 | if len(new_box) == 0:
431 | b_height = [poly[5]]
432 | b_ycenter = [poly[4]]
433 | new_box.append(poly)
434 | else:
435 | # comparable height and comparable y_center level up to ths*height
436 | if (abs(np.mean(b_height) - poly[5]) < height_ths * np.mean(b_height)) and (
437 | abs(np.mean(b_ycenter) - poly[4]) < ycenter_ths * np.mean(b_height)
438 | ):
439 | b_height.append(poly[5])
440 | b_ycenter.append(poly[4])
441 | new_box.append(poly)
442 | else:
443 | b_height = [poly[5]]
444 | b_ycenter = [poly[4]]
445 | combined_list.append(new_box)
446 | new_box = [poly]
447 | combined_list.append(new_box)
448 |
449 | # merge list use sort again
450 | for boxes in combined_list:
451 | if len(boxes) == 1: # one box per line
452 | box = boxes[0]
453 | margin = int(add_margin * box[5])
454 | merged_list.append(
455 | [box[0] - margin, box[1] + margin, box[2] - margin, box[3] + margin]
456 | )
457 | else: # multiple boxes per line
458 | boxes = sorted(boxes, key=lambda item: item[0])
459 |
460 | merged_box, new_box = [], []
461 | for box in boxes:
462 | if len(new_box) == 0:
463 | x_max = box[1]
464 | new_box.append(box)
465 | else:
466 | if abs(box[0] - x_max) < width_ths * (
467 | box[3] - box[2]
468 | ): # merge boxes
469 | x_max = box[1]
470 | new_box.append(box)
471 | else:
472 | x_max = box[1]
473 | merged_box.append(new_box)
474 | new_box = [box]
475 | if len(new_box) > 0:
476 | merged_box.append(new_box)
477 |
478 | for mbox in merged_box:
479 | if len(mbox) != 1: # adjacent box in same line
480 | # do I need to add margin here?
481 | x_min = min(mbox, key=lambda x: x[0])[0]
482 | x_max = max(mbox, key=lambda x: x[1])[1]
483 | y_min = min(mbox, key=lambda x: x[2])[2]
484 | y_max = max(mbox, key=lambda x: x[3])[3]
485 |
486 | margin = int(add_margin * (y_max - y_min))
487 |
488 | merged_list.append(
489 | [x_min - margin, x_max + margin, y_min - margin, y_max + margin]
490 | )
491 | else: # non adjacent box in same line
492 | box = mbox[0]
493 |
494 | margin = int(add_margin * (box[3] - box[2]))
495 | merged_list.append(
496 | [
497 | box[0] - margin,
498 | box[1] + margin,
499 | box[2] - margin,
500 | box[3] + margin,
501 | ]
502 | )
503 | # may need to check if box is really in image
504 | return merged_list, free_list
505 |
506 |
507 | def get_image_list(
508 | horizontal_list: list, free_list: list, img: np.ndarray, model_height: int = 64
509 | ):
510 | image_list = []
511 | maximum_y, maximum_x = img.shape
512 |
513 | max_ratio_hori, max_ratio_free = 1, 1
514 | for box in free_list:
515 | rect = np.array(box, dtype="float32")
516 | transformed_img = four_point_transform(img, rect)
517 | ratio = transformed_img.shape[1] / transformed_img.shape[0]
518 | crop_img = cv2.resize(
519 | transformed_img,
520 | (int(model_height * ratio), model_height),
521 | interpolation=Image.LANCZOS,
522 | )
523 | # box : [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
524 | image_list.append((box, crop_img))
525 | max_ratio_free = max(ratio, max_ratio_free)
526 |
527 | max_ratio_free = math.ceil(max_ratio_free)
528 |
529 | for box in horizontal_list:
530 | x_min = max(0, box[0])
531 | x_max = min(box[1], maximum_x)
532 | y_min = max(0, box[2])
533 | y_max = min(box[3], maximum_y)
534 | crop_img = img[y_min:y_max, x_min:x_max]
535 | width = x_max - x_min
536 | height = y_max - y_min
537 | ratio = width / height
538 | crop_img = cv2.resize(
539 | crop_img,
540 | (int(model_height * ratio), model_height),
541 | interpolation=Image.LANCZOS,
542 | )
543 | image_list.append(
544 | (
545 | [
546 | [x_min, y_min],
547 | [x_max, y_min],
548 | [x_max, y_max],
549 | [x_min, y_max],
550 | ],
551 | crop_img,
552 | )
553 | )
554 | max_ratio_hori = max(ratio, max_ratio_hori)
555 |
556 | max_ratio_hori = math.ceil(max_ratio_hori)
557 | max_ratio = max(max_ratio_hori, max_ratio_free)
558 | max_width = math.ceil(max_ratio) * model_height
559 |
560 | image_list = sorted(
561 | image_list, key=lambda item: item[0][0][1]
562 | ) # sort by vertical position
563 | return image_list, max_width
564 |
565 |
566 | def diff(input_list):
567 | return max(input_list) - min(input_list)
568 |
569 |
570 | def get_paragraph(raw_result, x_ths: int = 1, y_ths: float = 0.5, mode: str = "ltr"):
571 | # create basic attributes
572 | box_group = []
573 | for box in raw_result:
574 | all_x = [int(coord[0]) for coord in box[0]]
575 | all_y = [int(coord[1]) for coord in box[0]]
576 | min_x = min(all_x)
577 | max_x = max(all_x)
578 | min_y = min(all_y)
579 | max_y = max(all_y)
580 | height = max_y - min_y
581 | box_group.append(
582 | [box[1], min_x, max_x, min_y, max_y, height, 0.5 * (min_y + max_y), 0]
583 | ) # last element indicates group
584 | # cluster boxes into paragraph
585 | current_group = 1
586 | while len([box for box in box_group if box[7] == 0]) > 0:
587 | # group0 = non-group
588 | box_group0 = [box for box in box_group if box[7] == 0]
589 | # new group
590 | if len([box for box in box_group if box[7] == current_group]) == 0:
591 | # assign first box to form new group
592 | box_group0[0][7] = current_group
593 | # try to add group
594 | else:
595 | current_box_group = [box for box in box_group if box[7] == current_group]
596 | mean_height = np.mean([box[5] for box in current_box_group])
597 | # yapf: disable
598 | min_gx = min([box[1] for box in current_box_group]) - x_ths * mean_height
599 | max_gx = max([box[2] for box in current_box_group]) + x_ths * mean_height
600 | min_gy = min([box[3] for box in current_box_group]) - y_ths * mean_height
601 | max_gy = max([box[4] for box in current_box_group]) + y_ths * mean_height
602 | add_box = False
603 | for box in box_group0:
604 | same_horizontal_level = (min_gx <= box[1] <= max_gx) or (min_gx <= box[2] <= max_gx)
605 | same_vertical_level = (min_gy <= box[3] <= max_gy) or (min_gy <= box[4] <= max_gy)
606 | if same_horizontal_level and same_vertical_level:
607 | box[7] = current_group
608 | add_box = True
609 | break
610 | # cannot add more box, go to next group
611 | if not add_box:
612 | current_group += 1
613 | # yapf: enable
614 | # arrage order in paragraph
615 | result = []
616 | for i in set(box[7] for box in box_group):
617 | current_box_group = [box for box in box_group if box[7] == i]
618 | mean_height = np.mean([box[5] for box in current_box_group])
619 | min_gx = min([box[1] for box in current_box_group])
620 | max_gx = max([box[2] for box in current_box_group])
621 | min_gy = min([box[3] for box in current_box_group])
622 | max_gy = max([box[4] for box in current_box_group])
623 |
624 | text = ""
625 | while len(current_box_group) > 0:
626 | highest = min([box[6] for box in current_box_group])
627 | candidates = [
628 | box for box in current_box_group if box[6] < highest + 0.4 * mean_height
629 | ]
630 | # get the far left
631 | if mode == "ltr":
632 | most_left = min([box[1] for box in candidates])
633 | for box in candidates:
634 | if box[1] == most_left:
635 | best_box = box
636 | elif mode == "rtl":
637 | most_right = max([box[2] for box in candidates])
638 | for box in candidates:
639 | if box[2] == most_right:
640 | best_box = box
641 | text += " " + best_box[0]
642 | current_box_group.remove(best_box)
643 |
644 | result.append(
645 | [
646 | [
647 | [min_gx, min_gy],
648 | [max_gx, min_gy],
649 | [max_gx, max_gy],
650 | [min_gx, max_gy],
651 | ],
652 | text[1:],
653 | ]
654 | )
655 |
656 | return result
657 |
658 |
659 | def printProgressBar(
660 | prefix="",
661 | suffix="",
662 | decimals: int = 1,
663 | length: int = 100,
664 | fill: str = "█",
665 | printEnd: str = "\r",
666 | ):
667 | """
668 | Call in a loop to create terminal progress bar
669 | @params:
670 | prefix - Optional : prefix string (Str)
671 | suffix - Optional : suffix string (Str)
672 | decimals - Optional : positive number of decimals in percent complete (Int)
673 | length - Optional : character length of bar (Int)
674 | fill - Optional : bar fill character (Str)
675 | printEnd - Optional : end character (e.g. "\r", "\r\n") (Str)
676 | """
677 |
678 | def progress_hook(count, blockSize, totalSize):
679 | progress = count * blockSize / totalSize
680 | percent = ("{0:." + str(decimals) + "f}").format(progress * 100)
681 | filledLength = int(length * progress)
682 | bar = fill * filledLength + "-" * (length - filledLength)
683 | print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=printEnd)
684 |
685 | return progress_hook
686 |
687 |
688 | def reformat_input(image):
689 | """
690 | :param image: image file path or bytes or array
691 | :return:
692 | img (array): (original_image_height, original_image_width, 3)
693 | img_cv_grey (array): (original_image_height, original_image_width, 3)
694 | """
695 | if type(image) == str:
696 | if image.startswith("http://") or image.startswith("https://"):
697 | tmp, _ = urlretrieve(
698 | image,
699 | reporthook=printProgressBar(
700 | prefix="Progress:",
701 | suffix="Complete",
702 | length=50,
703 | ),
704 | )
705 | img_cv_grey = cv2.imread(tmp, cv2.IMREAD_GRAYSCALE)
706 | os.remove(tmp)
707 | else:
708 | img_cv_grey = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
709 | image = os.path.expanduser(image)
710 | img = load_image(image) # can accept URL
711 | elif type(image) == bytes:
712 | nparr = np.frombuffer(image, np.uint8)
713 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
714 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
715 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
716 |
717 | elif type(image) == np.ndarray:
718 | if len(image.shape) == 2: # grayscale
719 | img_cv_grey = image
720 | img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
721 | elif len(image.shape) == 3 and image.shape[2] == 3: # BGRscale
722 | img = image
723 | img_cv_grey = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
724 | elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale
725 | img = image[:, :, :3]
726 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
727 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
728 |
729 | return img, img_cv_grey
730 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/pororo.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """
9 | Pororo task-specific factory class
10 |
11 | isort:skip_file
12 |
13 | """
14 |
15 | import logging
16 | from typing import Optional
17 | from .tasks.utils.base import PororoTaskBase
18 |
19 | import torch
20 |
21 | from .tasks import (
22 | PororoOcrFactory,
23 | )
24 |
25 | SUPPORTED_TASKS = {
26 | "ocr": PororoOcrFactory,
27 | }
28 |
29 | LANG_ALIASES = {
30 | "english": "en",
31 | "eng": "en",
32 | "korean": "ko",
33 | "kor": "ko",
34 | "kr": "ko",
35 | "chinese": "zh",
36 | "chn": "zh",
37 | "cn": "zh",
38 | "japanese": "ja",
39 | "jap": "ja",
40 | "jp": "ja",
41 | "jejueo": "je",
42 | "jje": "je",
43 | }
44 |
45 | logging.getLogger("transformers").setLevel(logging.WARN)
46 | logging.getLogger("fairseq").setLevel(logging.WARN)
47 | logging.getLogger("sentence_transformers").setLevel(logging.WARN)
48 | logging.getLogger("youtube_dl").setLevel(logging.WARN)
49 | logging.getLogger("pydub").setLevel(logging.WARN)
50 | logging.getLogger("librosa").setLevel(logging.WARN)
51 |
52 |
53 | class Pororo:
54 | r"""
55 | This is a generic class that will return one of the task-specific model classes of the library
56 | when created with the `__new__()` method
57 |
58 | """
59 |
60 | def __new__(
61 | cls,
62 | task: str,
63 | lang: str = "en",
64 | model: Optional[str] = None,
65 | **kwargs,
66 | ) -> PororoTaskBase:
67 | if task not in SUPPORTED_TASKS:
68 | raise KeyError(
69 | "Unknown task {}, available tasks are {}".format(
70 | task,
71 | list(SUPPORTED_TASKS.keys()),
72 | )
73 | )
74 |
75 | lang = lang.lower()
76 | lang = LANG_ALIASES[lang] if lang in LANG_ALIASES else lang
77 |
78 | # Get device information from torch API
79 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80 |
81 | # Instantiate task-specific pipeline module, if possible
82 | task_module = SUPPORTED_TASKS[task](
83 | task,
84 | lang,
85 | model,
86 | **kwargs,
87 | ).load(device)
88 |
89 | return task_module
90 |
91 | @staticmethod
92 | def available_tasks() -> str:
93 | """
94 | Returns available tasks in Pororo project
95 |
96 | Returns:
97 | str: Supported task names
98 |
99 | """
100 | return "Available tasks are {}".format(list(SUPPORTED_TASKS.keys()))
101 |
102 | @staticmethod
103 | def available_models(task: str) -> str:
104 | """
105 | Returns available model names correponding to the user-input task
106 |
107 | Args:
108 | task (str): user-input task name
109 |
110 | Returns:
111 | str: Supported model names corresponding to the user-input task
112 |
113 | Raises:
114 | KeyError: When user-input task is not supported
115 |
116 | """
117 | if task not in SUPPORTED_TASKS:
118 | raise KeyError(
119 | "Unknown task {} ! Please check available models via `available_tasks()`".format(
120 | task
121 | )
122 | )
123 |
124 | langs = SUPPORTED_TASKS[task].get_available_models()
125 | output = f"Available models for {task} are "
126 | for lang in langs:
127 | output += f"([lang]: {lang}, [model]: {', '.join(langs[lang])}), "
128 | return output[:-2]
129 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | # flake8: noqa
9 | """
10 | __init__.py for import child .py files
11 |
12 | isort:skip_file
13 | """
14 |
15 | # Utility classes & functions
16 | # import pororo.tasks.utils
17 | from .utils.download_utils import download_or_load
18 | from .utils.base import (
19 | PororoBiencoderBase,
20 | PororoFactoryBase,
21 | PororoGenerationBase,
22 | PororoSimpleBase,
23 | PororoTaskGenerationBase,
24 | )
25 |
26 | # Factory classes
27 | from .optical_character_recognition import PororoOcrFactory
28 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/optical_character_recognition.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """OCR related modeling class"""
9 |
10 | from typing import Optional
11 |
12 | from . import download_or_load
13 | from .utils.base import PororoFactoryBase, PororoSimpleBase
14 |
15 |
16 | class PororoOcrFactory(PororoFactoryBase):
17 | """
18 | Recognize optical characters in image file
19 | Currently support Korean language
20 |
21 | English + Korean (`brainocr`)
22 |
23 | - dataset: Internal data + AI hub Font Image dataset
24 | - metric: TBU
25 | - ref: https://www.aihub.or.kr/aidata/133
26 |
27 | Examples:
28 | >>> ocr = Pororo(task="ocr", lang="ko")
29 | >>> ocr(IMAGE_PATH)
30 | ["사이렌'(' 신마'", "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"]
31 |
32 | >>> ocr = Pororo(task="ocr", lang="ko")
33 | >>> ocr(IMAGE_PATH, detail=True)
34 | {
35 | 'description': ["사이렌'(' 신마', "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"],
36 | 'bounding_poly': [
37 | {
38 | 'description': "사이렌'(' 신마'",
39 | 'vertices': [
40 | {'x': 93, 'y': 7},
41 | {'x': 164, 'y': 7},
42 | {'x': 164, 'y': 21},
43 | {'x': 93, 'y': 21}
44 | ]
45 | },
46 | {
47 | 'description': "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고",
48 | 'vertices': [
49 | {'x': 0, 'y': 30},
50 | {'x': 259, 'y': 30},
51 | {'x': 259, 'y': 194},
52 | {'x': 0, 'y': 194}]}
53 | ]
54 | }
55 | }
56 | """
57 |
58 | def __init__(self, task: str, lang: str, model: Optional[str]):
59 | super().__init__(task, lang, model)
60 | self.detect_model = "craft"
61 | self.ocr_opt = "ocr-opt"
62 |
63 | @staticmethod
64 | def get_available_langs():
65 | return ["en", "ko"]
66 |
67 | @staticmethod
68 | def get_available_models():
69 | return {
70 | "en": ["brainocr"],
71 | "ko": ["brainocr"],
72 | }
73 |
74 | def load(self, device: str):
75 | """
76 | Load user-selected task-specific model
77 |
78 | Args:
79 | device (str): device information
80 |
81 | Returns:
82 | object: User-selected task-specific model
83 |
84 | """
85 | if self.config.n_model == "brainocr":
86 | from ..models.brainOCR import brainocr
87 |
88 | if self.config.lang not in self.get_available_langs():
89 | raise ValueError(
90 | f"Unsupported Language : {self.config.lang}",
91 | 'Support Languages : ["en", "ko"]',
92 | )
93 |
94 | det_model_path = download_or_load(
95 | f"misc/{self.detect_model}.pt",
96 | self.config.lang,
97 | )
98 | rec_model_path = download_or_load(
99 | f"misc/{self.config.n_model}.pt",
100 | self.config.lang,
101 | )
102 | opt_fp = download_or_load(
103 | f"misc/{self.ocr_opt}.txt",
104 | self.config.lang,
105 | )
106 | model = brainocr.Reader(
107 | self.config.lang,
108 | det_model_ckpt_fp=det_model_path,
109 | rec_model_ckpt_fp=rec_model_path,
110 | opt_fp=opt_fp,
111 | device=device,
112 | )
113 | model.detector.to(device)
114 | model.recognizer.to(device)
115 | return PororoOCR(model, self.config)
116 |
117 |
118 | class PororoOCR(PororoSimpleBase):
119 | def __init__(self, model, config):
120 | super().__init__(config)
121 | self._model = model
122 |
123 | def _postprocess(self, ocr_results, detail: bool = False):
124 | """
125 | Post-process for OCR result
126 |
127 | Args:
128 | ocr_results (list): list contains result of OCR
129 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc)
130 |
131 | """
132 | sorted_ocr_results = sorted(
133 | ocr_results,
134 | key=lambda x: (
135 | x[0][0][1],
136 | x[0][0][0],
137 | ),
138 | )
139 |
140 | if not detail:
141 | return [sorted_ocr_results[i][-1] for i in range(len(sorted_ocr_results))]
142 |
143 | result_dict = {
144 | "description": list(),
145 | "bounding_poly": list(),
146 | }
147 |
148 | for ocr_result in sorted_ocr_results:
149 | vertices = list()
150 |
151 | for vertice in ocr_result[0]:
152 | vertices.append(
153 | {
154 | "x": vertice[0],
155 | "y": vertice[1],
156 | }
157 | )
158 |
159 | result_dict["description"].append(ocr_result[1])
160 | result_dict["bounding_poly"].append(
161 | {"description": ocr_result[1], "vertices": vertices}
162 | )
163 |
164 | return result_dict
165 |
166 | def predict(self, image_path: str, **kwargs):
167 | """
168 | Conduct Optical Character Recognition (OCR)
169 |
170 | Args:
171 | image_path (str): the image file path
172 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc)
173 |
174 | """
175 | detail = kwargs.get("detail", False)
176 |
177 | return self._postprocess(
178 | self._model(
179 | image_path,
180 | skip_details=False,
181 | batch_size=1,
182 | paragraph=True,
183 | ),
184 | detail,
185 | )
186 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/base.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import re
9 | import unicodedata
10 | from abc import abstractmethod
11 | from dataclasses import dataclass
12 | from typing import List, Mapping, Optional, Union
13 |
14 |
15 | @dataclass
16 | class TaskConfig:
17 | task: str
18 | lang: str
19 | n_model: str
20 |
21 |
22 | class PororoTaskBase:
23 | r"""Task base class that implements basic functions for prediction"""
24 |
25 | def __init__(self, config: TaskConfig):
26 | self.config = config
27 |
28 | @property
29 | def n_model(self):
30 | return self.config.n_model
31 |
32 | @property
33 | def lang(self):
34 | return self.config.lang
35 |
36 | @abstractmethod
37 | def predict(
38 | self,
39 | text: Union[str, List[str]],
40 | **kwargs,
41 | ):
42 | raise NotImplementedError("`predict()` function is not implemented properly!")
43 |
44 | def __call__(self):
45 | raise NotImplementedError("`call()` function is not implemented properly!")
46 |
47 | def __repr__(self):
48 | return f"[TASK]: {self.config.task.upper()}\n[LANG]: {self.config.lang.upper()}\n[MODEL]: {self.config.n_model}"
49 |
50 | def _normalize(self, text: str):
51 | """Unicode normalization and whitespace removal (often needed for contexts)"""
52 | text = unicodedata.normalize("NFKC", text)
53 | text = re.sub(r"\s+", " ", text).strip()
54 | return text
55 |
56 |
57 | class PororoFactoryBase(object):
58 | r"""This is a factory base class that construct task-specific module"""
59 |
60 | def __init__(
61 | self,
62 | task: str,
63 | lang: str,
64 | model: Optional[str] = None,
65 | ):
66 | self._available_langs = self.get_available_langs()
67 | self._available_models = self.get_available_models()
68 | self._model2lang = {
69 | v: k for k, vs in self._available_models.items() for v in vs
70 | }
71 |
72 | # Set default language as very first supported language
73 | assert (
74 | lang in self._available_langs
75 | ), f"Following langs are supported for this task: {self._available_langs}"
76 |
77 | if lang is None:
78 | lang = self._available_langs[0]
79 |
80 | # Change language option if model is defined by user
81 | if model is not None:
82 | lang = self._model2lang[model]
83 |
84 | # Set default model
85 | if model is None:
86 | model = self.get_default_model(lang)
87 |
88 | # yapf: disable
89 | assert (model in self._available_models[lang]), f"{model} is NOT supported for {lang}"
90 | # yapf: enable
91 |
92 | self.config = TaskConfig(task, lang, model)
93 |
94 | @abstractmethod
95 | def get_available_langs(self) -> List[str]:
96 | raise NotImplementedError(
97 | "`get_available_langs()` is not implemented properly!"
98 | )
99 |
100 | @abstractmethod
101 | def get_available_models(self) -> Mapping[str, List[str]]:
102 | raise NotImplementedError(
103 | "`get_available_models()` is not implemented properly!"
104 | )
105 |
106 | @abstractmethod
107 | def get_default_model(self, lang: str) -> str:
108 | return self._available_models[lang][0]
109 |
110 | @classmethod
111 | def load(cls) -> PororoTaskBase:
112 | raise NotImplementedError("Model load function is not implemented properly!")
113 |
114 |
115 | class PororoSimpleBase(PororoTaskBase):
116 | r"""Simple task base wrapper class"""
117 |
118 | def __call__(self, text: str, **kwargs):
119 | return self.predict(text, **kwargs)
120 |
121 |
122 | class PororoBiencoderBase(PororoTaskBase):
123 | r"""Bi-Encoder base wrapper class"""
124 |
125 | def __call__(
126 | self,
127 | sent_a: str,
128 | sent_b: Union[str, List[str]],
129 | **kwargs,
130 | ):
131 | assert isinstance(sent_a, str), "sent_a should be string type"
132 | assert isinstance(sent_b, str) or isinstance(
133 | sent_b, list
134 | ), "sent_b should be string or list of string type"
135 |
136 | sent_a = self._normalize(sent_a)
137 |
138 | # For "Find Similar Sentence" task
139 | if isinstance(sent_b, list):
140 | sent_b = [self._normalize(t) for t in sent_b]
141 | else:
142 | sent_b = self._normalize(sent_b)
143 |
144 | return self.predict(sent_a, sent_b, **kwargs)
145 |
146 |
147 | class PororoGenerationBase(PororoTaskBase):
148 | r"""Generation task wrapper class using various generation tricks"""
149 |
150 | def __call__(
151 | self,
152 | text: str,
153 | beam: int = 5,
154 | temperature: float = 1.0,
155 | top_k: int = -1,
156 | top_p: float = -1,
157 | no_repeat_ngram_size: int = 4,
158 | len_penalty: float = 1.0,
159 | **kwargs,
160 | ):
161 | assert isinstance(text, str), "Input text should be string type"
162 |
163 | return self.predict(
164 | text,
165 | beam=beam,
166 | temperature=temperature,
167 | top_k=top_k,
168 | top_p=top_p,
169 | no_repeat_ngram_size=no_repeat_ngram_size,
170 | len_penalty=len_penalty,
171 | **kwargs,
172 | )
173 |
174 |
175 | class PororoTaskGenerationBase(PororoTaskBase):
176 | r"""Generation task wrapper class using only beam search"""
177 |
178 | def __call__(self, text: str, beam: int = 1, **kwargs):
179 | assert isinstance(text, str), "Input text should be string type"
180 |
181 | text = self._normalize(text)
182 |
183 | return self.predict(text, beam=beam, **kwargs)
184 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from dataclasses import dataclass
9 | from typing import Union
10 |
11 |
12 | @dataclass
13 | class TransformerConfig:
14 | src_dict: Union[str, None]
15 | tgt_dict: Union[str, None]
16 | src_tok: Union[str, None]
17 | tgt_tok: Union[str, None]
18 |
19 |
20 | CONFIGS = {
21 | "transformer.base.ko.const": TransformerConfig(
22 | "dict.transformer.base.ko.const",
23 | "dict.transformer.base.ko.const",
24 | None,
25 | None,
26 | ),
27 | "transformer.base.ko.pg": TransformerConfig(
28 | "dict.transformer.base.ko.mt",
29 | "dict.transformer.base.ko.mt",
30 | "bpe8k.ko",
31 | None,
32 | ),
33 | "transformer.base.ko.pg_long": TransformerConfig(
34 | "dict.transformer.base.ko.mt",
35 | "dict.transformer.base.ko.mt",
36 | "bpe8k.ko",
37 | None,
38 | ),
39 | "transformer.base.en.gec": TransformerConfig(
40 | "dict.transformer.base.en.mt",
41 | "dict.transformer.base.en.mt",
42 | "bpe32k.en",
43 | None,
44 | ),
45 | "transformer.base.zh.pg": TransformerConfig(
46 | "dict.transformer.base.zh.mt",
47 | "dict.transformer.base.zh.mt",
48 | None,
49 | None,
50 | ),
51 | "transformer.base.ja.pg": TransformerConfig(
52 | "dict.transformer.base.ja.mt",
53 | "dict.transformer.base.ja.mt",
54 | "bpe8k.ja",
55 | None,
56 | ),
57 | "transformer.base.zh.const": TransformerConfig(
58 | "dict.transformer.base.zh.const",
59 | "dict.transformer.base.zh.const",
60 | None,
61 | None,
62 | ),
63 | "transformer.base.en.const": TransformerConfig(
64 | "dict.transformer.base.en.const",
65 | "dict.transformer.base.en.const",
66 | None,
67 | None,
68 | ),
69 | "transformer.base.en.pg": TransformerConfig(
70 | "dict.transformer.base.en.mt",
71 | "dict.transformer.base.en.mt",
72 | "bpe32k.en",
73 | None,
74 | ),
75 | "transformer.base.ko.gec": TransformerConfig(
76 | "dict.transformer.base.ko.gec",
77 | "dict.transformer.base.ko.gec",
78 | "bpe8k.ko",
79 | None,
80 | ),
81 | "transformer.base.en.char_gec": TransformerConfig(
82 | "dict.transformer.base.en.char_gec",
83 | "dict.transformer.base.en.char_gec",
84 | None,
85 | None,
86 | ),
87 | "transformer.base.en.caption": TransformerConfig(
88 | None,
89 | None,
90 | None,
91 | None,
92 | ),
93 | "transformer.base.ja.p2g": TransformerConfig(
94 | "dict.transformer.base.ja.p2g",
95 | "dict.transformer.base.ja.p2g",
96 | None,
97 | None,
98 | ),
99 | "transformer.large.multi.mtpg": TransformerConfig(
100 | "dict.transformer.large.multi.mtpg",
101 | "dict.transformer.large.multi.mtpg",
102 | "bpe32k.en",
103 | None,
104 | ),
105 | "transformer.large.multi.fast.mtpg": TransformerConfig(
106 | "dict.transformer.large.multi.mtpg",
107 | "dict.transformer.large.multi.mtpg",
108 | "bpe32k.en",
109 | None,
110 | ),
111 | "transformer.large.ko.wsd": TransformerConfig(
112 | "dict.transformer.large.ko.wsd",
113 | "dict.transformer.large.ko.wsd",
114 | None,
115 | None,
116 | ),
117 | }
118 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/tasks/utils/download_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | """Module download related function from. Tenth"""
9 |
10 | import logging
11 | import os
12 | import platform
13 | import sys
14 | import zipfile
15 | from dataclasses import dataclass
16 | from typing import Tuple, Union
17 |
18 | import wget
19 |
20 | from .config import CONFIGS
21 |
22 | DEFAULT_PREFIX = {
23 | "model": "https://twg.kakaocdn.net/pororo/{lang}/models",
24 | "dict": "https://twg.kakaocdn.net/pororo/{lang}/dicts",
25 | }
26 |
27 |
28 | @dataclass
29 | class TransformerInfo:
30 | r"Dataclass for transformer-based model"
31 | path: str
32 | dict_path: str
33 | src_dict: str
34 | tgt_dict: str
35 | src_tok: Union[str, None]
36 | tgt_tok: Union[str, None]
37 |
38 |
39 | @dataclass
40 | class DownloadInfo:
41 | r"Download information such as defined directory, language and model name"
42 | n_model: str
43 | lang: str
44 | root_dir: str
45 |
46 |
47 | def get_save_dir(save_dir: str = None) -> str:
48 | """
49 | Get default save directory
50 |
51 | Args:
52 | savd_dir(str): User-defined save directory
53 |
54 | Returns:
55 | str: Set save directory
56 |
57 | """
58 | # If user wants to manually define save directory
59 | if save_dir:
60 | os.makedirs(save_dir, exist_ok=True)
61 | return save_dir
62 |
63 | pf = platform.system()
64 |
65 | if pf == "Windows":
66 | save_dir = "C:\\pororo"
67 | else:
68 | home_dir = os.path.expanduser("~")
69 | save_dir = os.path.join(home_dir, ".pororo")
70 |
71 | if not os.path.exists(save_dir):
72 | os.makedirs(save_dir)
73 |
74 | return save_dir
75 |
76 |
77 | def get_download_url(n_model: str, key: str, lang: str) -> str:
78 | """
79 | Get download url using default prefix
80 |
81 | Args:
82 | n_model (str): model name
83 | key (str): key name either `model` or `dict`
84 | lang (str): language name
85 |
86 | Returns:
87 | str: generated download url
88 |
89 | """
90 | default_prefix = DEFAULT_PREFIX[key].format(lang=lang)
91 | return f"{default_prefix}/{n_model}"
92 |
93 |
94 | def download_or_load_bert(info: DownloadInfo) -> str:
95 | """
96 | Download fine-tuned BrainBert & BrainSBert model and dict
97 |
98 | Args:
99 | info (DownloadInfo): download information
100 |
101 | Returns:
102 | str: downloaded bert & sbert path
103 |
104 | """
105 | model_path = os.path.join(info.root_dir, info.n_model)
106 |
107 | if not os.path.exists(model_path):
108 | info.n_model += ".zip"
109 | zip_path = os.path.join(info.root_dir, info.n_model)
110 |
111 | type_dir = download_from_url(
112 | info.n_model,
113 | zip_path,
114 | key="model",
115 | lang=info.lang,
116 | )
117 |
118 | zip_file = zipfile.ZipFile(zip_path)
119 | zip_file.extractall(type_dir)
120 | zip_file.close()
121 |
122 | return model_path
123 |
124 |
125 | def download_or_load_transformer(info: DownloadInfo) -> TransformerInfo:
126 | """
127 | Download pre-trained Transformer model and corresponding dict
128 |
129 | Args:
130 | info (DownloadInfo): download information
131 |
132 | Returns:
133 | TransformerInfo: information dataclass for transformer construction
134 |
135 | """
136 | config = CONFIGS[info.n_model.split("/")[-1]]
137 |
138 | src_dict_in = config.src_dict
139 | tgt_dict_in = config.tgt_dict
140 | src_tok = config.src_tok
141 | tgt_tok = config.tgt_tok
142 |
143 | info.n_model += ".pt"
144 | model_path = os.path.join(info.root_dir, info.n_model)
145 |
146 | # Download or load Transformer model
147 | model_type_dir = "/".join(model_path.split("/")[:-1])
148 | if not os.path.exists(model_path):
149 | model_type_dir = download_from_url(
150 | info.n_model,
151 | model_path,
152 | key="model",
153 | lang=info.lang,
154 | )
155 |
156 | dict_type_dir = str()
157 | src_dict, tgt_dict = str(), str()
158 |
159 | # Download or load corresponding dictionary
160 | if src_dict_in:
161 | src_dict = f"{src_dict_in}.txt"
162 | src_dict_path = os.path.join(info.root_dir, f"dicts/{src_dict}")
163 | dict_type_dir = "/".join(src_dict_path.split("/")[:-1])
164 | if not os.path.exists(src_dict_path):
165 | dict_type_dir = download_from_url(
166 | src_dict,
167 | src_dict_path,
168 | key="dict",
169 | lang=info.lang,
170 | )
171 |
172 | if tgt_dict_in:
173 | tgt_dict = f"{tgt_dict_in}.txt"
174 | tgt_dict_path = os.path.join(info.root_dir, f"dicts/{tgt_dict}")
175 | if not os.path.exists(tgt_dict_path):
176 | download_from_url(
177 | tgt_dict,
178 | tgt_dict_path,
179 | key="dict",
180 | lang=info.lang,
181 | )
182 |
183 | # Download or load corresponding tokenizer
184 | src_tok_path, tgt_tok_path = None, None
185 | if src_tok:
186 | src_tok_path = download_or_load(
187 | f"tokenizers/{src_tok}.zip",
188 | lang=info.lang,
189 | )
190 | if tgt_tok:
191 | tgt_tok_path = download_or_load(
192 | f"tokenizers/{tgt_tok}.zip",
193 | lang=info.lang,
194 | )
195 |
196 | return TransformerInfo(
197 | path=model_type_dir,
198 | dict_path=dict_type_dir,
199 | # Drop prefix "dict." and postfix ".txt"
200 | src_dict=".".join(src_dict.split(".")[1:-1]),
201 | # to follow fairseq's dictionary load process
202 | tgt_dict=".".join(tgt_dict.split(".")[1:-1]),
203 | src_tok=src_tok_path,
204 | tgt_tok=tgt_tok_path,
205 | )
206 |
207 |
208 | def download_or_load_misc(info: DownloadInfo) -> str:
209 | """
210 | Download (pre-trained) miscellaneous model
211 |
212 | Args:
213 | info (DownloadInfo): download information
214 |
215 | Returns:
216 | str: miscellaneous model path
217 |
218 | """
219 | # Add postfix <.model> for sentencepiece
220 | if "sentencepiece" in info.n_model:
221 | info.n_model += ".model"
222 |
223 | # Generate target model path using root directory
224 | model_path = os.path.join(info.root_dir, info.n_model)
225 | if not os.path.exists(model_path):
226 | type_dir = download_from_url(
227 | info.n_model,
228 | model_path,
229 | key="model",
230 | lang=info.lang,
231 | )
232 |
233 | if ".zip" in info.n_model:
234 | zip_file = zipfile.ZipFile(model_path)
235 | zip_file.extractall(type_dir)
236 | zip_file.close()
237 |
238 | if ".zip" in info.n_model:
239 | model_path = model_path[: model_path.rfind(".zip")]
240 | return model_path
241 |
242 |
243 | def download_or_load_bart(info: DownloadInfo) -> Union[str, Tuple[str, str]]:
244 | """
245 | Download BART model
246 |
247 | Args:
248 | info (DownloadInfo): download information
249 |
250 | Returns:
251 | Union[str, Tuple[str, str]]: BART model path (with. corresponding SentencePiece)
252 |
253 | """
254 | info.n_model += ".pt"
255 |
256 | model_path = os.path.join(info.root_dir, info.n_model)
257 | if not os.path.exists(model_path):
258 | download_from_url(
259 | info.n_model,
260 | model_path,
261 | key="model",
262 | lang=info.lang,
263 | )
264 |
265 | return model_path
266 |
267 |
268 | def download_from_url(
269 | n_model: str,
270 | model_path: str,
271 | key: str,
272 | lang: str,
273 | ) -> str:
274 | """
275 | Download specified model from Tenth
276 |
277 | Args:
278 | n_model (str): model name
279 | model_path (str): pre-defined model path
280 | key (str): type key (either model or dict)
281 | lang (str): language name
282 |
283 | Returns:
284 | str: default type directory
285 |
286 | """
287 | # Get default type dir path
288 | type_dir = "/".join(model_path.split("/")[:-1])
289 | os.makedirs(type_dir, exist_ok=True)
290 |
291 | # Get download tenth url
292 | url = get_download_url(n_model, key=key, lang=lang)
293 |
294 | logging.info("Downloading user-selected model...")
295 | wget.download(url, type_dir)
296 | sys.stderr.write("\n")
297 | sys.stderr.flush()
298 |
299 | return type_dir
300 |
301 |
302 | def download_or_load(
303 | n_model: str,
304 | lang: str,
305 | custom_save_dir: str = None,
306 | ) -> Union[TransformerInfo, str, Tuple[str, str]]:
307 | """
308 | Download or load model based on model information
309 |
310 | Args:
311 | n_model (str): model name
312 | lang (str): language information
313 | custom_save_dir (str, optional): user-defined save directory path. defaults to None.
314 |
315 | Returns:
316 | Union[TransformerInfo, str, Tuple[str, str]]
317 |
318 | """
319 | root_dir = get_save_dir(save_dir=custom_save_dir)
320 | info = DownloadInfo(n_model, lang, root_dir)
321 |
322 | if "transformer" in n_model:
323 | return download_or_load_transformer(info)
324 | if "bert" in n_model:
325 | return download_or_load_bert(info)
326 | if "bart" in n_model and "bpe" not in n_model:
327 | return download_or_load_bart(info)
328 |
329 | return download_or_load_misc(info)
330 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/pororo/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | from contextlib import contextmanager
9 | from tempfile import NamedTemporaryFile
10 |
11 | from requests import get
12 |
13 |
14 | def postprocess_span(tagger, text: str) -> str:
15 | """
16 | Postprocess NOUN span to remove unnecessary character
17 |
18 | Args:
19 | text (str): NOUN span to be processed
20 |
21 | Returns:
22 | (str): post-processed NOUN span
23 |
24 | Examples:
25 | >>> postprocess_span("강감찬 장군은")
26 | '강감찬 장군'
27 | >>> postprocess_span("그녀에게")
28 | '그녀'
29 |
30 | """
31 |
32 | # First, strip punctuations
33 | text = text.strip("""!"\#$&'()*+,\-./:;<=>?@\^_‘{|}~《》""")
34 |
35 | # Complete imbalanced parentheses pair
36 | if text.count("(") == text.count(")") + 1:
37 | text += ")"
38 | elif text.count("(") + 1 == text.count(")"):
39 | text = "(" + text
40 |
41 | # Preserve beginning tokens since we only want to extract noun phrase of the last eojeol
42 | noun_phrase = " ".join(text.rsplit(" ", 1)[:-1])
43 | tokens = text.split(" ")
44 | eojeols = list()
45 | for token in tokens:
46 | eojeols.append(tagger.pos(token))
47 | last_eojeol = eojeols[-1]
48 |
49 | # Iterate backwardly to remove unnecessary postfixes
50 | i = 0
51 | for i, token in enumerate(last_eojeol[::-1]):
52 | _, pos = token
53 | # 1. The loop breaks when you meet a noun
54 | # 2. The loop also breaks when you meet a XSN (e.g. 8/SN+일/NNB LG/SL 전/XSN)
55 | if (pos[0] in ("N", "S")) or pos.startswith("XSN"):
56 | break
57 | idx = len(last_eojeol) - i
58 |
59 | # Extract noun span from last eojeol and postpend it to beginning tokens
60 | ext_last_eojeol = "".join(morph for morph, _ in last_eojeol[:idx])
61 | noun_phrase += " " + ext_last_eojeol
62 | return noun_phrase.strip()
63 |
64 |
65 | @contextmanager
66 | def control_temp(file_path: str):
67 | """
68 | Download temporary file from web, then remove it after some context
69 |
70 | Args:
71 | file_path (str): web file path
72 |
73 | """
74 | # yapf: disable
75 | assert file_path.startswith("http"), "File path should contain `http` prefix !"
76 | # yapf: enable
77 |
78 | ext = file_path[file_path.rfind(".") :]
79 |
80 | with NamedTemporaryFile("wb", suffix=ext, delete=True) as f:
81 | response = get(file_path, allow_redirects=True)
82 | f.write(response.content)
83 | yield f.name
84 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/utils/image_convert.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 |
11 |
12 | def convert_coord(coord_list):
13 | x_min, x_max, y_min, y_max = coord_list
14 | return [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
15 |
16 |
17 | # https://www.life2coding.com/cropping-polygon-or-non-rectangular-region-from-image-using-opencv-python/
18 | # https://stackoverflow.com/questions/48301186/cropping-concave-polygon-from-image-using-opencv-python
19 | def crop(image, points):
20 | pts = np.array(points, np.int32)
21 |
22 | # Crop the bounding rect
23 | rect = cv2.boundingRect(pts)
24 | x, y, w, h = rect
25 | croped = image[y : y + h, x : x + w].copy()
26 |
27 | # make mask
28 | pts = pts - pts.min(axis=0)
29 |
30 | mask = np.zeros(croped.shape[:2], np.uint8)
31 | cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
32 |
33 | # do bit-op
34 | dst = cv2.bitwise_and(croped, croped, mask=mask)
35 |
36 | # add the white background
37 | bg = np.ones_like(croped, np.uint8) * 255
38 | cv2.bitwise_not(bg, bg, mask=mask)
39 | result = bg + dst
40 |
41 | return result
42 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/utils/image_util.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 | import platform
11 | from PIL import ImageFont, ImageDraw, Image
12 | from matplotlib import pyplot as plt
13 |
14 |
15 | def plt_imshow(title="image", img=None, figsize=(8, 5)):
16 | plt.figure(figsize=figsize)
17 |
18 | if type(img) is str:
19 | img = cv2.imread(img)
20 |
21 | if type(img) == list:
22 | if type(title) == list:
23 | titles = title
24 | else:
25 | titles = []
26 |
27 | for i in range(len(img)):
28 | titles.append(title)
29 |
30 | for i in range(len(img)):
31 | if len(img[i].shape) <= 2:
32 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_GRAY2RGB)
33 | else:
34 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_BGR2RGB)
35 |
36 | plt.subplot(1, len(img), i + 1), plt.imshow(rgbImg)
37 | plt.title(titles[i])
38 | plt.xticks([]), plt.yticks([])
39 |
40 | plt.show()
41 | else:
42 | if len(img.shape) < 3:
43 | rgbImg = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
44 | else:
45 | rgbImg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
46 |
47 | plt.imshow(rgbImg)
48 | plt.title(title)
49 | plt.xticks([]), plt.yticks([])
50 | plt.show()
51 |
52 |
53 | def put_text(image, text, x, y, color=(0, 255, 0), font_size=22):
54 | if type(image) == np.ndarray:
55 | color_coverted = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56 | image = Image.fromarray(color_coverted)
57 |
58 | if platform.system() == "Darwin":
59 | font = "AppleGothic.ttf"
60 | elif platform.system() == "Windows":
61 | font = "malgun.ttf"
62 | elif platform.system() == "Linux":
63 | font = "NotoSansCJK-Regular.ttc"
64 |
65 | image_font = ImageFont.truetype(font, font_size)
66 | font = ImageFont.load_default()
67 | draw = ImageDraw.Draw(image)
68 |
69 | draw.text((x, y), text, font=image_font, fill=color)
70 |
71 | numpy_image = np.array(image)
72 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
73 |
74 | return opencv_image
75 |
--------------------------------------------------------------------------------
/betterocr/engines/easy_pororo_ocr/utils/pre_processing.py:
--------------------------------------------------------------------------------
1 | """
2 | this code is adapted from https://github.com/black7375/korean_ocr_using_pororo
3 |
4 | Apache License 2.0 @yunwoong7
5 | Apache License 2.0 @black7375
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 |
11 |
12 | # https://tesseract-ocr.github.io/tessdoc/ImproveQuality.html
13 | # https://nanonets.com/blog/ocr-with-tesseract/
14 | # https://github.com/TCAT-capstone/ocr-preprocessor/blob/main/main.py
15 | # https://towardsdatascience.com/pre-processing-in-ocr-fc231c6035a7
16 | # == Main ======================================================================
17 | def load(path):
18 | return cv2.imread(path)
19 |
20 |
21 | def image_filter(image):
22 | image = grayscale(image)
23 | # image = thresholding(image, mode="GAUSSIAN")
24 | # image = opening(image)
25 | # image = closing(image)
26 | return image
27 |
28 |
29 | def roi_filter(image):
30 | image = resize(image)
31 | return image
32 |
33 |
34 | def load_with_filter(path):
35 | image = load(path)
36 | return image_filter(image)
37 |
38 |
39 | def isEven(num):
40 | return num % 2 == 0
41 |
42 |
43 | # == Color =====================================================================
44 | def grayscale(image, blur=False):
45 | return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
46 |
47 |
48 | def invert(image):
49 | return cv2.bitwise_not(image)
50 |
51 |
52 | # https://opencv-python.readthedocs.io/en/latest/doc/09.imageThresholding/imageThresholding.html
53 | def thresholding(image, mode="GENERAL", block_size=9, C=5):
54 | if isEven(block_size):
55 | print("block_size to use odd")
56 | return
57 |
58 | if mode == "MEAN":
59 | return cv2.adaptiveThreshold(
60 | image, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, block_size, C
61 | )
62 | elif mode == "GAUSSIAN":
63 | return cv2.adaptiveThreshold(
64 | image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, C
65 | )
66 | else:
67 | return cv2.threshold(image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
68 |
69 |
70 | def normalization(image, mode="COLOR", result_size=None):
71 | result_size = np.zeros(result_size) if result_size is not None else None
72 |
73 | if mode == "COLOR":
74 | return cv2.normalize(image, result_size, 0, 255, cv2.NORM_MINMAX)
75 |
76 | if mode == "GRAY":
77 | return cv2.normalize(
78 | image, result_size, 0, 1.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F
79 | )
80 |
81 |
82 | def equalization(image):
83 | return cv2.equalizeHist(image)
84 |
85 |
86 | # == Noise =====================================================================
87 | def remove_noise(image, mode="COLOR"):
88 | if mode == "COLOR":
89 | return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 15)
90 | if mode == "GRAY":
91 | return cv2.fastNlMeansDenoising(image, None, 10, 7, 21)
92 |
93 |
94 | def blur(image, kernel=(3, 3)):
95 | return cv2.GaussianBlur(image, kernel, 0)
96 |
97 |
98 | def blur_median(image, kernel_size=3):
99 | return cv2.medianBlur(image, ksize=kernel_size)
100 |
101 |
102 | # == Morphology ================================================================
103 | def dilation(image, kernel=np.ones((3, 3), np.uint8)):
104 | return cv2.dilate(image, kernel, iterations=1)
105 |
106 |
107 | def erosion(image, kernel=np.ones((3, 3), np.uint8)):
108 | return cv2.erode(image, kernel, iterations=1)
109 |
110 |
111 | def opening(image, kernel=np.ones((3, 3), np.uint8)):
112 | return cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)
113 |
114 |
115 | def closing(image, kernel=np.ones((3, 3), np.uint8)):
116 | return cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)
117 |
118 |
119 | def gradient(image, kernel=np.ones((3, 3), np.uint8)):
120 | return cv2.morphologyEx(image, cv2.MORPH_GRADIENT, kernel)
121 |
122 |
123 | def canny(image):
124 | return cv2.Canny(image, 100, 200)
125 |
126 |
127 | # == Others ====================================================================
128 | def resize(image, interpolation=cv2.INTER_CUBIC):
129 | height, width = image.shape[:2]
130 | factor = max(1, (20.0 / height))
131 | size = (int(factor * width), int(factor * height))
132 | return cv2.resize(image, dsize=size, interpolation=interpolation)
133 |
134 |
135 | # https://becominghuman.ai/how-to-automatically-deskew-straighten-a-text-image-using-opencv-a0c30aed83df
136 | def deskew(image):
137 | coords = np.column_stack(np.where(image > 0))
138 | angle = cv2.minAreaRect(coords)[-1]
139 | if angle < -45:
140 | angle = -(90 + angle)
141 | else:
142 | angle = -angle
143 |
144 | (h, w) = image.shape[:2]
145 | center = (w // 2, h // 2)
146 | M = cv2.getRotationMatrix2D(center, angle, 1.0)
147 | rotated = cv2.warpAffine(
148 | image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE
149 | )
150 | return rotated
151 |
152 |
153 | def match_template(image, template):
154 | return cv2.matchTemplate(image, template, cv2.TM_CCOEFF_NORMED)
155 |
--------------------------------------------------------------------------------
/betterocr/parsers.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 |
4 |
5 | def extract_json(input_string):
6 | # Find the JSON in the string
7 | matches = re.findall(r'{\s*"data"\s*:\s*"(.*?)"\s*}', input_string, re.DOTALL)
8 | if matches:
9 | # Correctly escape special characters
10 | matches = [m.replace("\n", "\\n").replace('"', '\\"') for m in matches]
11 | for match in matches:
12 | # Construct JSON string
13 | json_string = f'{{"data": "{match}"}}'
14 | try:
15 | # Load the JSON and return the data
16 | json_obj = json.loads(json_string)
17 | return json_obj
18 | except json.decoder.JSONDecodeError:
19 | continue
20 |
21 | # If no JSON found, return None
22 | return None
23 |
24 |
25 | def extract_list(s):
26 | stack = []
27 | start_position = None
28 |
29 | # Iterate through each character in the string
30 | for i, c in enumerate(s):
31 | if c == "[":
32 | if start_position is None: # First '[' found
33 | start_position = i
34 | stack.append(c)
35 |
36 | elif c == "]":
37 | if stack:
38 | stack.pop()
39 |
40 | # If stack is empty and start was marked
41 | if not stack and start_position is not None:
42 | substring = s[start_position : i + 1]
43 | try:
44 | list_obj = json.loads(substring)
45 | for item in list_obj:
46 | if "box" in item and "text" in item:
47 | return list_obj
48 | except json.decoder.JSONDecodeError:
49 | # Reset the stack and start position as this isn't a valid JSON
50 | stack = []
51 | start_position = None
52 | continue
53 |
54 | # If no valid list found, return None
55 | return None
56 |
57 |
58 | def rectangle_corners(rect):
59 | x, y, w, h = rect
60 | return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
61 |
--------------------------------------------------------------------------------
/betterocr/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .easy_ocr import job_easy_ocr, job_easy_ocr_boxes
2 | from .tesseract.job import job_tesseract, job_tesseract_boxes
3 |
4 | __all__ = [
5 | "job_easy_ocr",
6 | "job_easy_ocr_boxes",
7 | "job_tesseract",
8 | "job_tesseract_boxes",
9 | ]
10 |
--------------------------------------------------------------------------------
/betterocr/wrappers/easy_ocr.py:
--------------------------------------------------------------------------------
1 | import easyocr
2 |
3 |
4 | def job_easy_ocr(_options):
5 | reader = easyocr.Reader(_options["lang"])
6 | text = reader.readtext(_options["path"], detail=0)
7 | text = "".join(text)
8 | print("[*] job_easy_ocr", text)
9 | return text
10 |
11 |
12 | def job_easy_ocr_boxes(_options):
13 | reader = easyocr.Reader(_options["lang"])
14 | boxes = reader.readtext(_options["path"], output_format="dict")
15 | for box in boxes:
16 | box["box"] = box.pop("boxes")
17 | return boxes
18 |
--------------------------------------------------------------------------------
/betterocr/wrappers/easy_pororo_ocr.py:
--------------------------------------------------------------------------------
1 | from ..engines.easy_pororo_ocr import EasyPororoOcr, load_with_filter
2 |
3 |
4 | def parse_languages(lang: list[str]):
5 | languages = []
6 | for l in lang:
7 | if l in ["ko", "en"]:
8 | languages.append(l)
9 |
10 | if len(languages) == 0:
11 | languages = ["ko"]
12 |
13 | return languages
14 |
15 |
16 | def default_ocr(_options):
17 | lang = parse_languages(_options["lang"])
18 | return EasyPororoOcr(lang)
19 |
20 |
21 | def job_easy_pororo_ocr(_options):
22 | image = load_with_filter(_options["path"])
23 |
24 | ocr = _options.get("ocr")
25 | if not ocr:
26 | ocr = default_ocr(_options)
27 |
28 | text = ocr.run_ocr(image, debug=False)
29 |
30 | if isinstance(text, list):
31 | text = "\\n".join(text)
32 |
33 | print("[*] job_easy_pororo_ocr", text)
34 | return text
35 |
36 |
37 | def job_easy_pororo_ocr_boxes(_options):
38 | ocr = default_ocr(_options)
39 | job_easy_pororo_ocr({**_options, "ocr": ocr})
40 | return ocr.get_boxes()
41 |
--------------------------------------------------------------------------------
/betterocr/wrappers/tesseract/__init__.py:
--------------------------------------------------------------------------------
1 | from .job import convert_to_tesseract_lang_code, job_tesseract, job_tesseract_boxes
2 |
3 | __all__ = ["convert_to_tesseract_lang_code", "job_tesseract", "job_tesseract_boxes"]
4 |
--------------------------------------------------------------------------------
/betterocr/wrappers/tesseract/job.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import pytesseract
4 |
5 | from .mapping import LANG_CODE_MAPPING
6 |
7 |
8 | def convert_to_tesseract_lang_code(langs: list[str]) -> str:
9 | return "+".join(
10 | [
11 | LANG_CODE_MAPPING[lang]
12 | for lang in langs
13 | if lang in LANG_CODE_MAPPING and LANG_CODE_MAPPING[lang] is not None
14 | ]
15 | )
16 |
17 |
18 | def job_tesseract(_options):
19 | lang = convert_to_tesseract_lang_code(_options["lang"])
20 | text = pytesseract.image_to_string(
21 | _options["path"],
22 | lang=lang,
23 | **_options["tesseract"]
24 | # pass rest of tesseract options here.
25 | )
26 | text = text.replace("\n", "\\n")
27 | print("[*] job_tesseract_ocr", text)
28 | return text
29 |
30 |
31 | def job_tesseract_boxes(_options):
32 | lang = convert_to_tesseract_lang_code(_options["lang"])
33 | df = pytesseract.image_to_data(
34 | _options["path"],
35 | lang=lang,
36 | **_options["tesseract"],
37 | output_type=pytesseract.Output.DATAFRAME
38 | # pass rest of tesseract options here.
39 | )
40 |
41 | # https://stackoverflow.com/questions/74221064/draw-a-rectangle-around-a-string-of-words-using-pytesseract
42 | boxes = []
43 | for line_num, words_per_line in df.groupby("line_num"):
44 | words_per_line = words_per_line[words_per_line["conf"] >= 5]
45 | if len(words_per_line) == 0:
46 | continue
47 |
48 | words = words_per_line["text"].values
49 | line = " ".join(words)
50 |
51 | word_boxes = []
52 | for left, top, width, height in words_per_line[
53 | ["left", "top", "width", "height"]
54 | ].values:
55 | word_boxes.append((left, top))
56 | word_boxes.append((left + width, top + height))
57 |
58 | x, y, w, h = cv2.boundingRect(np.array(word_boxes))
59 | boxes.append(
60 | {
61 | "box": [[x, y], [x + w, y], [x + w, y + h], [x, y + h]],
62 | "text": line,
63 | }
64 | )
65 |
66 | return boxes
67 |
--------------------------------------------------------------------------------
/betterocr/wrappers/tesseract/mapping.py:
--------------------------------------------------------------------------------
1 | LANG_CODE_MAPPING = {
2 | "abq": "afr", # Not a direct match
3 | "ady": None, # Not found
4 | "af": None, # Not found
5 | "ang": None, # Not found
6 | "ar": "ara",
7 | "as": "asm",
8 | "ava": None, # Not found
9 | "az": "aze",
10 | "be": "bel",
11 | "bg": "bul",
12 | "bh": None, # Not found
13 | "bho": None, # Not found
14 | "bn": "ben",
15 | "bs": "bos",
16 | "ch_sim": "chi_sim",
17 | "ch_tra": "chi_tra",
18 | "che": None, # Not found
19 | "cs": "ces",
20 | "cy": "cym",
21 | "da": "dan",
22 | "dar": None, # Not found
23 | "de": "deu",
24 | "en": "eng",
25 | "es": "spa",
26 | "et": "est",
27 | "fa": "fas",
28 | "fr": "fra",
29 | "ga": "gle",
30 | "gom": None, # Not found
31 | "hi": "hin",
32 | "hr": "hrv",
33 | "hu": "hun",
34 | "id": "ind",
35 | "inh": None, # Not found
36 | "is": "isl",
37 | "it": "ita",
38 | "ja": "jpn",
39 | "kbd": None, # Not found
40 | "kn": "kan",
41 | "ko": "kor",
42 | "ku": None, # Not found
43 | "la": "lat",
44 | "lbe": None, # Not found
45 | "lez": None, # Not found
46 | "lt": "lit",
47 | "lv": "lav",
48 | "mah": None, # Not found
49 | "mai": None, # Not found
50 | "mi": "mri",
51 | "mn": "mon",
52 | "mr": "mar",
53 | "ms": "msa",
54 | "mt": "mlt",
55 | "ne": "nep",
56 | "new": None, # Not found
57 | "nl": "nld",
58 | "no": "nor",
59 | "oc": "oci",
60 | "pi": None, # Not found
61 | "pl": "pol",
62 | "pt": "por",
63 | "ro": "ron",
64 | "ru": "rus",
65 | "rs_cyrillic": None, # Not a direct match
66 | "rs_latin": None, # Not a direct match
67 | "sck": None, # Not found
68 | "sk": "slk",
69 | "sl": "slv",
70 | "sq": "sqi",
71 | "sv": "swe",
72 | "sw": "swa",
73 | "ta": "tam",
74 | "tab": None, # Not found
75 | "te": "tel",
76 | "th": "tha",
77 | "tjk": None, # Not found
78 | "tl": "tgl",
79 | "tr": "tur",
80 | "ug": "uig",
81 | "uk": "ukr",
82 | "ur": "urd",
83 | "uz": "uzb",
84 | "vi": "vie",
85 | }
86 |
--------------------------------------------------------------------------------
/examples/detect_boxes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import matplotlib.pyplot as plt
4 | from PIL import Image, ImageDraw, ImageFont
5 |
6 | from betterocr import detect_boxes
7 |
8 | image_path = ".github/images/demo-1.png"
9 | result = detect_boxes(
10 | image_path,
11 | ["ko", "en"],
12 | context="퍼멘테이션 펩타인 아이케어 크림", # product name
13 | tesseract={
14 | "config": "--psm 6 --tessdata-dir ./tessdata -c tessedit_create_boxfile=1"
15 | },
16 | )
17 | print(result)
18 |
19 | font_path = ".github/examples/Pretendard-Medium.ttf"
20 | font_size = 36
21 | font = ImageFont.truetype(font_path, size=font_size, encoding="unic")
22 |
23 | img = cv2.imread(image_path)
24 | img_rgb = cv2.cvtColor(
25 | img, cv2.COLOR_BGR2RGB
26 | ) # Convert from BGR to RGB for correct color display in matplotlib
27 |
28 | for item in result:
29 | box = item["box"]
30 | pt1, pt2, pt3, pt4 = box
31 | top_left, bottom_right = tuple(pt1), tuple(pt3)
32 |
33 | # Draw rectangle
34 | cv2.rectangle(img_rgb, top_left, bottom_right, (0, 255, 0), 2)
35 |
36 | # For text
37 | pil_img = Image.fromarray(img_rgb)
38 | draw = ImageDraw.Draw(pil_img)
39 |
40 | # Get text dimensions
41 | text_width = draw.textlength(item["text"], font=font)
42 | text_height = font_size # line height = font size
43 |
44 | # Position the text just above the rectangle
45 | text_pos = (top_left[0], top_left[1] - text_height)
46 |
47 | # Draw text background rectangle
48 | draw.rectangle(
49 | [text_pos, (text_pos[0] + text_width, text_pos[1] + text_height)],
50 | fill=(0, 255, 0),
51 | )
52 |
53 | # Draw text
54 | draw.text(text_pos, item["text"], font=font, fill=(0, 0, 0))
55 |
56 | # Convert the PIL image back to an OpenCV image
57 | img_rgb = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
58 |
59 |
60 | height, width, _ = img.shape
61 |
62 | # Convert pixel dimensions to inches
63 | dpi = 80 # Assuming default matplotlib dpi
64 | fig_width = width / dpi
65 | fig_height = height / dpi
66 |
67 | plt.figure(figsize=(fig_width, fig_height))
68 | plt.imshow(img_rgb)
69 | plt.axis("off")
70 | plt.show()
71 |
--------------------------------------------------------------------------------
/examples/detect_text.py:
--------------------------------------------------------------------------------
1 | from betterocr import detect_text
2 |
3 | text = detect_text(
4 | ".github/images/demo-1.png",
5 | ["ko", "en"],
6 | context="퍼멘테이션 펩타인 아이케어 크림", # product name
7 | tesseract={"config": "--tessdata-dir ./tessdata"},
8 | openai={
9 | "model": "gpt-3.5-turbo",
10 | },
11 | )
12 | print("\n")
13 | print(text)
14 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "betterocr"
3 | version = "1.2.0"
4 | description = "Better text detection by combining OCR engines with LLM."
5 | authors = ["Junho Yeo "]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = ">=3.11,<3.13"
10 | openai = { version = "^1.28.1", allow-prereleases = true }
11 | pytesseract = "^0.3.10"
12 | easyocr = "^1.7.1"
13 | numpy = "^1.26.4"
14 | pandas = "^2.2.2"
15 | opencv-python = "^4.9.0.80"
16 |
17 | [tool.poetry.group.dev.dependencies]
18 | matplotlib = "^3.8.4"
19 | pillow = "^10.3.0"
20 | pytest = "^8.2.0"
21 |
22 | [tool.poetry.group.pororo.dependencies]
23 | torch = "^2.1.0"
24 | torchvision = "^0.16.0"
25 | wget = "^3.2"
26 | scikit-image = "^0.22.0"
27 |
28 | [build-system]
29 | requires = ["poetry-core"]
30 | build-backend = "poetry.core.masonry.api"
31 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junhoyeo/BetterOCR/765fffe2988384e0775124b48b6d3b3ad1dbf069/tests/__init__.py
--------------------------------------------------------------------------------
/tests/parsers/test_extract_json.py:
--------------------------------------------------------------------------------
1 | from betterocr.parsers import extract_json
2 | import json
3 |
4 |
5 | def is_same_dict(dict1, dict2):
6 | return json.dumps(dict1) == json.dumps(dict2)
7 |
8 |
9 | def test_extract_json__default():
10 | assert is_same_dict(extract_json('{"data": "hello"}'), {"data": "hello"})
11 | assert is_same_dict(
12 | extract_json(
13 | """
14 | Using the provided context and both OCR data sets, I'll merge, correct, and provide the result:
15 | {"data": "{Optical Character Recognition}"}
16 | """
17 | ),
18 | {"data": "{Optical Character Recognition}"},
19 | )
20 |
21 |
22 | def test_extract_json__newlines():
23 | assert is_same_dict(
24 | extract_json('{"data": "hello\nworld"}'), {"data": "hello\nworld"}
25 | )
26 | assert is_same_dict(
27 | extract_json('{"data": "hello \n\n world"}'), {"data": "hello \n\n world"}
28 | )
29 |
30 |
31 | def test_extract_json__newlines_escaped():
32 | assert is_same_dict(
33 | extract_json('{"data": "hello\\nworld"}'), {"data": "hello\nworld"}
34 | )
35 |
36 |
37 | def test_extract_json__multiple_dicts_should_return_first_occurrence():
38 | assert is_same_dict(
39 | extract_json(
40 | """
41 | {"data": "JSON Data 1"}
42 | {"data": "JSON Data 2"}
43 | """
44 | ),
45 | {"data": "JSON Data 1"},
46 | )
47 | assert is_same_dict(
48 | extract_json(
49 | """
50 | {"invalid_key": "JSON Data 1"}
51 | {"data": "JSON Data 2"}
52 | """
53 | ),
54 | {"data": "JSON Data 2"},
55 | )
56 |
--------------------------------------------------------------------------------
/tests/parsers/test_extract_list.py:
--------------------------------------------------------------------------------
1 | from betterocr.parsers import extract_list
2 | import json
3 |
4 |
5 | def is_same_list(dict1, dict2):
6 | return json.dumps(dict1) == json.dumps(dict2)
7 |
8 |
9 | def test_extract_list__default():
10 | assert is_same_list(
11 | extract_list('[{"box":[],"text":""}]'),
12 | [{"box": [], "text": ""}],
13 | )
14 | assert is_same_list(
15 | extract_list(
16 | """
17 | Using the provided context and both OCR data sets, I'll merge, correct, and provide the result:
18 | [
19 | {
20 | "box": [123, 456, 789, 101112],
21 | "text": "Hello World!"
22 | }
23 | ]
24 | """
25 | ),
26 | [{"box": [123, 456, 789, 101112], "text": "Hello World!"}],
27 | )
28 |
--------------------------------------------------------------------------------