├── utils
├── __init__.py
└── image_util.py
├── pororo
├── tasks
│ ├── utils
│ │ ├── __init__.py
│ │ ├── tokenizer.py
│ │ ├── config.py
│ │ ├── base.py
│ │ └── download_utils.py
│ ├── __init__.py
│ └── optical_character_recognition.py
├── __version__.py
├── models
│ └── brainOCR
│ │ ├── modules
│ │ ├── __init__.py
│ │ ├── sequence_modeling.py
│ │ ├── basenet.py
│ │ ├── prediction.py
│ │ ├── transformation.py
│ │ └── feature_extraction.py
│ │ ├── __init__.py
│ │ ├── _dataset.py
│ │ ├── imgproc.py
│ │ ├── detection.py
│ │ ├── craft.py
│ │ ├── model.py
│ │ ├── recognition.py
│ │ ├── brainocr.py
│ │ ├── craft_utils.py
│ │ ├── _modules.py
│ │ └── utils.py
├── __init__.py
├── utils.py
└── pororo.py
├── assets
└── images
│ ├── test_image_1.jpg
│ ├── test_image_2.jpg
│ └── test_image_3.jpg
├── requirements.txt
├── README.md
├── main.py
└── LICENSE
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pororo/tasks/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pororo/__version__.py:
--------------------------------------------------------------------------------
1 | version = "0.4.1"
2 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/__init__.py:
--------------------------------------------------------------------------------
1 | from .brainocr import Reader # noqa
2 |
--------------------------------------------------------------------------------
/assets/images/test_image_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunwoong7/korean_ocr_using_pororo/HEAD/assets/images/test_image_1.jpg
--------------------------------------------------------------------------------
/assets/images/test_image_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunwoong7/korean_ocr_using_pororo/HEAD/assets/images/test_image_2.jpg
--------------------------------------------------------------------------------
/assets/images/test_image_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunwoong7/korean_ocr_using_pororo/HEAD/assets/images/test_image_3.jpg
--------------------------------------------------------------------------------
/pororo/__init__.py:
--------------------------------------------------------------------------------
1 | from pororo.__version__ import version as __version__ # noqa
2 | from pororo.pororo import Pororo # noqa
3 |
--------------------------------------------------------------------------------
/pororo/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | """
3 | __init__.py for import child .py files
4 |
5 | isort:skip_file
6 | """
7 |
8 | # Utility classes & functions
9 | import pororo.tasks.utils
10 | from pororo.tasks.utils.download_utils import download_or_load
11 | from pororo.tasks.utils.base import (
12 | PororoBiencoderBase,
13 | PororoFactoryBase,
14 | PororoGenerationBase,
15 | PororoSimpleBase,
16 | PororoTaskGenerationBase,
17 | )
18 |
19 | # Factory classes
20 | from pororo.tasks.optical_character_recognition import PororoOcrFactory
21 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/sequence_modeling.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BidirectionalLSTM(nn.Module):
5 |
6 | def __init__(self, input_size: int, hidden_size: int, output_size: int):
7 | super(BidirectionalLSTM, self).__init__()
8 | self.rnn = nn.LSTM(input_size,
9 | hidden_size,
10 | bidirectional=True,
11 | batch_first=True)
12 | self.linear = nn.Linear(hidden_size * 2, output_size)
13 |
14 | def forward(self, x):
15 | """
16 | x : visual feature [batch_size x T=24 x input_size=512]
17 | output : contextual feature [batch_size x T x output_size]
18 | """
19 | self.rnn.flatten_parameters()
20 | recurrent, _ = self.rnn(
21 | x
22 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
23 | output = self.linear(recurrent) # batch_size x T x output_size
24 | return output
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | appnope==0.1.3
2 | backcall==0.2.0
3 | certifi==2022.12.7
4 | charset-normalizer==3.0.1
5 | cycler==0.11.0
6 | debugpy==1.6.6
7 | decorator==5.1.1
8 | entrypoints==0.4
9 | fonttools==4.38.0
10 | idna==3.4
11 | imageio==2.25.0
12 | ipykernel==6.16.2
13 | ipython==7.34.0
14 | jedi==0.18.2
15 | jupyter_client==7.4.9
16 | jupyter_core==4.12.0
17 | kiwisolver==1.4.4
18 | matplotlib==3.5.3
19 | matplotlib-inline==0.1.6
20 | nest-asyncio==1.5.6
21 | networkx==2.6.3
22 | numpy==1.21.6
23 | opencv-python==4.7.0.68
24 | packaging==23.0
25 | parso==0.8.3
26 | pexpect==4.8.0
27 | pickleshare==0.7.5
28 | Pillow==9.4.0
29 | pip==22.3.1
30 | prompt-toolkit==3.0.36
31 | psutil==5.9.4
32 | ptyprocess==0.7.0
33 | Pygments==2.14.0
34 | pyparsing==3.0.9
35 | python-dateutil==2.8.2
36 | PyWavelets==1.3.0
37 | pyzmq==25.0.0
38 | requests==2.28.2
39 | scikit-image==0.19.3
40 | scipy==1.7.3
41 | setuptools==65.6.3
42 | six==1.16.0
43 | tifffile==2021.11.2
44 | torch==1.13.1
45 | torchvision==0.14.1
46 | tornado==6.2
47 | traitlets==5.9.0
48 | typing_extensions==4.4.0
49 | urllib3==1.26.14
50 | wcwidth==0.2.6
51 | wget==3.2
52 | wheel==0.37.1
53 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from natsort import natsorted
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class RawDataset(Dataset):
9 |
10 | def __init__(self, root, imgW, imgH):
11 | self.imgW = imgW
12 | self.imgH = imgH
13 | self.image_path_list = []
14 | for dirpath, _, filenames in os.walk(root):
15 | for name in filenames:
16 | _, ext = os.path.splitext(name)
17 | ext = ext.lower()
18 | if ext in (".jpg", ".jpeg", ".png"):
19 | self.image_path_list.append(os.path.join(dirpath, name))
20 |
21 | self.image_path_list = natsorted(self.image_path_list)
22 | self.nSamples = len(self.image_path_list)
23 |
24 | def __len__(self):
25 | return self.nSamples
26 |
27 | def __getitem__(self, index):
28 | try:
29 | img = Image.open(self.image_path_list[index]).convert("L")
30 |
31 | except IOError:
32 | print(f"Corrupted image for {index}")
33 | img = Image.new("L", (self.imgW, self.imgH))
34 |
35 | return img, self.image_path_list[index]
36 |
--------------------------------------------------------------------------------
/utils/image_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import platform
4 | from PIL import ImageFont, ImageDraw, Image
5 | from matplotlib import pyplot as plt
6 |
7 |
8 | def plt_imshow(title='image', img=None, figsize=(8, 5)):
9 | plt.figure(figsize=figsize)
10 |
11 | if type(img) is str:
12 | img = cv2.imread(img)
13 |
14 | if type(img) == list:
15 | if type(title) == list:
16 | titles = title
17 | else:
18 | titles = []
19 |
20 | for i in range(len(img)):
21 | titles.append(title)
22 |
23 | for i in range(len(img)):
24 | if len(img[i].shape) <= 2:
25 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_GRAY2RGB)
26 | else:
27 | rgbImg = cv2.cvtColor(img[i], cv2.COLOR_BGR2RGB)
28 |
29 | plt.subplot(1, len(img), i + 1), plt.imshow(rgbImg)
30 | plt.title(titles[i])
31 | plt.xticks([]), plt.yticks([])
32 |
33 | plt.show()
34 | else:
35 | if len(img.shape) < 3:
36 | rgbImg = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
37 | else:
38 | rgbImg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
39 |
40 | plt.imshow(rgbImg)
41 | plt.title(title)
42 | plt.xticks([]), plt.yticks([])
43 | plt.show()
44 |
45 |
46 | def put_text(image, text, x, y, color=(0, 255, 0), font_size=22):
47 | if type(image) == np.ndarray:
48 | color_coverted = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
49 | image = Image.fromarray(color_coverted)
50 |
51 | if platform.system() == 'Darwin':
52 | font = 'AppleGothic.ttf'
53 | elif platform.system() == 'Windows':
54 | font = 'malgun.ttf'
55 |
56 | image_font = ImageFont.truetype(font, font_size)
57 | font = ImageFont.load_default()
58 | draw = ImageDraw.Draw(image)
59 |
60 | draw.text((x, y), text, font=image_font, fill=color)
61 |
62 | numpy_image = np.array(image)
63 | opencv_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
64 |
65 | return opencv_image
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Korean OCR using pororo
3 |
4 |
5 |
11 |
12 | This is a Korean OCR Python code using the Pororo library.
13 |
14 |
15 |

16 |
17 |
18 | ## Requirements
19 |
20 | - torch
21 | - torchvision
22 | - opencv-python
23 |
24 | You can install it from PyPI:
25 |
26 | ```sh
27 | pip install torch
28 | pip install torchvision
29 | pip install opencv-python
30 | ```
31 |
32 | ## PORORO: Platform Of neuRal mOdels for natuRal language prOcessing
33 |
34 | [pororo](https://github.com/kakaobrain/pororo) is a library developed by KakaoBrain for performing natural language processing and speech-related tasks.
35 |
36 | This repository is configured to only include the OCR functionality from the pororo library. If you wish to use other pororo features such as natural language processing, please install pororo through `pip install pororo`.
37 |
38 | ## Usage
39 |
40 | ```python
41 | from pororo import Pororo
42 |
43 | ocr = PororoOcr()
44 | image_path = input("Enter image path: ")
45 | text = ocr.run_ocr(image_path, debug=True)
46 | print('Result :', text)
47 | ```
48 |
49 | Output:
50 |
51 | ```sh
52 | ['메이크업존 MAKEUP ZONE', '드레스 피팅룸 DRESS FITTING ROOM', '포토존 PHOTO ZONE']
53 | ```
54 |
55 | ------
56 |
57 |
58 |
59 |
60 |

61 |
62 |
63 | ```sh
64 | ["Life is ot a spectator sport. If you're going to spend your whole life in the grandstand just watching what goes on, in my apinion you're wasting your life.",
65 | "인생은 구경거리가 아니다. 무슨 일이 일어나는지 보기만 하는 것은 인생을 낭비하고 있는 것이다.",
66 | 'Jackie Robinson']
67 | ```
--------------------------------------------------------------------------------
/pororo/utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from tempfile import NamedTemporaryFile
3 |
4 | from requests import get
5 |
6 |
7 | def postprocess_span(tagger, text: str) -> str:
8 | """
9 | Postprocess NOUN span to remove unnecessary character
10 |
11 | Args:
12 | text (str): NOUN span to be processed
13 |
14 | Returns:
15 | (str): post-processed NOUN span
16 |
17 | Examples:
18 | >>> postprocess_span("강감찬 장군은")
19 | '강감찬 장군'
20 | >>> postprocess_span("그녀에게")
21 | '그녀'
22 |
23 | """
24 |
25 | # First, strip punctuations
26 | text = text.strip("""!"\#$&'()*+,\-./:;<=>?@\^_‘{|}~《》""")
27 |
28 | # Complete imbalanced parentheses pair
29 | if text.count("(") == text.count(")") + 1:
30 | text += ")"
31 | elif text.count("(") + 1 == text.count(")"):
32 | text = "(" + text
33 |
34 | # Preserve beginning tokens since we only want to extract noun phrase of the last eojeol
35 | noun_phrase = " ".join(text.rsplit(" ", 1)[:-1])
36 | tokens = text.split(" ")
37 | eojeols = list()
38 | for token in tokens:
39 | eojeols.append(tagger.pos(token))
40 | last_eojeol = eojeols[-1]
41 |
42 | # Iterate backwardly to remove unnecessary postfixes
43 | i = 0
44 | for i, token in enumerate(last_eojeol[::-1]):
45 | _, pos = token
46 | # 1. The loop breaks when you meet a noun
47 | # 2. The loop also breaks when you meet a XSN (e.g. 8/SN+일/NNB LG/SL 전/XSN)
48 | if (pos[0] in ("N", "S")) or pos.startswith("XSN"):
49 | break
50 | idx = len(last_eojeol) - i
51 |
52 | # Extract noun span from last eojeol and postpend it to beginning tokens
53 | ext_last_eojeol = "".join(morph for morph, _ in last_eojeol[:idx])
54 | noun_phrase += " " + ext_last_eojeol
55 | return noun_phrase.strip()
56 |
57 |
58 | @contextmanager
59 | def control_temp(file_path: str):
60 | """
61 | Download temporary file from web, then remove it after some context
62 |
63 | Args:
64 | file_path (str): web file path
65 |
66 | """
67 | # yapf: disable
68 | assert file_path.startswith("http"), "File path should contain `http` prefix !"
69 | # yapf: enable
70 |
71 | ext = file_path[file_path.rfind("."):]
72 |
73 | with NamedTemporaryFile("wb", suffix=ext, delete=True) as f:
74 | response = get(file_path, allow_redirects=True)
75 | f.write(response.content)
76 | yield f.name
77 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/imgproc.py:
--------------------------------------------------------------------------------
1 | """
2 | This is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/imgproc.py
3 | Copyright (c) 2019-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import cv2
8 | import numpy as np
9 | from skimage import io
10 |
11 |
12 | def load_image(img_file):
13 | img = io.imread(img_file) # RGB order
14 | if img.shape[0] == 2:
15 | img = img[0]
16 | if len(img.shape) == 2:
17 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
18 | if img.shape[2] == 4:
19 | img = img[:, :, :3]
20 | img = np.array(img)
21 |
22 | return img
23 |
24 |
25 | def normalize_mean_variance(
26 | in_img,
27 | mean=(0.485, 0.456, 0.406),
28 | variance=(0.229, 0.224, 0.225),
29 | ):
30 | # should be RGB order
31 | img = in_img.copy().astype(np.float32)
32 |
33 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0],
34 | dtype=np.float32)
35 | img /= np.array(
36 | [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
37 | dtype=np.float32,
38 | )
39 | return img
40 |
41 |
42 | def denormalize_mean_variance(
43 | in_img,
44 | mean=(0.485, 0.456, 0.406),
45 | variance=(0.229, 0.224, 0.225),
46 | ):
47 | # should be RGB order
48 | img = in_img.copy()
49 | img *= variance
50 | img += mean
51 | img *= 255.0
52 | img = np.clip(img, 0, 255).astype(np.uint8)
53 | return img
54 |
55 |
56 | def resize_aspect_ratio(
57 | img: np.ndarray,
58 | square_size: int,
59 | interpolation: int,
60 | mag_ratio: float = 1.0,
61 | ):
62 | height, width, channel = img.shape
63 |
64 | # magnify image size
65 | target_size = mag_ratio * max(height, width)
66 |
67 | # set original image size
68 | if target_size > square_size:
69 | target_size = square_size
70 |
71 | ratio = target_size / max(height, width)
72 |
73 | target_h, target_w = int(height * ratio), int(width * ratio)
74 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
75 |
76 | # make canvas and paste image
77 | target_h32, target_w32 = target_h, target_w
78 | if target_h % 32 != 0:
79 | target_h32 = target_h + (32 - target_h % 32)
80 | if target_w % 32 != 0:
81 | target_w32 = target_w + (32 - target_w % 32)
82 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
83 | resized[0:target_h, 0:target_w, :] = proc
84 | target_h, target_w = target_h32, target_w32
85 |
86 | size_heatmap = (int(target_w / 2), int(target_h / 2))
87 |
88 | return resized, ratio, size_heatmap
89 |
90 |
91 | def cvt2heatmap_img(img):
92 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
93 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
94 | return img
95 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from pororo import Pororo
3 | from pororo.pororo import SUPPORTED_TASKS
4 | from utils.image_util import plt_imshow, put_text
5 | import warnings
6 |
7 | warnings.filterwarnings('ignore')
8 |
9 |
10 | class PororoOcr:
11 | def __init__(self, model: str = "brainocr", lang: str = "ko", **kwargs):
12 | self.model = model
13 | self.lang = lang
14 | self._ocr = Pororo(task="ocr", lang=lang, model=model, **kwargs)
15 | self.img_path = None
16 | self.ocr_result = {}
17 |
18 | def run_ocr(self, img_path: str, debug: bool = False):
19 | self.img_path = img_path
20 | self.ocr_result = self._ocr(img_path, detail=True)
21 |
22 | if self.ocr_result['description']:
23 | ocr_text = self.ocr_result["description"]
24 | else:
25 | ocr_text = "No text detected."
26 |
27 | if debug:
28 | self.show_img_with_ocr()
29 |
30 | return ocr_text
31 |
32 | @staticmethod
33 | def get_available_langs():
34 | return SUPPORTED_TASKS["ocr"].get_available_langs()
35 |
36 | @staticmethod
37 | def get_available_models():
38 | return SUPPORTED_TASKS["ocr"].get_available_models()
39 |
40 | def get_ocr_result(self):
41 | return self.ocr_result
42 |
43 | def get_img_path(self):
44 | return self.img_path
45 |
46 | def show_img(self):
47 | plt_imshow(img=self.img_path)
48 |
49 | def show_img_with_ocr(self):
50 | img = cv2.imread(self.img_path)
51 | roi_img = img.copy()
52 |
53 | for text_result in self.ocr_result['bounding_poly']:
54 | text = text_result['description']
55 | tlX = text_result['vertices'][0]['x']
56 | tlY = text_result['vertices'][0]['y']
57 | trX = text_result['vertices'][1]['x']
58 | trY = text_result['vertices'][1]['y']
59 | brX = text_result['vertices'][2]['x']
60 | brY = text_result['vertices'][2]['y']
61 | blX = text_result['vertices'][3]['x']
62 | blY = text_result['vertices'][3]['y']
63 |
64 | pts = ((tlX, tlY), (trX, trY), (brX, brY), (blX, blY))
65 |
66 | topLeft = pts[0]
67 | topRight = pts[1]
68 | bottomRight = pts[2]
69 | bottomLeft = pts[3]
70 |
71 | cv2.line(roi_img, topLeft, topRight, (0, 255, 0), 2)
72 | cv2.line(roi_img, topRight, bottomRight, (0, 255, 0), 2)
73 | cv2.line(roi_img, bottomRight, bottomLeft, (0, 255, 0), 2)
74 | cv2.line(roi_img, bottomLeft, topLeft, (0, 255, 0), 2)
75 | roi_img = put_text(roi_img, text, topLeft[0], topLeft[1] - 20, font_size=15)
76 |
77 | # print(text)
78 |
79 | plt_imshow(["Original", "ROI"], [img, roi_img], figsize=(16, 10))
80 |
81 |
82 | if __name__ == "__main__":
83 | ocr = PororoOcr()
84 | image_path = input("Enter image path: ")
85 | text = ocr.run_ocr(image_path, debug=True)
86 | print('Result :', text)
87 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/detection.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/detection.py
3 | """
4 |
5 | from collections import OrderedDict
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.autograd import Variable
12 |
13 | from .craft import CRAFT
14 | from .craft_utils import adjust_result_coordinates, get_det_boxes
15 | from .imgproc import normalize_mean_variance, resize_aspect_ratio
16 |
17 |
18 | def copy_state_dict(state_dict):
19 | if list(state_dict.keys())[0].startswith("module"):
20 | start_idx = 1
21 | else:
22 | start_idx = 0
23 | new_state_dict = OrderedDict()
24 | for k, v in state_dict.items():
25 | name = ".".join(k.split(".")[start_idx:])
26 | new_state_dict[name] = v
27 | return new_state_dict
28 |
29 |
30 | def test_net(image: np.ndarray, net, opt2val: dict):
31 | canvas_size = opt2val["canvas_size"]
32 | mag_ratio = opt2val["mag_ratio"]
33 | text_threshold = opt2val["text_threshold"]
34 | link_threshold = opt2val["link_threshold"]
35 | low_text = opt2val["low_text"]
36 | device = opt2val["device"]
37 |
38 | # resize
39 | img_resized, target_ratio, size_heatmap = resize_aspect_ratio(
40 | image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio)
41 | ratio_h = ratio_w = 1 / target_ratio
42 |
43 | # preprocessing
44 | x = normalize_mean_variance(img_resized)
45 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
46 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
47 | x = x.to(device)
48 |
49 | # forward pass
50 | with torch.no_grad():
51 | y, feature = net(x)
52 |
53 | # make score and link map
54 | score_text = y[0, :, :, 0].cpu().data.numpy()
55 | score_link = y[0, :, :, 1].cpu().data.numpy()
56 |
57 | # Post-processing
58 | boxes, polys = get_det_boxes(
59 | score_text,
60 | score_link,
61 | text_threshold,
62 | link_threshold,
63 | low_text,
64 | )
65 |
66 | # coordinate adjustment
67 | boxes = adjust_result_coordinates(boxes, ratio_w, ratio_h)
68 | polys = adjust_result_coordinates(polys, ratio_w, ratio_h)
69 | for k in range(len(polys)):
70 | if polys[k] is None:
71 | polys[k] = boxes[k]
72 |
73 | return boxes, polys
74 |
75 |
76 | def get_detector(det_model_ckpt_fp: str, device: str = "cpu"):
77 | net = CRAFT()
78 |
79 | net.load_state_dict(
80 | copy_state_dict(torch.load(det_model_ckpt_fp, map_location=device)))
81 | if device == "cuda":
82 | net = torch.nn.DataParallel(net).to(device)
83 | cudnn.benchmark = False
84 |
85 | net.eval()
86 | return net
87 |
88 |
89 | def get_textbox(detector, image: np.ndarray, opt2val: dict):
90 | bboxes, polys = test_net(image, detector, opt2val)
91 | result = []
92 | for i, box in enumerate(polys):
93 | poly = np.array(box).astype(np.int32).reshape((-1))
94 | result.append(poly)
95 |
96 | return result
97 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/basenet.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torchvision import models
7 | from torchvision.models.vgg import model_urls
8 |
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 |
12 | def init_weights(modules):
13 | for m in modules:
14 | if isinstance(m, nn.Conv2d):
15 | init.xavier_uniform_(m.weight.data)
16 | if m.bias is not None:
17 | m.bias.data.zero_()
18 | elif isinstance(m, nn.BatchNorm2d):
19 | m.weight.data.fill_(1)
20 | m.bias.data.zero_()
21 | elif isinstance(m, nn.Linear):
22 | m.weight.data.normal_(0, 0.01)
23 | m.bias.data.zero_()
24 |
25 |
26 | class Vgg16BN(torch.nn.Module):
27 |
28 | def __init__(self, pretrained: bool = True, freeze: bool = True):
29 | super(Vgg16BN, self).__init__()
30 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace(
31 | "https://", "http://")
32 | vgg_pretrained_features = models.vgg16_bn(
33 | pretrained=pretrained).features
34 | self.slice1 = torch.nn.Sequential()
35 | self.slice2 = torch.nn.Sequential()
36 | self.slice3 = torch.nn.Sequential()
37 | self.slice4 = torch.nn.Sequential()
38 | self.slice5 = torch.nn.Sequential()
39 | for x in range(12): # conv2_2
40 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
41 | for x in range(12, 19): # conv3_3
42 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
43 | for x in range(19, 29): # conv4_3
44 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
45 | for x in range(29, 39): # conv5_3
46 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
47 |
48 | # fc6, fc7 without atrous conv
49 | self.slice5 = torch.nn.Sequential(
50 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
51 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
52 | nn.Conv2d(1024, 1024, kernel_size=1),
53 | )
54 |
55 | if not pretrained:
56 | init_weights(self.slice1.modules())
57 | init_weights(self.slice2.modules())
58 | init_weights(self.slice3.modules())
59 | init_weights(self.slice4.modules())
60 |
61 | init_weights(
62 | self.slice5.modules()) # no pretrained model for fc6 and fc7
63 |
64 | if freeze:
65 | for param in self.slice1.parameters(): # only first conv
66 | param.requires_grad = False
67 |
68 | def forward(self, x):
69 | h = self.slice1(x)
70 | h_relu2_2 = h
71 | h = self.slice2(h)
72 | h_relu3_2 = h
73 | h = self.slice3(h)
74 | h_relu4_3 = h
75 | h = self.slice4(h)
76 | h_relu5_3 = h
77 | h = self.slice5(h)
78 | h_fc7 = h
79 | vgg_outputs = namedtuple(
80 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"])
81 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
82 | return out
83 |
--------------------------------------------------------------------------------
/pororo/tasks/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | from tokenizers import Tokenizer, decoders, pre_tokenizers
4 | from tokenizers.implementations import BaseTokenizer
5 | from tokenizers.models import BPE, Unigram
6 | from tokenizers.normalizers import NFKC
7 |
8 |
9 | class CustomTokenizer(BaseTokenizer):
10 |
11 | def __init__(
12 | self,
13 | vocab: Union[str, List],
14 | merges: Union[str, None],
15 | unk_token: str = "",
16 | replacement: str = "▁",
17 | add_prefix_space: bool = True,
18 | dropout: Optional[float] = None,
19 | normalize: bool = True,
20 | ):
21 | if merges:
22 | n_model = "BPE"
23 | tokenizer = Tokenizer(
24 | BPE(
25 | vocab, # type: ignore
26 | merges,
27 | unk_token=unk_token,
28 | fuse_unk=True,
29 | ))
30 | else:
31 | n_model = "Unigram"
32 | tokenizer = Tokenizer(Unigram(vocab, 1)) # type: ignore
33 |
34 | if normalize:
35 | tokenizer.normalizer = NFKC()
36 |
37 | tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
38 | replacement=replacement,
39 | add_prefix_space=add_prefix_space,
40 | )
41 |
42 | tokenizer.decoder = decoders.Metaspace(
43 | replacement=replacement,
44 | add_prefix_space=add_prefix_space,
45 | )
46 |
47 | parameters = {
48 | "model": f"SentencePiece{n_model}",
49 | "unk_token": unk_token,
50 | "replacement": replacement,
51 | "add_prefix_space": add_prefix_space,
52 | "dropout": dropout,
53 | }
54 | super().__init__(tokenizer, parameters)
55 |
56 | @staticmethod
57 | def from_file(
58 | vocab_filename: str,
59 | merges_filename: Union[str, None],
60 | **kwargs,
61 | ):
62 | # BPE
63 | if merges_filename:
64 | vocab, merges = BPE.read_file(vocab_filename, merges_filename)
65 |
66 | # Unigram
67 | else:
68 | vocab = []
69 | merges = None
70 | with open(vocab_filename, "r") as f_in:
71 | for line in f_in.readlines():
72 | token, score = line.strip().split("\t")
73 | vocab.append((token, float(score)))
74 |
75 | return CustomTokenizer(vocab, merges, **kwargs)
76 |
77 | def segment(self, text: str) -> List[str]:
78 | """
79 | Segment text into subword list
80 |
81 | Args:
82 | text (str): input text to be segmented
83 |
84 | Returns:
85 | List[str]: segmented subword list
86 |
87 | """
88 | encoding = self.encode(text)
89 |
90 | offsets = encoding.offsets
91 | tokens = encoding.tokens
92 |
93 | result = []
94 | for offset, token in zip(offsets, tokens):
95 | if token != "":
96 | result.append(token)
97 | continue
98 | s, e = offset
99 | result.append(text[s:e])
100 | return result
101 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/craft.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from https://github.com/clovaai/CRAFT-pytorch/blob/master/craft.py.
3 | Copyright (c) 2019-present NAVER Corp.
4 | MIT License
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch import Tensor
11 |
12 | from ._modules import Vgg16BN, init_weights
13 |
14 |
15 | class DoubleConv(nn.Module):
16 |
17 | def __init__(self, in_ch: int, mid_ch: int, out_ch: int) -> None:
18 | super(DoubleConv, self).__init__()
19 | self.conv = nn.Sequential(
20 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
21 | nn.BatchNorm2d(mid_ch),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
24 | nn.BatchNorm2d(out_ch),
25 | nn.ReLU(inplace=True),
26 | )
27 |
28 | def forward(self, x: Tensor):
29 | x = self.conv(x)
30 | return x
31 |
32 |
33 | class CRAFT(nn.Module):
34 |
35 | def __init__(self, pretrained: bool = False, freeze: bool = False) -> None:
36 | super(CRAFT, self).__init__()
37 |
38 | # Base network
39 | self.basenet = Vgg16BN(pretrained, freeze)
40 |
41 | # U network
42 | self.upconv1 = DoubleConv(1024, 512, 256)
43 | self.upconv2 = DoubleConv(512, 256, 128)
44 | self.upconv3 = DoubleConv(256, 128, 64)
45 | self.upconv4 = DoubleConv(128, 64, 32)
46 |
47 | num_class = 2
48 | self.conv_cls = nn.Sequential(
49 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(32, 16, kernel_size=3, padding=1),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(16, 16, kernel_size=1),
56 | nn.ReLU(inplace=True),
57 | nn.Conv2d(16, num_class, kernel_size=1),
58 | )
59 |
60 | init_weights(self.upconv1.modules())
61 | init_weights(self.upconv2.modules())
62 | init_weights(self.upconv3.modules())
63 | init_weights(self.upconv4.modules())
64 | init_weights(self.conv_cls.modules())
65 |
66 | def forward(self, x: Tensor):
67 | # Base network
68 | sources = self.basenet(x)
69 |
70 | # U network
71 | y = torch.cat([sources[0], sources[1]], dim=1)
72 | y = self.upconv1(y)
73 |
74 | y = F.interpolate(
75 | y,
76 | size=sources[2].size()[2:],
77 | mode="bilinear",
78 | align_corners=False,
79 | )
80 | y = torch.cat([y, sources[2]], dim=1)
81 | y = self.upconv2(y)
82 |
83 | y = F.interpolate(
84 | y,
85 | size=sources[3].size()[2:],
86 | mode="bilinear",
87 | align_corners=False,
88 | )
89 | y = torch.cat([y, sources[3]], dim=1)
90 | y = self.upconv3(y)
91 |
92 | y = F.interpolate(
93 | y,
94 | size=sources[4].size()[2:],
95 | mode="bilinear",
96 | align_corners=False,
97 | )
98 | y = torch.cat([y, sources[4]], dim=1)
99 | feature = self.upconv4(y)
100 |
101 | y = self.conv_cls(feature)
102 |
103 | return y.permute(0, 2, 3, 1), feature
104 |
--------------------------------------------------------------------------------
/pororo/pororo.py:
--------------------------------------------------------------------------------
1 | """
2 | Pororo task-specific factory class
3 |
4 | isort:skip_file
5 |
6 | """
7 |
8 | import logging
9 | from typing import Optional
10 | from pororo.tasks.utils.base import PororoTaskBase
11 |
12 | import torch
13 |
14 | from pororo.tasks import (
15 | PororoOcrFactory,
16 | )
17 |
18 | SUPPORTED_TASKS = {
19 | "ocr": PororoOcrFactory,
20 | }
21 |
22 | LANG_ALIASES = {
23 | "english": "en",
24 | "eng": "en",
25 | "korean": "ko",
26 | "kor": "ko",
27 | "kr": "ko",
28 | "chinese": "zh",
29 | "chn": "zh",
30 | "cn": "zh",
31 | "japanese": "ja",
32 | "jap": "ja",
33 | "jp": "ja",
34 | "jejueo": "je",
35 | "jje": "je",
36 | }
37 |
38 | logging.getLogger("transformers").setLevel(logging.WARN)
39 | logging.getLogger("fairseq").setLevel(logging.WARN)
40 | logging.getLogger("sentence_transformers").setLevel(logging.WARN)
41 | logging.getLogger("youtube_dl").setLevel(logging.WARN)
42 | logging.getLogger("pydub").setLevel(logging.WARN)
43 | logging.getLogger("librosa").setLevel(logging.WARN)
44 |
45 |
46 | class Pororo:
47 | r"""
48 | This is a generic class that will return one of the task-specific model classes of the library
49 | when created with the `__new__()` method
50 |
51 | """
52 |
53 | def __new__(
54 | cls,
55 | task: str,
56 | lang: str = "en",
57 | model: Optional[str] = None,
58 | **kwargs,
59 | ) -> PororoTaskBase:
60 | if task not in SUPPORTED_TASKS:
61 | raise KeyError("Unknown task {}, available tasks are {}".format(
62 | task,
63 | list(SUPPORTED_TASKS.keys()),
64 | ))
65 |
66 | lang = lang.lower()
67 | lang = LANG_ALIASES[lang] if lang in LANG_ALIASES else lang
68 |
69 | # Get device information from torch API
70 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 |
72 | # Instantiate task-specific pipeline module, if possible
73 | task_module = SUPPORTED_TASKS[task](
74 | task,
75 | lang,
76 | model,
77 | **kwargs,
78 | ).load(device)
79 |
80 | return task_module
81 |
82 | @staticmethod
83 | def available_tasks() -> str:
84 | """
85 | Returns available tasks in Pororo project
86 |
87 | Returns:
88 | str: Supported task names
89 |
90 | """
91 | return "Available tasks are {}".format(list(SUPPORTED_TASKS.keys()))
92 |
93 | @staticmethod
94 | def available_models(task: str) -> str:
95 | """
96 | Returns available model names correponding to the user-input task
97 |
98 | Args:
99 | task (str): user-input task name
100 |
101 | Returns:
102 | str: Supported model names corresponding to the user-input task
103 |
104 | Raises:
105 | KeyError: When user-input task is not supported
106 |
107 | """
108 | if task not in SUPPORTED_TASKS:
109 | raise KeyError(
110 | "Unknown task {} ! Please check available models via `available_tasks()`"
111 | .format(task))
112 |
113 | langs = SUPPORTED_TASKS[task].get_available_models()
114 | output = f"Available models for {task} are "
115 | for lang in langs:
116 | output += f"([lang]: {lang}, [model]: {', '.join(langs[lang])}), "
117 | return output[:-2]
118 |
--------------------------------------------------------------------------------
/pororo/tasks/utils/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Union
3 |
4 |
5 | @dataclass
6 | class TransformerConfig:
7 | src_dict: Union[str, None]
8 | tgt_dict: Union[str, None]
9 | src_tok: Union[str, None]
10 | tgt_tok: Union[str, None]
11 |
12 |
13 | CONFIGS = {
14 | "transformer.base.ko.const":
15 | TransformerConfig(
16 | "dict.transformer.base.ko.const",
17 | "dict.transformer.base.ko.const",
18 | None,
19 | None,
20 | ),
21 | "transformer.base.ko.pg":
22 | TransformerConfig(
23 | "dict.transformer.base.ko.mt",
24 | "dict.transformer.base.ko.mt",
25 | "bpe8k.ko",
26 | None,
27 | ),
28 | "transformer.base.ko.pg_long":
29 | TransformerConfig(
30 | "dict.transformer.base.ko.mt",
31 | "dict.transformer.base.ko.mt",
32 | "bpe8k.ko",
33 | None,
34 | ),
35 | "transformer.base.en.gec":
36 | TransformerConfig(
37 | "dict.transformer.base.en.mt",
38 | "dict.transformer.base.en.mt",
39 | "bpe32k.en",
40 | None,
41 | ),
42 | "transformer.base.zh.pg":
43 | TransformerConfig(
44 | "dict.transformer.base.zh.mt",
45 | "dict.transformer.base.zh.mt",
46 | None,
47 | None,
48 | ),
49 | "transformer.base.ja.pg":
50 | TransformerConfig(
51 | "dict.transformer.base.ja.mt",
52 | "dict.transformer.base.ja.mt",
53 | "bpe8k.ja",
54 | None,
55 | ),
56 | "transformer.base.zh.const":
57 | TransformerConfig(
58 | "dict.transformer.base.zh.const",
59 | "dict.transformer.base.zh.const",
60 | None,
61 | None,
62 | ),
63 | "transformer.base.en.const":
64 | TransformerConfig(
65 | "dict.transformer.base.en.const",
66 | "dict.transformer.base.en.const",
67 | None,
68 | None,
69 | ),
70 | "transformer.base.en.pg":
71 | TransformerConfig(
72 | "dict.transformer.base.en.mt",
73 | "dict.transformer.base.en.mt",
74 | "bpe32k.en",
75 | None,
76 | ),
77 | "transformer.base.ko.gec":
78 | TransformerConfig(
79 | "dict.transformer.base.ko.gec",
80 | "dict.transformer.base.ko.gec",
81 | "bpe8k.ko",
82 | None,
83 | ),
84 | "transformer.base.en.char_gec":
85 | TransformerConfig(
86 | "dict.transformer.base.en.char_gec",
87 | "dict.transformer.base.en.char_gec",
88 | None,
89 | None,
90 | ),
91 | "transformer.base.en.caption":
92 | TransformerConfig(
93 | None,
94 | None,
95 | None,
96 | None,
97 | ),
98 | "transformer.base.ja.p2g":
99 | TransformerConfig(
100 | "dict.transformer.base.ja.p2g",
101 | "dict.transformer.base.ja.p2g",
102 | None,
103 | None,
104 | ),
105 | "transformer.large.multi.mtpg":
106 | TransformerConfig(
107 | "dict.transformer.large.multi.mtpg",
108 | "dict.transformer.large.multi.mtpg",
109 | "bpe32k.en",
110 | None,
111 | ),
112 | "transformer.large.multi.fast.mtpg":
113 | TransformerConfig(
114 | "dict.transformer.large.multi.mtpg",
115 | "dict.transformer.large.multi.mtpg",
116 | "bpe32k.en",
117 | None,
118 | ),
119 | "transformer.large.ko.wsd":
120 | TransformerConfig(
121 | "dict.transformer.large.ko.wsd",
122 | "dict.transformer.large.ko.wsd",
123 | None,
124 | None,
125 | ),
126 | }
127 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/model.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from
3 | https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/model.py
4 | """
5 |
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | from .modules.feature_extraction import (
10 | ResNetFeatureExtractor,
11 | VGGFeatureExtractor,
12 | )
13 | from .modules.prediction import Attention
14 | from .modules.sequence_modeling import BidirectionalLSTM
15 | from .modules.transformation import TpsSpatialTransformerNetwork
16 |
17 |
18 | class Model(nn.Module):
19 |
20 | def __init__(self, opt2val: dict):
21 | super(Model, self).__init__()
22 |
23 | input_channel = opt2val["input_channel"]
24 | output_channel = opt2val["output_channel"]
25 | hidden_size = opt2val["hidden_size"]
26 | vocab_size = opt2val["vocab_size"]
27 | num_fiducial = opt2val["num_fiducial"]
28 | imgH = opt2val["imgH"]
29 | imgW = opt2val["imgW"]
30 | FeatureExtraction = opt2val["FeatureExtraction"]
31 | Transformation = opt2val["Transformation"]
32 | SequenceModeling = opt2val["SequenceModeling"]
33 | Prediction = opt2val["Prediction"]
34 |
35 | # Transformation
36 | if Transformation == "TPS":
37 | self.Transformation = TpsSpatialTransformerNetwork(
38 | F=num_fiducial,
39 | I_size=(imgH, imgW),
40 | I_r_size=(imgH, imgW),
41 | I_channel_num=input_channel,
42 | )
43 | else:
44 | print("No Transformation module specified")
45 |
46 | # FeatureExtraction
47 | if FeatureExtraction == "VGG":
48 | extractor = VGGFeatureExtractor
49 | else: # ResNet
50 | extractor = ResNetFeatureExtractor
51 | self.FeatureExtraction = extractor(
52 | input_channel,
53 | output_channel,
54 | opt2val,
55 | )
56 | self.FeatureExtraction_output = output_channel # int(imgH/16-1) * 512
57 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
58 | (None, 1)) # Transform final (imgH/16-1) -> 1
59 |
60 | # Sequence modeling
61 | if SequenceModeling == "BiLSTM":
62 | self.SequenceModeling = nn.Sequential(
63 | BidirectionalLSTM(
64 | self.FeatureExtraction_output,
65 | hidden_size,
66 | hidden_size,
67 | ),
68 | BidirectionalLSTM(hidden_size, hidden_size, hidden_size),
69 | )
70 | self.SequenceModeling_output = hidden_size
71 | else:
72 | print("No SequenceModeling module specified")
73 | self.SequenceModeling_output = self.FeatureExtraction_output
74 |
75 | # Prediction
76 | if Prediction == "CTC":
77 | self.Prediction = nn.Linear(
78 | self.SequenceModeling_output,
79 | vocab_size,
80 | )
81 | elif Prediction == "Attn":
82 | self.Prediction = Attention(
83 | self.SequenceModeling_output,
84 | hidden_size,
85 | vocab_size,
86 | )
87 | elif Prediction == "Transformer": # TODO
88 | pass
89 | else:
90 | raise Exception("Prediction is neither CTC or Attn")
91 |
92 | def forward(self, x: Tensor):
93 | """
94 | :param x: (batch, input_channel, height, width)
95 | :return:
96 | """
97 | # Transformation stage
98 | x = self.Transformation(x)
99 |
100 | # Feature extraction stage
101 | visual_feature = self.FeatureExtraction(
102 | x) # (b, output_channel=512, h=3, w)
103 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(
104 | 0, 3, 1, 2)) # (b, w, channel=512, h=1)
105 | visual_feature = visual_feature.squeeze(3) # (b, w, channel=512)
106 |
107 | # Sequence modeling stage
108 | self.SequenceModeling.eval()
109 | contextual_feature = self.SequenceModeling(visual_feature)
110 |
111 | # Prediction stage
112 | prediction = self.Prediction(
113 | contextual_feature.contiguous()) # (b, T, num_classes)
114 |
115 | return prediction
116 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class Attention(nn.Module):
9 |
10 | def __init__(self, input_size, hidden_size, num_classes):
11 | super(Attention, self).__init__()
12 | self.attention_cell = AttentionCell(input_size, hidden_size,
13 | num_classes)
14 | self.hidden_size = hidden_size
15 | self.num_classes = num_classes
16 | self.generator = nn.Linear(hidden_size, num_classes)
17 |
18 | def _char_to_onehot(self, input_char, onehot_dim=38):
19 | input_char = input_char.unsqueeze(1)
20 | batch_size = input_char.size(0)
21 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
22 | one_hot = one_hot.scatter_(1, input_char, 1)
23 | return one_hot
24 |
25 | def forward(self, batch_H, text, is_train=True, batch_max_length=25):
26 | """
27 | input:
28 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
29 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
30 | output: probability distribution at each step [batch_size x num_steps x num_classes]
31 | """
32 | batch_size = batch_H.size(0)
33 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
34 |
35 | output_hiddens = (torch.FloatTensor(
36 | batch_size, num_steps, self.hidden_size).fill_(0).to(device))
37 | hidden = (
38 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
39 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
40 | )
41 |
42 | if is_train:
43 | for i in range(num_steps):
44 | # one-hot vectors for a i-th char. in a batch
45 | char_onehots = self._char_to_onehot(text[:, i],
46 | onehot_dim=self.num_classes)
47 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
48 | hidden, alpha = self.attention_cell(hidden, batch_H,
49 | char_onehots)
50 | output_hiddens[:, i, :] = hidden[
51 | 0] # LSTM hidden index (0: hidden, 1: Cell)
52 | probs = self.generator(output_hiddens)
53 |
54 | else:
55 | targets = torch.LongTensor(batch_size).fill_(0).to(
56 | device) # [GO] token
57 | probs = (torch.FloatTensor(batch_size, num_steps,
58 | self.num_classes).fill_(0).to(device))
59 |
60 | for i in range(num_steps):
61 | char_onehots = self._char_to_onehot(targets,
62 | onehot_dim=self.num_classes)
63 | hidden, alpha = self.attention_cell(hidden, batch_H,
64 | char_onehots)
65 | probs_step = self.generator(hidden[0])
66 | probs[:, i, :] = probs_step
67 | _, next_input = probs_step.max(1)
68 | targets = next_input
69 |
70 | return probs # batch_size x num_steps x num_classes
71 |
72 |
73 | class AttentionCell(nn.Module):
74 |
75 | def __init__(self, input_size, hidden_size, num_embeddings):
76 | super(AttentionCell, self).__init__()
77 | self.i2h = nn.Linear(input_size, hidden_size, bias=False)
78 | self.h2h = nn.Linear(hidden_size,
79 | hidden_size) # either i2i or h2h should have bias
80 | self.score = nn.Linear(hidden_size, 1, bias=False)
81 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
82 | self.hidden_size = hidden_size
83 |
84 | def forward(self, prev_hidden, batch_H, char_onehots):
85 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
86 | batch_H_proj = self.i2h(batch_H)
87 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
88 | e = self.score(
89 | torch.tanh(batch_H_proj +
90 | prev_hidden_proj)) # batch_size x num_encoder_step * 1
91 |
92 | alpha = F.softmax(e, dim=1)
93 | context = torch.bmm(alpha.permute(0, 2, 1),
94 | batch_H).squeeze(1) # batch_size x num_channel
95 | concat_context = torch.cat(
96 | [context, char_onehots],
97 | 1) # batch_size x (num_channel + num_embedding)
98 | cur_hidden = self.rnn(concat_context, prev_hidden)
99 | return cur_hidden, alpha
100 |
--------------------------------------------------------------------------------
/pororo/tasks/utils/base.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 | from abc import abstractmethod
4 | from dataclasses import dataclass
5 | from typing import List, Mapping, Optional, Union
6 |
7 |
8 | @dataclass
9 | class TaskConfig:
10 | task: str
11 | lang: str
12 | n_model: str
13 |
14 |
15 | class PororoTaskBase:
16 | r"""Task base class that implements basic functions for prediction"""
17 |
18 | def __init__(self, config: TaskConfig):
19 | self.config = config
20 |
21 | @property
22 | def n_model(self):
23 | return self.config.n_model
24 |
25 | @property
26 | def lang(self):
27 | return self.config.lang
28 |
29 | @abstractmethod
30 | def predict(
31 | self,
32 | text: Union[str, List[str]],
33 | **kwargs,
34 | ):
35 | raise NotImplementedError(
36 | "`predict()` function is not implemented properly!")
37 |
38 | def __call__(self):
39 | raise NotImplementedError(
40 | "`call()` function is not implemented properly!")
41 |
42 | def __repr__(self):
43 | return f"[TASK]: {self.config.task.upper()}\n[LANG]: {self.config.lang.upper()}\n[MODEL]: {self.config.n_model}"
44 |
45 | def _normalize(self, text: str):
46 | """Unicode normalization and whitespace removal (often needed for contexts)"""
47 | text = unicodedata.normalize("NFKC", text)
48 | text = re.sub(r"\s+", " ", text).strip()
49 | return text
50 |
51 |
52 | class PororoFactoryBase(object):
53 | r"""This is a factory base class that construct task-specific module"""
54 |
55 | def __init__(
56 | self,
57 | task: str,
58 | lang: str,
59 | model: Optional[str] = None,
60 | ):
61 | self._available_langs = self.get_available_langs()
62 | self._available_models = self.get_available_models()
63 | self._model2lang = {
64 | v: k for k, vs in self._available_models.items() for v in vs
65 | }
66 |
67 | # Set default language as very first supported language
68 | assert (
69 | lang in self._available_langs
70 | ), f"Following langs are supported for this task: {self._available_langs}"
71 |
72 | if lang is None:
73 | lang = self._available_langs[0]
74 |
75 | # Change language option if model is defined by user
76 | if model is not None:
77 | lang = self._model2lang[model]
78 |
79 | # Set default model
80 | if model is None:
81 | model = self.get_default_model(lang)
82 |
83 | # yapf: disable
84 | assert (model in self._available_models[lang]), f"{model} is NOT supported for {lang}"
85 | # yapf: enable
86 |
87 | self.config = TaskConfig(task, lang, model)
88 |
89 | @abstractmethod
90 | def get_available_langs(self) -> List[str]:
91 | raise NotImplementedError(
92 | "`get_available_langs()` is not implemented properly!")
93 |
94 | @abstractmethod
95 | def get_available_models(self) -> Mapping[str, List[str]]:
96 | raise NotImplementedError(
97 | "`get_available_models()` is not implemented properly!")
98 |
99 | @abstractmethod
100 | def get_default_model(self, lang: str) -> str:
101 | return self._available_models[lang][0]
102 |
103 | @classmethod
104 | def load(cls) -> PororoTaskBase:
105 | raise NotImplementedError(
106 | "Model load function is not implemented properly!")
107 |
108 |
109 | class PororoSimpleBase(PororoTaskBase):
110 | r"""Simple task base wrapper class"""
111 |
112 | def __call__(self, text: str, **kwargs):
113 | return self.predict(text, **kwargs)
114 |
115 |
116 | class PororoBiencoderBase(PororoTaskBase):
117 | r"""Bi-Encoder base wrapper class"""
118 |
119 | def __call__(
120 | self,
121 | sent_a: str,
122 | sent_b: Union[str, List[str]],
123 | **kwargs,
124 | ):
125 | assert isinstance(sent_a, str), "sent_a should be string type"
126 | assert isinstance(sent_b, str) or isinstance(
127 | sent_b, list), "sent_b should be string or list of string type"
128 |
129 | sent_a = self._normalize(sent_a)
130 |
131 | # For "Find Similar Sentence" task
132 | if isinstance(sent_b, list):
133 | sent_b = [self._normalize(t) for t in sent_b]
134 | else:
135 | sent_b = self._normalize(sent_b)
136 |
137 | return self.predict(sent_a, sent_b, **kwargs)
138 |
139 |
140 | class PororoGenerationBase(PororoTaskBase):
141 | r"""Generation task wrapper class using various generation tricks"""
142 |
143 | def __call__(
144 | self,
145 | text: str,
146 | beam: int = 5,
147 | temperature: float = 1.0,
148 | top_k: int = -1,
149 | top_p: float = -1,
150 | no_repeat_ngram_size: int = 4,
151 | len_penalty: float = 1.0,
152 | **kwargs,
153 | ):
154 | assert isinstance(text, str), "Input text should be string type"
155 |
156 | return self.predict(
157 | text,
158 | beam=beam,
159 | temperature=temperature,
160 | top_k=top_k,
161 | top_p=top_p,
162 | no_repeat_ngram_size=no_repeat_ngram_size,
163 | len_penalty=len_penalty,
164 | **kwargs,
165 | )
166 |
167 |
168 | class PororoTaskGenerationBase(PororoTaskBase):
169 | r"""Generation task wrapper class using only beam search"""
170 |
171 | def __call__(self, text: str, beam: int = 1, **kwargs):
172 | assert isinstance(text, str), "Input text should be string type"
173 |
174 | text = self._normalize(text)
175 |
176 | return self.predict(text, beam=beam, **kwargs)
177 |
--------------------------------------------------------------------------------
/pororo/tasks/optical_character_recognition.py:
--------------------------------------------------------------------------------
1 | """OCR related modeling class"""
2 |
3 | from typing import Optional
4 |
5 | from pororo.tasks import download_or_load
6 | from pororo.tasks.utils.base import PororoFactoryBase, PororoSimpleBase
7 |
8 |
9 | class PororoOcrFactory(PororoFactoryBase):
10 | """
11 | Recognize optical characters in image file
12 | Currently support Korean language
13 |
14 | English + Korean (`brainocr`)
15 |
16 | - dataset: Internal data + AI hub Font Image dataset
17 | - metric: TBU
18 | - ref: https://www.aihub.or.kr/aidata/133
19 |
20 | Examples:
21 | >>> ocr = Pororo(task="ocr", lang="ko")
22 | >>> ocr(IMAGE_PATH)
23 | ["사이렌'(' 신마'", "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"]
24 |
25 | >>> ocr = Pororo(task="ocr", lang="ko")
26 | >>> ocr(IMAGE_PATH, detail=True)
27 | {
28 | 'description': ["사이렌'(' 신마', "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고"],
29 | 'bounding_poly': [
30 | {
31 | 'description': "사이렌'(' 신마'",
32 | 'vertices': [
33 | {'x': 93, 'y': 7},
34 | {'x': 164, 'y': 7},
35 | {'x': 164, 'y': 21},
36 | {'x': 93, 'y': 21}
37 | ]
38 | },
39 | {
40 | 'description': "내가 말했잖아 속지열라고 이 손을 잡는 너는 위협해질 거라고",
41 | 'vertices': [
42 | {'x': 0, 'y': 30},
43 | {'x': 259, 'y': 30},
44 | {'x': 259, 'y': 194},
45 | {'x': 0, 'y': 194}]}
46 | ]
47 | }
48 | }
49 | """
50 |
51 | def __init__(self, task: str, lang: str, model: Optional[str]):
52 | super().__init__(task, lang, model)
53 | self.detect_model = "craft"
54 | self.ocr_opt = "ocr-opt"
55 |
56 | @staticmethod
57 | def get_available_langs():
58 | return ["en", "ko"]
59 |
60 | @staticmethod
61 | def get_available_models():
62 | return {
63 | "en": ["brainocr"],
64 | "ko": ["brainocr"],
65 | }
66 |
67 | def load(self, device: str):
68 | """
69 | Load user-selected task-specific model
70 |
71 | Args:
72 | device (str): device information
73 |
74 | Returns:
75 | object: User-selected task-specific model
76 |
77 | """
78 | if self.config.n_model == "brainocr":
79 | from pororo.models.brainOCR import brainocr
80 |
81 | if self.config.lang not in self.get_available_langs():
82 | raise ValueError(
83 | f"Unsupported Language : {self.config.lang}",
84 | 'Support Languages : ["en", "ko"]',
85 | )
86 |
87 | det_model_path = download_or_load(
88 | f"misc/{self.detect_model}.pt",
89 | self.config.lang,
90 | )
91 | rec_model_path = download_or_load(
92 | f"misc/{self.config.n_model}.pt",
93 | self.config.lang,
94 | )
95 | opt_fp = download_or_load(
96 | f"misc/{self.ocr_opt}.txt",
97 | self.config.lang,
98 | )
99 | model = brainocr.Reader(
100 | self.config.lang,
101 | det_model_ckpt_fp=det_model_path,
102 | rec_model_ckpt_fp=rec_model_path,
103 | opt_fp=opt_fp,
104 | device=device,
105 | )
106 | model.detector.to(device)
107 | model.recognizer.to(device)
108 | return PororoOCR(model, self.config)
109 |
110 |
111 | class PororoOCR(PororoSimpleBase):
112 |
113 | def __init__(self, model, config):
114 | super().__init__(config)
115 | self._model = model
116 |
117 | def _postprocess(self, ocr_results, detail: bool = False):
118 | """
119 | Post-process for OCR result
120 |
121 | Args:
122 | ocr_results (list): list contains result of OCR
123 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc)
124 |
125 | """
126 | sorted_ocr_results = sorted(
127 | ocr_results,
128 | key=lambda x: (
129 | x[0][0][1],
130 | x[0][0][0],
131 | ),
132 | )
133 |
134 | if not detail:
135 | return [
136 | sorted_ocr_results[i][-1]
137 | for i in range(len(sorted_ocr_results))
138 | ]
139 |
140 | result_dict = {
141 | "description": list(),
142 | "bounding_poly": list(),
143 | }
144 |
145 | for ocr_result in sorted_ocr_results:
146 | vertices = list()
147 |
148 | for vertice in ocr_result[0]:
149 | vertices.append({
150 | "x": vertice[0],
151 | "y": vertice[1],
152 | })
153 |
154 | result_dict["description"].append(ocr_result[1])
155 | result_dict["bounding_poly"].append({
156 | "description": ocr_result[1],
157 | "vertices": vertices
158 | })
159 |
160 | return result_dict
161 |
162 | def predict(self, image_path: str, **kwargs):
163 | """
164 | Conduct Optical Character Recognition (OCR)
165 |
166 | Args:
167 | image_path (str): the image file path
168 | detail (bool): if True, returned to include details. (bounding poly, vertices, etc)
169 |
170 | """
171 | detail = kwargs.get("detail", False)
172 |
173 | return self._postprocess(
174 | self._model(
175 | image_path,
176 | skip_details=False,
177 | batch_size=1,
178 | paragraph=True,
179 | ),
180 | detail,
181 | )
182 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/recognition.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/recognition.py
3 | """
4 |
5 | import math
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.utils.data
11 | import torchvision.transforms as transforms
12 | from PIL import Image
13 |
14 | from .model import Model
15 | from .utils import CTCLabelConverter
16 |
17 |
18 | def contrast_grey(img):
19 | high = np.percentile(img, 90)
20 | low = np.percentile(img, 10)
21 | return (high - low) / np.maximum(10, high + low), high, low
22 |
23 |
24 | def adjust_contrast_grey(img, target: float = 0.4):
25 | contrast, high, low = contrast_grey(img)
26 | if contrast < target:
27 | img = img.astype(int)
28 | ratio = 200.0 / np.maximum(10, high - low)
29 | img = (img - low + 25) * ratio
30 | img = np.maximum(
31 | np.full(img.shape, 0),
32 | np.minimum(
33 | np.full(img.shape, 255),
34 | img,
35 | ),
36 | ).astype(np.uint8)
37 | return img
38 |
39 |
40 | class NormalizePAD(object):
41 |
42 | def __init__(self, max_size, PAD_type: str = "right"):
43 | self.toTensor = transforms.ToTensor()
44 | self.max_size = max_size
45 | self.max_width_half = math.floor(max_size[2] / 2)
46 | self.PAD_type = PAD_type
47 |
48 | def __call__(self, img):
49 | img = self.toTensor(img)
50 | img.sub_(0.5).div_(0.5)
51 | c, h, w = img.size()
52 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
53 | Pad_img[:, :, :w] = img # right pad
54 | if self.max_size[2] != w: # add border Pad
55 | Pad_img[:, :, w:] = (img[:, :, w - 1].unsqueeze(2).expand(
56 | c,
57 | h,
58 | self.max_size[2] - w,
59 | ))
60 |
61 | return Pad_img
62 |
63 |
64 | class ListDataset(torch.utils.data.Dataset):
65 |
66 | def __init__(self, image_list: list):
67 | self.image_list = image_list
68 | self.nSamples = len(image_list)
69 |
70 | def __len__(self):
71 | return self.nSamples
72 |
73 | def __getitem__(self, index):
74 | img = self.image_list[index]
75 | return Image.fromarray(img, "L")
76 |
77 |
78 | class AlignCollate(object):
79 |
80 | def __init__(self, imgH: int, imgW: int, adjust_contrast: float):
81 | self.imgH = imgH
82 | self.imgW = imgW
83 | self.keep_ratio_with_pad = True # Do Not Change
84 | self.adjust_contrast = adjust_contrast
85 |
86 | def __call__(self, batch):
87 | batch = filter(lambda x: x is not None, batch)
88 | images = batch
89 |
90 | resized_max_w = self.imgW
91 | input_channel = 1
92 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
93 |
94 | resized_images = []
95 | for image in images:
96 | w, h = image.size
97 | # augmentation here - change contrast
98 | if self.adjust_contrast > 0:
99 | image = np.array(image.convert("L"))
100 | image = adjust_contrast_grey(image, target=self.adjust_contrast)
101 | image = Image.fromarray(image, "L")
102 |
103 | ratio = w / float(h)
104 | if math.ceil(self.imgH * ratio) > self.imgW:
105 | resized_w = self.imgW
106 | else:
107 | resized_w = math.ceil(self.imgH * ratio)
108 |
109 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
110 | resized_images.append(transform(resized_image))
111 |
112 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
113 | return image_tensors
114 |
115 |
116 | def recognizer_predict(model, converter, test_loader, opt2val: dict):
117 | device = opt2val["device"]
118 |
119 | model.eval()
120 | result = []
121 | with torch.no_grad():
122 | for image_tensors in test_loader:
123 | batch_size = image_tensors.size(0)
124 | inputs = image_tensors.to(device)
125 | preds = model(inputs) # (N, length, num_classes)
126 |
127 | # rebalance
128 | preds_prob = F.softmax(preds, dim=2)
129 | preds_prob = preds_prob.cpu().detach().numpy()
130 | pred_norm = preds_prob.sum(axis=2)
131 | preds_prob = preds_prob / np.expand_dims(pred_norm, axis=-1)
132 | preds_prob = torch.from_numpy(preds_prob).float().to(device)
133 |
134 | # Select max probabilty (greedy decoding), then decode index to character
135 | preds_lengths = torch.IntTensor([preds.size(1)] *
136 | batch_size) # (N,)
137 | _, preds_indices = preds_prob.max(2) # (N, length)
138 | preds_indices = preds_indices.view(-1) # (N*length)
139 | preds_str = converter.decode_greedy(preds_indices, preds_lengths)
140 |
141 | preds_max_prob, _ = preds_prob.max(dim=2)
142 |
143 | for pred, pred_max_prob in zip(preds_str, preds_max_prob):
144 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
145 | result.append([pred, confidence_score.item()])
146 |
147 | return result
148 |
149 |
150 | def get_recognizer(opt2val: dict):
151 | """
152 | :return:
153 | recognizer: recognition net
154 | converter: CTCLabelConverter
155 | """
156 | # converter
157 | vocab = opt2val["vocab"]
158 | converter = CTCLabelConverter(vocab)
159 |
160 | # recognizer
161 | recognizer = Model(opt2val)
162 |
163 | # state_dict
164 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"]
165 | device = opt2val["device"]
166 | state_dict = torch.load(rec_model_ckpt_fp, map_location=device)
167 |
168 | if device == "cuda":
169 | recognizer = torch.nn.DataParallel(recognizer).to(device)
170 | else:
171 | # TODO temporary: multigpu 학습한 뒤 ckpt loading 문제
172 | from collections import OrderedDict
173 |
174 | def _sync_tensor_name(state_dict):
175 | state_dict_ = OrderedDict()
176 | for name, val in state_dict.items():
177 | name = name.replace("module.", "")
178 | state_dict_[name] = val
179 | return state_dict_
180 |
181 | state_dict = _sync_tensor_name(state_dict)
182 |
183 | recognizer.load_state_dict(state_dict)
184 |
185 | return recognizer, converter
186 |
187 |
188 | def get_text(image_list, recognizer, converter, opt2val: dict):
189 | imgW = opt2val["imgW"]
190 | imgH = opt2val["imgH"]
191 | adjust_contrast = opt2val["adjust_contrast"]
192 | batch_size = opt2val["batch_size"]
193 | n_workers = opt2val["n_workers"]
194 | contrast_ths = opt2val["contrast_ths"]
195 |
196 | # TODO: figure out what is this for
197 | # batch_max_length = int(imgW / 10)
198 |
199 | coord = [item[0] for item in image_list]
200 | img_list = [item[1] for item in image_list]
201 | AlignCollate_normal = AlignCollate(imgH, imgW, adjust_contrast)
202 | test_data = ListDataset(img_list)
203 | test_loader = torch.utils.data.DataLoader(
204 | test_data,
205 | batch_size=batch_size,
206 | shuffle=False,
207 | num_workers=n_workers,
208 | collate_fn=AlignCollate_normal,
209 | pin_memory=True,
210 | )
211 |
212 | # predict first round
213 | result1 = recognizer_predict(recognizer, converter, test_loader, opt2val)
214 |
215 | # predict second round
216 | low_confident_idx = [
217 | i for i, item in enumerate(result1) if (item[1] < contrast_ths)
218 | ]
219 | if len(low_confident_idx) > 0:
220 | img_list2 = [img_list[i] for i in low_confident_idx]
221 | AlignCollate_contrast = AlignCollate(imgH, imgW, adjust_contrast)
222 | test_data = ListDataset(img_list2)
223 | test_loader = torch.utils.data.DataLoader(
224 | test_data,
225 | batch_size=batch_size,
226 | shuffle=False,
227 | num_workers=n_workers,
228 | collate_fn=AlignCollate_contrast,
229 | pin_memory=True,
230 | )
231 | result2 = recognizer_predict(recognizer, converter, test_loader,
232 | opt2val)
233 |
234 | result = []
235 | for i, zipped in enumerate(zip(coord, result1)):
236 | box, pred1 = zipped
237 | if i in low_confident_idx:
238 | pred2 = result2[low_confident_idx.index(i)]
239 | if pred1[1] > pred2[1]:
240 | result.append((box, pred1[0], pred1[1]))
241 | else:
242 | result.append((box, pred2[0], pred2[1]))
243 | else:
244 | result.append((box, pred1[0], pred1[1]))
245 |
246 | return result
247 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/brainocr.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is primarily based on the following:
3 | https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/easyocr.py
4 |
5 | Basic usage:
6 | >>> from pororo import Pororo
7 | >>> ocr = Pororo(task="ocr", lang="ko")
8 | >>> ocr("IMAGE_FILE")
9 | """
10 |
11 | import ast
12 | from logging import getLogger
13 | from typing import List
14 |
15 | import cv2
16 | import numpy as np
17 | from PIL import Image
18 |
19 | from .detection import get_detector, get_textbox
20 | from .recognition import get_recognizer, get_text
21 | from .utils import (
22 | diff,
23 | get_image_list,
24 | get_paragraph,
25 | group_text_box,
26 | reformat_input,
27 | )
28 |
29 | LOGGER = getLogger(__name__)
30 |
31 |
32 | class Reader(object):
33 |
34 | def __init__(
35 | self,
36 | lang: str,
37 | det_model_ckpt_fp: str,
38 | rec_model_ckpt_fp: str,
39 | opt_fp: str,
40 | device: str,
41 | ) -> None:
42 | """
43 | TODO @karter: modify this such that you download the pretrained checkpoint files
44 | Parameters:
45 | lang: language code. e.g, "en" or "ko"
46 | det_model_ckpt_fp: Detection model's checkpoint path e.g., 'craft_mlt_25k.pth'
47 | rec_model_ckpt_fp: Recognition model's checkpoint path
48 | opt_fp: option file path
49 | """
50 | # Plug options in the dictionary
51 | opt2val = self.parse_options(opt_fp) # e.g., {"imgH": 64, ...}
52 | opt2val["vocab"] = self.build_vocab(opt2val["character"])
53 | opt2val["vocab_size"] = len(opt2val["vocab"])
54 | opt2val["device"] = device
55 | opt2val["lang"] = lang
56 | opt2val["det_model_ckpt_fp"] = det_model_ckpt_fp
57 | opt2val["rec_model_ckpt_fp"] = rec_model_ckpt_fp
58 |
59 | # Get model objects
60 | self.detector = get_detector(det_model_ckpt_fp, opt2val["device"])
61 | self.recognizer, self.converter = get_recognizer(opt2val)
62 | self.opt2val = opt2val
63 |
64 | @staticmethod
65 | def parse_options(opt_fp: str) -> dict:
66 | opt2val = dict()
67 | for line in open(opt_fp, "r", encoding="utf8"):
68 | line = line.strip()
69 | if ": " in line:
70 | opt, val = line.split(": ", 1)
71 | try:
72 | opt2val[opt] = ast.literal_eval(val)
73 | except:
74 | opt2val[opt] = val
75 |
76 | return opt2val
77 |
78 | @staticmethod
79 | def build_vocab(character: str) -> List[str]:
80 | """Returns vocabulary (=list of characters)"""
81 | vocab = ["[blank]"] + list(
82 | character) # dummy '[blank]' token for CTCLoss (index 0)
83 | return vocab
84 |
85 | def detect(self, img: np.ndarray, opt2val: dict):
86 | """
87 | :return:
88 | horizontal_list (list): e.g., [[613, 1496, 51, 190], [136, 1544, 134, 508]]
89 | free_list (list): e.g., []
90 | """
91 | text_box = get_textbox(self.detector, img, opt2val)
92 | horizontal_list, free_list = group_text_box(
93 | text_box,
94 | opt2val["slope_ths"],
95 | opt2val["ycenter_ths"],
96 | opt2val["height_ths"],
97 | opt2val["width_ths"],
98 | opt2val["add_margin"],
99 | )
100 |
101 | min_size = opt2val["min_size"]
102 | if min_size:
103 | horizontal_list = [
104 | i for i in horizontal_list
105 | if max(i[1] - i[0], i[3] - i[2]) > min_size
106 | ]
107 | free_list = [
108 | i for i in free_list
109 | if max(diff([c[0] for c in i]), diff([c[1]
110 | for c in i])) > min_size
111 | ]
112 |
113 | return horizontal_list, free_list
114 |
115 | def recognize(
116 | self,
117 | img_cv_grey: np.ndarray,
118 | horizontal_list: list,
119 | free_list: list,
120 | opt2val: dict,
121 | ):
122 | """
123 | Read text in the image
124 | :return:
125 | result (list): bounding box, text and confident score
126 | e.g., [([[189, 75], [469, 75], [469, 165], [189, 165]], '愚园路', 0.3754989504814148),
127 | ([[86, 80], [134, 80], [134, 128], [86, 128]], '西', 0.40452659130096436),
128 | ([[517, 81], [565, 81], [565, 123], [517, 123]], '东', 0.9989598989486694),
129 | ([[78, 126], [136, 126], [136, 156], [78, 156]], '315', 0.8125889301300049),
130 | ([[514, 126], [574, 126], [574, 156], [514, 156]], '309', 0.4971577227115631),
131 | ([[226, 170], [414, 170], [414, 220], [226, 220]], 'Yuyuan Rd.', 0.8261902332305908),
132 | ([[79, 173], [125, 173], [125, 213], [79, 213]], 'W', 0.9848111271858215),
133 | ([[529, 173], [569, 173], [569, 213], [529, 213]], 'E', 0.8405593633651733)]
134 | or list of texts (if skip_details is True)
135 | e.g., ['愚园路', '西', '东', '315', '309', 'Yuyuan Rd.', 'W', 'E']
136 | """
137 | imgH = opt2val["imgH"]
138 | paragraph = opt2val["paragraph"]
139 | skip_details = opt2val["skip_details"]
140 |
141 | if (horizontal_list is None) and (free_list is None):
142 | y_max, x_max = img_cv_grey.shape
143 | ratio = x_max / y_max
144 | max_width = int(imgH * ratio)
145 | crop_img = cv2.resize(
146 | img_cv_grey,
147 | (max_width, imgH),
148 | interpolation=Image.ANTIALIAS,
149 | )
150 | image_list = [([[0, 0], [x_max, 0], [x_max, y_max],
151 | [0, y_max]], crop_img)]
152 | else:
153 | image_list, max_width = get_image_list(
154 | horizontal_list,
155 | free_list,
156 | img_cv_grey,
157 | model_height=imgH,
158 | )
159 |
160 | result = get_text(image_list, self.recognizer, self.converter, opt2val)
161 |
162 | if paragraph:
163 | result = get_paragraph(result, mode="ltr")
164 |
165 | if skip_details: # texts only
166 | return [item[1] for item in result]
167 | else: # full outputs: bounding box, text and confident score
168 | return result
169 |
170 | def __call__(
171 | self,
172 | image,
173 | batch_size: int = 1,
174 | n_workers: int = 0,
175 | skip_details: bool = False,
176 | paragraph: bool = False,
177 | min_size: int = 20,
178 | contrast_ths: float = 0.1,
179 | adjust_contrast: float = 0.5,
180 | filter_ths: float = 0.003,
181 | text_threshold: float = 0.7,
182 | low_text: float = 0.4,
183 | link_threshold: float = 0.4,
184 | canvas_size: int = 2560,
185 | mag_ratio: float = 1.0,
186 | slope_ths: float = 0.1,
187 | ycenter_ths: float = 0.5,
188 | height_ths: float = 0.5,
189 | width_ths: float = 0.5,
190 | add_margin: float = 0.1,
191 | ):
192 | """
193 | Detect text in the image and then recognize it.
194 | :param image: file path or numpy-array or a byte stream object
195 | :param batch_size:
196 | :param n_workers:
197 | :param skip_details:
198 | :param paragraph:
199 | :param min_size:
200 | :param contrast_ths:
201 | :param adjust_contrast:
202 | :param filter_ths:
203 | :param text_threshold:
204 | :param low_text:
205 | :param link_threshold:
206 | :param canvas_size:
207 | :param mag_ratio:
208 | :param slope_ths:
209 | :param ycenter_ths:
210 | :param height_ths:
211 | :param width_ths:
212 | :param add_margin:
213 | :return:
214 | """
215 | # update `opt2val`
216 | self.opt2val["batch_size"] = batch_size
217 | self.opt2val["n_workers"] = n_workers
218 | self.opt2val["skip_details"] = skip_details
219 | self.opt2val["paragraph"] = paragraph
220 | self.opt2val["min_size"] = min_size
221 | self.opt2val["contrast_ths"] = contrast_ths
222 | self.opt2val["adjust_contrast"] = adjust_contrast
223 | self.opt2val["filter_ths"] = filter_ths
224 | self.opt2val["text_threshold"] = text_threshold
225 | self.opt2val["low_text"] = low_text
226 | self.opt2val["link_threshold"] = link_threshold
227 | self.opt2val["canvas_size"] = canvas_size
228 | self.opt2val["mag_ratio"] = mag_ratio
229 | self.opt2val["slope_ths"] = slope_ths
230 | self.opt2val["ycenter_ths"] = ycenter_ths
231 | self.opt2val["height_ths"] = height_ths
232 | self.opt2val["width_ths"] = width_ths
233 | self.opt2val["add_margin"] = add_margin
234 |
235 | img, img_cv_grey = reformat_input(image) # img, img_cv_grey: array
236 |
237 | horizontal_list, free_list = self.detect(img, self.opt2val)
238 | result = self.recognize(
239 | img_cv_grey,
240 | horizontal_list,
241 | free_list,
242 | self.opt2val,
243 | )
244 |
245 | return result
246 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | class TpsSpatialTransformerNetwork(nn.Module):
10 | """ Rectification Network of RARE, namely TPS based STN """
11 |
12 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1):
13 | """Based on RARE TPS
14 | input:
15 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
16 | I_size : (height, width) of the input image I
17 | I_r_size : (height, width) of the rectified image I_r
18 | I_channel_num : the number of channels of the input image I
19 | output:
20 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
21 | """
22 | super(TpsSpatialTransformerNetwork, self).__init__()
23 | self.F = F
24 | self.I_size = I_size
25 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
26 | self.I_channel_num = I_channel_num
27 | self.LocalizationNetwork = LocalizationNetwork(self.F,
28 | self.I_channel_num)
29 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
30 |
31 | def forward(self, batch_I):
32 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
33 | build_P_prime = self.GridGenerator.build_P_prime(
34 | batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
35 | build_P_prime_reshape = build_P_prime.reshape(
36 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
37 |
38 | # if torch.__version__ > "1.2.0":
39 | # batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
40 | # else:
41 | batch_I_r = F.grid_sample(batch_I,
42 | build_P_prime_reshape,
43 | padding_mode="border")
44 |
45 | return batch_I_r
46 |
47 |
48 | class LocalizationNetwork(nn.Module):
49 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
50 |
51 | def __init__(self, F, I_channel_num: int):
52 | super(LocalizationNetwork, self).__init__()
53 | self.F = F
54 | self.I_channel_num = I_channel_num
55 | self.conv = nn.Sequential(
56 | nn.Conv2d(
57 | in_channels=self.I_channel_num,
58 | out_channels=64,
59 | kernel_size=3,
60 | stride=1,
61 | padding=1,
62 | bias=False,
63 | ),
64 | nn.BatchNorm2d(64),
65 | nn.ReLU(True),
66 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
67 | nn.Conv2d(64, 128, 3, 1, 1, bias=False),
68 | nn.BatchNorm2d(128),
69 | nn.ReLU(True),
70 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
71 | nn.Conv2d(128, 256, 3, 1, 1, bias=False),
72 | nn.BatchNorm2d(256),
73 | nn.ReLU(True),
74 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
75 | nn.Conv2d(256, 512, 3, 1, 1, bias=False),
76 | nn.BatchNorm2d(512),
77 | nn.ReLU(True),
78 | nn.AdaptiveAvgPool2d(1), # batch_size x 512
79 | )
80 |
81 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256),
82 | nn.ReLU(True))
83 | self.localization_fc2 = nn.Linear(256, self.F * 2)
84 |
85 | # Init fc2 in LocalizationNetwork
86 | self.localization_fc2.weight.data.fill_(0)
87 |
88 | # see RARE paper Fig. 6 (a)
89 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
90 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
91 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
92 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
93 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
94 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
95 | self.localization_fc2.bias.data = (
96 | torch.from_numpy(initial_bias).float().view(-1))
97 |
98 | def forward(self, batch_I):
99 | """
100 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
101 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
102 | """
103 | batch_size = batch_I.size(0)
104 | features = self.conv(batch_I).view(batch_size, -1)
105 | batch_C_prime = self.localization_fc2(
106 | self.localization_fc1(features)).view(batch_size, self.F, 2)
107 | return batch_C_prime
108 |
109 |
110 | class GridGenerator(nn.Module):
111 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """
112 |
113 | def __init__(self, F, I_r_size):
114 | """ Generate P_hat and inv_delta_C for later """
115 | super(GridGenerator, self).__init__()
116 | self.eps = 1e-6
117 | self.I_r_height, self.I_r_width = I_r_size
118 | self.F = F
119 | self.C = self._build_C(self.F) # F x 2
120 | self.P = self._build_P(self.I_r_width, self.I_r_height)
121 |
122 | # for multi-gpu, you need register buffer
123 | self.register_buffer(
124 | "inv_delta_C",
125 | torch.tensor(self._build_inv_delta_C(self.F,
126 | self.C)).float()) # F+3 x F+3
127 | self.register_buffer("P_hat",
128 | torch.tensor(
129 | self._build_P_hat(self.F, self.C,
130 | self.P)).float()) # n x F+3
131 |
132 | def _build_C(self, F):
133 | """ Return coordinates of fiducial points in I_r; C """
134 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
135 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
136 | ctrl_pts_y_bottom = np.ones(int(F / 2))
137 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
138 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
139 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
140 | return C # F x 2
141 |
142 | def _build_inv_delta_C(self, F, C):
143 | """ Return inv_delta_C which is needed to calculate T """
144 | hat_C = np.zeros((F, F), dtype=float) # F x F
145 | for i in range(0, F):
146 | for j in range(i, F):
147 | r = np.linalg.norm(C[i] - C[j])
148 | hat_C[i, j] = r
149 | hat_C[j, i] = r
150 | np.fill_diagonal(hat_C, 1)
151 | hat_C = (hat_C**2) * np.log(hat_C)
152 | # print(C.shape, hat_C.shape)
153 | delta_C = np.concatenate( # F+3 x F+3
154 | [
155 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
156 | np.concatenate([np.zeros(
157 | (2, 3)), np.transpose(C)], axis=1), # 2 x F+3
158 | np.concatenate([np.zeros(
159 | (1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
160 | ],
161 | axis=0,
162 | )
163 | inv_delta_C = np.linalg.inv(delta_C)
164 | return inv_delta_C # F+3 x F+3
165 |
166 | def _build_P(self, I_r_width, I_r_height):
167 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) +
168 | 1.0) / I_r_width # self.I_r_width
169 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) +
170 | 1.0) / I_r_height # self.I_r_height
171 | P = np.stack( # self.I_r_width x self.I_r_height x 2
172 | np.meshgrid(I_r_grid_x, I_r_grid_y),
173 | axis=2)
174 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
175 |
176 | def _build_P_hat(self, F, C, P):
177 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
178 | P_tile = np.tile(np.expand_dims(P, axis=1),
179 | (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
180 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
181 | P_diff = P_tile - C_tile # n x F x 2
182 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2,
183 | keepdims=False) # n x F
184 | rbf = np.multiply(np.square(rbf_norm),
185 | np.log(rbf_norm + self.eps)) # n x F
186 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
187 | return P_hat # n x F+3
188 |
189 | def build_P_prime(self, batch_C_prime):
190 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """
191 | batch_size = batch_C_prime.size(0)
192 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
193 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
194 | batch_C_prime_with_zeros = torch.cat(
195 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)),
196 | dim=1) # batch_size x F+3 x 2
197 | batch_T = torch.bmm(batch_inv_delta_C,
198 | batch_C_prime_with_zeros) # batch_size x F+3 x 2
199 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
200 | return batch_P_prime # batch_size x n x 2
201 |
--------------------------------------------------------------------------------
/pororo/tasks/utils/download_utils.py:
--------------------------------------------------------------------------------
1 | """Module download related function from. Tenth"""
2 |
3 | import logging
4 | import os
5 | import platform
6 | import sys
7 | import zipfile
8 | from dataclasses import dataclass
9 | from typing import Tuple, Union
10 |
11 | import wget
12 |
13 | from pororo.tasks.utils.config import CONFIGS
14 |
15 | DEFAULT_PREFIX = {
16 | "model": "https://twg.kakaocdn.net/pororo/{lang}/models",
17 | "dict": "https://twg.kakaocdn.net/pororo/{lang}/dicts",
18 | }
19 |
20 |
21 | @dataclass
22 | class TransformerInfo:
23 | r"Dataclass for transformer-based model"
24 | path: str
25 | dict_path: str
26 | src_dict: str
27 | tgt_dict: str
28 | src_tok: Union[str, None]
29 | tgt_tok: Union[str, None]
30 |
31 |
32 | @dataclass
33 | class DownloadInfo:
34 | r"Download information such as defined directory, language and model name"
35 | n_model: str
36 | lang: str
37 | root_dir: str
38 |
39 |
40 | def get_save_dir(save_dir: str = None) -> str:
41 | """
42 | Get default save directory
43 |
44 | Args:
45 | savd_dir(str): User-defined save directory
46 |
47 | Returns:
48 | str: Set save directory
49 |
50 | """
51 | # If user wants to manually define save directory
52 | if save_dir:
53 | os.makedirs(save_dir, exist_ok=True)
54 | return save_dir
55 |
56 | pf = platform.system()
57 |
58 | if pf == "Windows":
59 | save_dir = "C:\\pororo"
60 | else:
61 | home_dir = os.path.expanduser("~")
62 | save_dir = os.path.join(home_dir, ".pororo")
63 |
64 | if not os.path.exists(save_dir):
65 | os.makedirs(save_dir)
66 |
67 | return save_dir
68 |
69 |
70 | def get_download_url(n_model: str, key: str, lang: str) -> str:
71 | """
72 | Get download url using default prefix
73 |
74 | Args:
75 | n_model (str): model name
76 | key (str): key name either `model` or `dict`
77 | lang (str): language name
78 |
79 | Returns:
80 | str: generated download url
81 |
82 | """
83 | default_prefix = DEFAULT_PREFIX[key].format(lang=lang)
84 | return f"{default_prefix}/{n_model}"
85 |
86 |
87 | def download_or_load_bert(info: DownloadInfo) -> str:
88 | """
89 | Download fine-tuned BrainBert & BrainSBert model and dict
90 |
91 | Args:
92 | info (DownloadInfo): download information
93 |
94 | Returns:
95 | str: downloaded bert & sbert path
96 |
97 | """
98 | model_path = os.path.join(info.root_dir, info.n_model)
99 |
100 | if not os.path.exists(model_path):
101 | info.n_model += ".zip"
102 | zip_path = os.path.join(info.root_dir, info.n_model)
103 |
104 | type_dir = download_from_url(
105 | info.n_model,
106 | zip_path,
107 | key="model",
108 | lang=info.lang,
109 | )
110 |
111 | zip_file = zipfile.ZipFile(zip_path)
112 | zip_file.extractall(type_dir)
113 | zip_file.close()
114 |
115 | return model_path
116 |
117 |
118 | def download_or_load_transformer(info: DownloadInfo) -> TransformerInfo:
119 | """
120 | Download pre-trained Transformer model and corresponding dict
121 |
122 | Args:
123 | info (DownloadInfo): download information
124 |
125 | Returns:
126 | TransformerInfo: information dataclass for transformer construction
127 |
128 | """
129 | config = CONFIGS[info.n_model.split("/")[-1]]
130 |
131 | src_dict_in = config.src_dict
132 | tgt_dict_in = config.tgt_dict
133 | src_tok = config.src_tok
134 | tgt_tok = config.tgt_tok
135 |
136 | info.n_model += ".pt"
137 | model_path = os.path.join(info.root_dir, info.n_model)
138 |
139 | # Download or load Transformer model
140 | model_type_dir = "/".join(model_path.split("/")[:-1])
141 | if not os.path.exists(model_path):
142 | model_type_dir = download_from_url(
143 | info.n_model,
144 | model_path,
145 | key="model",
146 | lang=info.lang,
147 | )
148 |
149 | dict_type_dir = str()
150 | src_dict, tgt_dict = str(), str()
151 |
152 | # Download or load corresponding dictionary
153 | if src_dict_in:
154 | src_dict = f"{src_dict_in}.txt"
155 | src_dict_path = os.path.join(info.root_dir, f"dicts/{src_dict}")
156 | dict_type_dir = "/".join(src_dict_path.split("/")[:-1])
157 | if not os.path.exists(src_dict_path):
158 | dict_type_dir = download_from_url(
159 | src_dict,
160 | src_dict_path,
161 | key="dict",
162 | lang=info.lang,
163 | )
164 |
165 | if tgt_dict_in:
166 | tgt_dict = f"{tgt_dict_in}.txt"
167 | tgt_dict_path = os.path.join(info.root_dir, f"dicts/{tgt_dict}")
168 | if not os.path.exists(tgt_dict_path):
169 | download_from_url(
170 | tgt_dict,
171 | tgt_dict_path,
172 | key="dict",
173 | lang=info.lang,
174 | )
175 |
176 | # Download or load corresponding tokenizer
177 | src_tok_path, tgt_tok_path = None, None
178 | if src_tok:
179 | src_tok_path = download_or_load(
180 | f"tokenizers/{src_tok}.zip",
181 | lang=info.lang,
182 | )
183 | if tgt_tok:
184 | tgt_tok_path = download_or_load(
185 | f"tokenizers/{tgt_tok}.zip",
186 | lang=info.lang,
187 | )
188 |
189 | return TransformerInfo(
190 | path=model_type_dir,
191 | dict_path=dict_type_dir,
192 | # Drop prefix "dict." and postfix ".txt"
193 | src_dict=".".join(src_dict.split(".")[1:-1]),
194 | # to follow fairseq's dictionary load process
195 | tgt_dict=".".join(tgt_dict.split(".")[1:-1]),
196 | src_tok=src_tok_path,
197 | tgt_tok=tgt_tok_path,
198 | )
199 |
200 |
201 | def download_or_load_misc(info: DownloadInfo) -> str:
202 | """
203 | Download (pre-trained) miscellaneous model
204 |
205 | Args:
206 | info (DownloadInfo): download information
207 |
208 | Returns:
209 | str: miscellaneous model path
210 |
211 | """
212 | # Add postfix <.model> for sentencepiece
213 | if "sentencepiece" in info.n_model:
214 | info.n_model += ".model"
215 |
216 | # Generate target model path using root directory
217 | model_path = os.path.join(info.root_dir, info.n_model)
218 | if not os.path.exists(model_path):
219 | type_dir = download_from_url(
220 | info.n_model,
221 | model_path,
222 | key="model",
223 | lang=info.lang,
224 | )
225 |
226 | if ".zip" in info.n_model:
227 | zip_file = zipfile.ZipFile(model_path)
228 | zip_file.extractall(type_dir)
229 | zip_file.close()
230 |
231 | if ".zip" in info.n_model:
232 | model_path = model_path[:model_path.rfind(".zip")]
233 | return model_path
234 |
235 |
236 | def download_or_load_bart(info: DownloadInfo) -> Union[str, Tuple[str, str]]:
237 | """
238 | Download BART model
239 |
240 | Args:
241 | info (DownloadInfo): download information
242 |
243 | Returns:
244 | Union[str, Tuple[str, str]]: BART model path (with. corresponding SentencePiece)
245 |
246 | """
247 | info.n_model += ".pt"
248 |
249 | model_path = os.path.join(info.root_dir, info.n_model)
250 | if not os.path.exists(model_path):
251 | download_from_url(
252 | info.n_model,
253 | model_path,
254 | key="model",
255 | lang=info.lang,
256 | )
257 |
258 | return model_path
259 |
260 |
261 | def download_from_url(
262 | n_model: str,
263 | model_path: str,
264 | key: str,
265 | lang: str,
266 | ) -> str:
267 | """
268 | Download specified model from Tenth
269 |
270 | Args:
271 | n_model (str): model name
272 | model_path (str): pre-defined model path
273 | key (str): type key (either model or dict)
274 | lang (str): language name
275 |
276 | Returns:
277 | str: default type directory
278 |
279 | """
280 | # Get default type dir path
281 | type_dir = "/".join(model_path.split("/")[:-1])
282 | os.makedirs(type_dir, exist_ok=True)
283 |
284 | # Get download tenth url
285 | url = get_download_url(n_model, key=key, lang=lang)
286 |
287 | logging.info("Downloading user-selected model...")
288 | wget.download(url, type_dir)
289 | sys.stderr.write("\n")
290 | sys.stderr.flush()
291 |
292 | return type_dir
293 |
294 |
295 | def download_or_load(
296 | n_model: str,
297 | lang: str,
298 | custom_save_dir: str = None,
299 | ) -> Union[TransformerInfo, str, Tuple[str, str]]:
300 | """
301 | Download or load model based on model information
302 |
303 | Args:
304 | n_model (str): model name
305 | lang (str): language information
306 | custom_save_dir (str, optional): user-defined save directory path. defaults to None.
307 |
308 | Returns:
309 | Union[TransformerInfo, str, Tuple[str, str]]
310 |
311 | """
312 | root_dir = get_save_dir(save_dir=custom_save_dir)
313 | info = DownloadInfo(n_model, lang, root_dir)
314 |
315 | if "transformer" in n_model:
316 | return download_or_load_transformer(info)
317 | if "bert" in n_model:
318 | return download_or_load_bert(info)
319 | if "bart" in n_model and "bpe" not in n_model:
320 | return download_or_load_bart(info)
321 |
322 | return download_or_load_misc(info)
323 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/craft_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from avhttps://github.com/clovaai/CRAFT-pytorch/blob/master/craft_utils.py
3 | MIT License
4 | """
5 |
6 | import math
7 |
8 | import cv2
9 | import numpy as np
10 |
11 |
12 | def warp_coord(Minv, pt):
13 | """auxilary functions: unwarp corodinates: """
14 | out = np.matmul(Minv, (pt[0], pt[1], 1))
15 | return np.array([out[0] / out[2], out[1] / out[2]])
16 |
17 |
18 | def get_det_boxes_core(textmap, linkmap, text_threshold, link_threshold,
19 | low_text):
20 | # prepare data
21 | linkmap = linkmap.copy()
22 | textmap = textmap.copy()
23 | img_h, img_w = textmap.shape
24 |
25 | # labeling method
26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
28 |
29 | text_score_comb = np.clip(text_score + link_score, 0, 1)
30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
31 | text_score_comb.astype(np.uint8), connectivity=4)
32 |
33 | det = []
34 | mapper = []
35 | for k in range(1, nLabels):
36 | # size filtering
37 | size = stats[k, cv2.CC_STAT_AREA]
38 | if size < 10:
39 | continue
40 |
41 | # thresholding
42 | if np.max(textmap[labels == k]) < text_threshold:
43 | continue
44 |
45 | # make segmentation map
46 | segmap = np.zeros(textmap.shape, dtype=np.uint8)
47 | segmap[labels == k] = 255
48 | segmap[np.logical_and(link_score == 1,
49 | text_score == 0)] = 0 # remove link area
50 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
51 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
52 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
53 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
54 | # boundary check
55 | if sx < 0:
56 | sx = 0
57 | if sy < 0:
58 | sy = 0
59 | if ex >= img_w:
60 | ex = img_w
61 | if ey >= img_h:
62 | ey = img_h
63 | kernel = cv2.getStructuringElement(
64 | cv2.MORPH_RECT,
65 | (1 + niter, 1 + niter),
66 | )
67 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
68 |
69 | # make box
70 | np_contours = (np.roll(np.array(np.where(segmap != 0)), 1,
71 | axis=0).transpose().reshape(-1, 2))
72 | rectangle = cv2.minAreaRect(np_contours)
73 | box = cv2.boxPoints(rectangle)
74 |
75 | # align diamond-shape
76 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
77 | box_ratio = max(w, h) / (min(w, h) + 1e-5)
78 | if abs(1 - box_ratio) <= 0.1:
79 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
80 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
81 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
82 |
83 | # make clock-wise order
84 | startidx = box.sum(axis=1).argmin()
85 | box = np.roll(box, 4 - startidx, 0)
86 | box = np.array(box)
87 |
88 | det.append(box)
89 | mapper.append(k)
90 |
91 | return det, labels, mapper
92 |
93 |
94 | def get_poly_core(boxes, labels, mapper, linkmap):
95 | # configs
96 | num_cp = 5
97 | max_len_ratio = 0.7
98 | expand_ratio = 1.45
99 | max_r = 2.0
100 | step_r = 0.2
101 |
102 | polys = []
103 | for k, box in enumerate(boxes):
104 | # size filter for small instance
105 | w, h = int(np.linalg.norm(box[0] - box[1]) +
106 | 1), int(np.linalg.norm(box[1] - box[2]) + 1)
107 | if w < 10 or h < 10:
108 | polys.append(None)
109 | continue
110 |
111 | # warp image
112 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
113 | M = cv2.getPerspectiveTransform(box, tar)
114 | word_label = cv2.warpPerspective(
115 | labels,
116 | M,
117 | (w, h),
118 | flags=cv2.INTER_NEAREST,
119 | )
120 | try:
121 | Minv = np.linalg.inv(M)
122 | except:
123 | polys.append(None)
124 | continue
125 |
126 | # binarization for selected label
127 | cur_label = mapper[k]
128 | word_label[word_label != cur_label] = 0
129 | word_label[word_label > 0] = 1
130 |
131 | # Polygon generation: find top/bottom contours
132 | cp = []
133 | max_len = -1
134 | for i in range(w):
135 | region = np.where(word_label[:, i] != 0)[0]
136 | if len(region) < 2:
137 | continue
138 | cp.append((i, region[0], region[-1]))
139 | length = region[-1] - region[0] + 1
140 | if length > max_len:
141 | max_len = length
142 |
143 | # pass if max_len is similar to h
144 | if h * max_len_ratio < max_len:
145 | polys.append(None)
146 | continue
147 |
148 | # get pivot points with fixed length
149 | tot_seg = num_cp * 2 + 1
150 | seg_w = w / tot_seg # segment width
151 | pp = [None] * num_cp # init pivot points
152 | cp_section = [[0, 0]] * tot_seg
153 | seg_height = [0] * num_cp
154 | seg_num = 0
155 | num_sec = 0
156 | prev_h = -1
157 | for i in range(0, len(cp)):
158 | (x, sy, ey) = cp[i]
159 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
160 | # average previous segment
161 | if num_sec == 0:
162 | break
163 | cp_section[seg_num] = [
164 | cp_section[seg_num][0] / num_sec,
165 | cp_section[seg_num][1] / num_sec,
166 | ]
167 | num_sec = 0
168 |
169 | # reset variables
170 | seg_num += 1
171 | prev_h = -1
172 |
173 | # accumulate center points
174 | cy = (sy + ey) * 0.5
175 | cur_h = ey - sy + 1
176 | cp_section[seg_num] = [
177 | cp_section[seg_num][0] + x,
178 | cp_section[seg_num][1] + cy,
179 | ]
180 | num_sec += 1
181 |
182 | if seg_num % 2 == 0:
183 | continue # No polygon area
184 |
185 | if prev_h < cur_h:
186 | pp[int((seg_num - 1) / 2)] = (x, cy)
187 | seg_height[int((seg_num - 1) / 2)] = cur_h
188 | prev_h = cur_h
189 |
190 | # processing last segment
191 | if num_sec != 0:
192 | cp_section[-1] = [
193 | cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec
194 | ]
195 |
196 | # pass if num of pivots is not sufficient or segment width is smaller than character height
197 | if None in pp or seg_w < np.max(seg_height) * 0.25:
198 | polys.append(None)
199 | continue
200 |
201 | # calc median maximum of pivot points
202 | half_char_h = np.median(seg_height) * expand_ratio / 2
203 |
204 | # calc gradiant and apply to make horizontal pivots
205 | new_pp = []
206 | for i, (x, cy) in enumerate(pp):
207 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
208 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
209 | if dx == 0: # gradient if zero
210 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
211 | continue
212 | rad = -math.atan2(dy, dx)
213 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
214 | new_pp.append([x - s, cy - c, x + s, cy + c])
215 |
216 | # get edge points to cover character heatmaps
217 | isSppFound, isEppFound = False, False
218 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (
219 | pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
220 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (
221 | pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
222 | for r in np.arange(0.5, max_r, step_r):
223 | dx = 2 * half_char_h * r
224 | if not isSppFound:
225 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
226 | dy = grad_s * dx
227 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
228 | cv2.line(
229 | line_img,
230 | (int(p[0]), int(p[1])),
231 | (int(p[2]), int(p[3])),
232 | 1,
233 | thickness=1,
234 | )
235 | if (np.sum(np.logical_and(word_label, line_img)) == 0 or
236 | r + 2 * step_r >= max_r):
237 | spp = p
238 | isSppFound = True
239 | if not isEppFound:
240 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
241 | dy = grad_e * dx
242 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
243 | cv2.line(
244 | line_img,
245 | (int(p[0]), int(p[1])),
246 | (int(p[2]), int(p[3])),
247 | 1,
248 | thickness=1,
249 | )
250 | if (np.sum(np.logical_and(word_label, line_img)) == 0 or
251 | r + 2 * step_r >= max_r):
252 | epp = p
253 | isEppFound = True
254 | if isSppFound and isEppFound:
255 | break
256 |
257 | # pass if boundary of polygon is not found
258 | if not (isSppFound and isEppFound):
259 | polys.append(None)
260 | continue
261 |
262 | # make final polygon
263 | poly = []
264 | poly.append(warp_coord(Minv, (spp[0], spp[1])))
265 | for p in new_pp:
266 | poly.append(warp_coord(Minv, (p[0], p[1])))
267 | poly.append(warp_coord(Minv, (epp[0], epp[1])))
268 | poly.append(warp_coord(Minv, (epp[2], epp[3])))
269 | for p in reversed(new_pp):
270 | poly.append(warp_coord(Minv, (p[2], p[3])))
271 | poly.append(warp_coord(Minv, (spp[2], spp[3])))
272 |
273 | # add to final result
274 | polys.append(np.array(poly))
275 |
276 | return polys
277 |
278 |
279 | def get_det_boxes(
280 | textmap,
281 | linkmap,
282 | text_threshold,
283 | link_threshold,
284 | low_text,
285 | poly=False,
286 | ):
287 | boxes, labels, mapper = get_det_boxes_core(
288 | textmap,
289 | linkmap,
290 | text_threshold,
291 | link_threshold,
292 | low_text,
293 | )
294 |
295 | if poly:
296 | polys = get_poly_core(boxes, labels, mapper, linkmap)
297 | else:
298 | polys = [None] * len(boxes)
299 |
300 | return boxes, polys
301 |
302 |
303 | def adjust_result_coordinates(polys, ratio_w, ratio_h, ratio_net=2):
304 | if len(polys) > 0:
305 | polys = np.array(polys)
306 | for k in range(len(polys)):
307 | if polys[k] is not None:
308 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
309 | return polys
310 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/modules/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class VGGFeatureExtractor(nn.Module):
5 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
6 |
7 | def __init__(self,
8 | n_input_channels: int = 1,
9 | n_output_channels: int = 512,
10 | opt2val=None):
11 | super(VGGFeatureExtractor, self).__init__()
12 |
13 | self.output_channel = [
14 | int(n_output_channels / 8),
15 | int(n_output_channels / 4),
16 | int(n_output_channels / 2),
17 | n_output_channels,
18 | ] # [64, 128, 256, 512]
19 |
20 | rec_model_ckpt_fp = opt2val["rec_model_ckpt_fp"]
21 | if "baseline" in rec_model_ckpt_fp:
22 | self.ConvNet = nn.Sequential(
23 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
24 | nn.ReLU(True),
25 | nn.MaxPool2d(2, 2), # 64x16x50
26 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1,
27 | 1),
28 | nn.ReLU(True),
29 | nn.MaxPool2d(2, 2), # 128x8x25
30 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1,
31 | 1),
32 | nn.ReLU(True), # 256x8x25
33 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1,
34 | 1),
35 | nn.ReLU(True),
36 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
37 | nn.Conv2d(self.output_channel[2],
38 | self.output_channel[3],
39 | 3,
40 | 1,
41 | 1,
42 | bias=False),
43 | nn.BatchNorm2d(self.output_channel[3]),
44 | nn.ReLU(True), # 512x4x25
45 | nn.Conv2d(self.output_channel[3],
46 | self.output_channel[3],
47 | 3,
48 | 1,
49 | 1,
50 | bias=False),
51 | nn.BatchNorm2d(self.output_channel[3]),
52 | nn.ReLU(True),
53 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
54 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
55 | nn.ConvTranspose2d(self.output_channel[3],
56 | self.output_channel[3], 2, 2),
57 | nn.ReLU(True),
58 | ) # 512x4x50
59 | else:
60 | self.ConvNet = nn.Sequential(
61 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
62 | nn.ReLU(True),
63 | nn.MaxPool2d(2, 2), # 64x16x50
64 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1,
65 | 1),
66 | nn.ReLU(True),
67 | nn.MaxPool2d(2, 2), # 128x8x25
68 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1,
69 | 1),
70 | nn.ReLU(True), # 256x8x25
71 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1,
72 | 1),
73 | nn.ReLU(True),
74 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
75 | nn.Conv2d(self.output_channel[2],
76 | self.output_channel[3],
77 | 3,
78 | 1,
79 | 1,
80 | bias=False),
81 | nn.BatchNorm2d(self.output_channel[3]),
82 | nn.ReLU(True), # 512x4x25
83 | nn.Conv2d(self.output_channel[3],
84 | self.output_channel[3],
85 | 3,
86 | 1,
87 | 1,
88 | bias=False),
89 | nn.BatchNorm2d(self.output_channel[3]),
90 | nn.ReLU(True),
91 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
92 | # nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
93 | nn.ConvTranspose2d(self.output_channel[3],
94 | self.output_channel[3], 2, 2),
95 | nn.ReLU(True), # 512x4x50
96 | nn.ConvTranspose2d(self.output_channel[3],
97 | self.output_channel[3], 2, 2),
98 | nn.ReLU(True),
99 | ) # 512x4x50
100 |
101 | def forward(self, x):
102 | return self.ConvNet(x)
103 |
104 |
105 | class ResNetFeatureExtractor(nn.Module):
106 | """
107 | FeatureExtractor of FAN
108 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf)
109 | """
110 |
111 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
112 | super(ResNetFeatureExtractor, self).__init__()
113 | self.ConvNet = ResNet(n_input_channels, n_output_channels, BasicBlock,
114 | [1, 2, 5, 3])
115 |
116 | def forward(self, inputs):
117 | return self.ConvNet(inputs)
118 |
119 |
120 | class BasicBlock(nn.Module):
121 | expansion = 1
122 |
123 | def __init__(self,
124 | inplanes: int,
125 | planes: int,
126 | stride: int = 1,
127 | downsample=None):
128 | super(BasicBlock, self).__init__()
129 | self.conv1 = self._conv3x3(inplanes, planes)
130 | self.bn1 = nn.BatchNorm2d(planes)
131 | self.conv2 = self._conv3x3(planes, planes)
132 | self.bn2 = nn.BatchNorm2d(planes)
133 | self.relu = nn.ReLU(inplace=True)
134 | self.downsample = downsample
135 | self.stride = stride
136 |
137 | def _conv3x3(self, in_planes, out_planes, stride=1):
138 | "3x3 convolution with padding"
139 | return nn.Conv2d(in_planes,
140 | out_planes,
141 | kernel_size=3,
142 | stride=stride,
143 | padding=1,
144 | bias=False)
145 |
146 | def forward(self, x):
147 | residual = x
148 |
149 | out = self.conv1(x)
150 | out = self.bn1(out)
151 | out = self.relu(out)
152 |
153 | out = self.conv2(out)
154 | out = self.bn2(out)
155 |
156 | if self.downsample is not None:
157 | residual = self.downsample(x)
158 | out += residual
159 | out = self.relu(out)
160 |
161 | return out
162 |
163 |
164 | class ResNet(nn.Module):
165 |
166 | def __init__(self, n_input_channels: int, n_output_channels: int, block,
167 | layers):
168 | """
169 | :param n_input_channels (int): The number of input channels of the feature extractor
170 | :param n_output_channels (int): The number of output channels of the feature extractor
171 | :param block:
172 | :param layers:
173 | """
174 | super(ResNet, self).__init__()
175 |
176 | self.output_channel_blocks = [
177 | int(n_output_channels / 4),
178 | int(n_output_channels / 2),
179 | n_output_channels,
180 | n_output_channels,
181 | ]
182 |
183 | self.inplanes = int(n_output_channels / 8)
184 | self.conv0_1 = nn.Conv2d(
185 | n_input_channels,
186 | int(n_output_channels / 16),
187 | kernel_size=3,
188 | stride=1,
189 | padding=1,
190 | bias=False,
191 | )
192 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16))
193 | self.conv0_2 = nn.Conv2d(
194 | int(n_output_channels / 16),
195 | self.inplanes,
196 | kernel_size=3,
197 | stride=1,
198 | padding=1,
199 | bias=False,
200 | )
201 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
202 | self.relu = nn.ReLU(inplace=True)
203 |
204 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
205 | self.layer1 = self._make_layer(block, self.output_channel_blocks[0],
206 | layers[0])
207 | self.conv1 = nn.Conv2d(
208 | self.output_channel_blocks[0],
209 | self.output_channel_blocks[0],
210 | kernel_size=3,
211 | stride=1,
212 | padding=1,
213 | bias=False,
214 | )
215 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0])
216 |
217 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
218 | self.layer2 = self._make_layer(block,
219 | self.output_channel_blocks[1],
220 | layers[1],
221 | stride=1)
222 | self.conv2 = nn.Conv2d(
223 | self.output_channel_blocks[1],
224 | self.output_channel_blocks[1],
225 | kernel_size=3,
226 | stride=1,
227 | padding=1,
228 | bias=False,
229 | )
230 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1])
231 |
232 | self.maxpool3 = nn.MaxPool2d(kernel_size=2,
233 | stride=(2, 1),
234 | padding=(0, 1))
235 | self.layer3 = self._make_layer(block,
236 | self.output_channel_blocks[2],
237 | layers[2],
238 | stride=1)
239 | self.conv3 = nn.Conv2d(
240 | self.output_channel_blocks[2],
241 | self.output_channel_blocks[2],
242 | kernel_size=3,
243 | stride=1,
244 | padding=1,
245 | bias=False,
246 | )
247 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2])
248 |
249 | self.layer4 = self._make_layer(block,
250 | self.output_channel_blocks[3],
251 | layers[3],
252 | stride=1)
253 | self.conv4_1 = nn.Conv2d(
254 | self.output_channel_blocks[3],
255 | self.output_channel_blocks[3],
256 | kernel_size=2,
257 | stride=(2, 1),
258 | padding=(0, 1),
259 | bias=False,
260 | )
261 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3])
262 | self.conv4_2 = nn.Conv2d(
263 | self.output_channel_blocks[3],
264 | self.output_channel_blocks[3],
265 | kernel_size=2,
266 | stride=1,
267 | padding=0,
268 | bias=False,
269 | )
270 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3])
271 |
272 | def _make_layer(self, block, planes, blocks, stride=1):
273 | downsample = None
274 | if stride != 1 or self.inplanes != planes * block.expansion:
275 | downsample = nn.Sequential(
276 | nn.Conv2d(
277 | self.inplanes,
278 | planes * block.expansion,
279 | kernel_size=1,
280 | stride=stride,
281 | bias=False,
282 | ),
283 | nn.BatchNorm2d(planes * block.expansion),
284 | )
285 |
286 | layers = []
287 | layers.append(block(self.inplanes, planes, stride, downsample))
288 | self.inplanes = planes * block.expansion
289 | for i in range(1, blocks):
290 | layers.append(block(self.inplanes, planes))
291 |
292 | return nn.Sequential(*layers)
293 |
294 | def forward(self, x):
295 | x = self.conv0_1(x)
296 | x = self.bn0_1(x)
297 | x = self.relu(x)
298 | x = self.conv0_2(x)
299 | x = self.bn0_2(x)
300 | x = self.relu(x)
301 |
302 | x = self.maxpool1(x)
303 | x = self.layer1(x)
304 | x = self.conv1(x)
305 | x = self.bn1(x)
306 | x = self.relu(x)
307 |
308 | x = self.maxpool2(x)
309 | x = self.layer2(x)
310 | x = self.conv2(x)
311 | x = self.bn2(x)
312 | x = self.relu(x)
313 |
314 | x = self.maxpool3(x)
315 | x = self.layer3(x)
316 | x = self.conv3(x)
317 | x = self.bn3(x)
318 | x = self.relu(x)
319 |
320 | x = self.layer4(x)
321 | x = self.conv4_1(x)
322 | x = self.bn4_1(x)
323 | x = self.relu(x)
324 | x = self.conv4_2(x)
325 | x = self.bn4_2(x)
326 | x = self.relu(x)
327 |
328 | return x
329 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/_modules.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.nn.init as init
8 | from torchvision import models
9 | from torchvision.models.vgg import model_urls
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | def init_weights(modules):
15 | for m in modules:
16 | if isinstance(m, nn.Conv2d):
17 | init.xavier_uniform_(m.weight.data)
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.BatchNorm2d):
21 | m.weight.data.fill_(1)
22 | m.bias.data.zero_()
23 | elif isinstance(m, nn.Linear):
24 | m.weight.data.normal_(0, 0.01)
25 | m.bias.data.zero_()
26 |
27 |
28 | class Vgg16BN(torch.nn.Module):
29 |
30 | def __init__(self, pretrained: bool = True, freeze: bool = True):
31 | super(Vgg16BN, self).__init__()
32 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace(
33 | "https://", "http://")
34 | vgg_pretrained_features = models.vgg16_bn(
35 | pretrained=pretrained).features
36 | self.slice1 = torch.nn.Sequential()
37 | self.slice2 = torch.nn.Sequential()
38 | self.slice3 = torch.nn.Sequential()
39 | self.slice4 = torch.nn.Sequential()
40 | self.slice5 = torch.nn.Sequential()
41 | for x in range(12): # conv2_2
42 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
43 | for x in range(12, 19): # conv3_3
44 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
45 | for x in range(19, 29): # conv4_3
46 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
47 | for x in range(29, 39): # conv5_3
48 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
49 |
50 | # fc6, fc7 without atrous conv
51 | self.slice5 = torch.nn.Sequential(
52 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
53 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
54 | nn.Conv2d(1024, 1024, kernel_size=1),
55 | )
56 |
57 | if not pretrained:
58 | init_weights(self.slice1.modules())
59 | init_weights(self.slice2.modules())
60 | init_weights(self.slice3.modules())
61 | init_weights(self.slice4.modules())
62 |
63 | init_weights(
64 | self.slice5.modules()) # no pretrained model for fc6 and fc7
65 |
66 | if freeze:
67 | for param in self.slice1.parameters(): # only first conv
68 | param.requires_grad = False
69 |
70 | def forward(self, x):
71 | h = self.slice1(x)
72 | h_relu2_2 = h
73 | h = self.slice2(h)
74 | h_relu3_2 = h
75 | h = self.slice3(h)
76 | h_relu4_3 = h
77 | h = self.slice4(h)
78 | h_relu5_3 = h
79 | h = self.slice5(h)
80 | h_fc7 = h
81 | vgg_outputs = namedtuple(
82 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"])
83 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
84 | return out
85 |
86 |
87 | class VGGFeatureExtractor(nn.Module):
88 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
89 |
90 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
91 | super(VGGFeatureExtractor, self).__init__()
92 |
93 | self.output_channel = [
94 | int(n_output_channels / 8),
95 | int(n_output_channels / 4),
96 | int(n_output_channels / 2),
97 | n_output_channels,
98 | ] # [64, 128, 256, 512]
99 | self.ConvNet = nn.Sequential(
100 | nn.Conv2d(n_input_channels, self.output_channel[0], 3, 1, 1),
101 | nn.ReLU(True),
102 | nn.MaxPool2d(2, 2), # 64x16x50
103 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1),
104 | nn.ReLU(True),
105 | nn.MaxPool2d(2, 2), # 128x8x25
106 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1),
107 | nn.ReLU(True), # 256x8x25
108 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1),
109 | nn.ReLU(True),
110 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
111 | nn.Conv2d(
112 | self.output_channel[2],
113 | self.output_channel[3],
114 | 3,
115 | 1,
116 | 1,
117 | bias=False,
118 | ),
119 | nn.BatchNorm2d(self.output_channel[3]),
120 | nn.ReLU(True), # 512x4x25
121 | nn.Conv2d(
122 | self.output_channel[3],
123 | self.output_channel[3],
124 | 3,
125 | 1,
126 | 1,
127 | bias=False,
128 | ),
129 | nn.BatchNorm2d(self.output_channel[3]),
130 | nn.ReLU(True),
131 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
132 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0),
133 | nn.ReLU(True),
134 | ) # 512x1x24
135 |
136 | def forward(self, x):
137 | return self.ConvNet(x)
138 |
139 |
140 | class BidirectionalLSTM(nn.Module):
141 |
142 | def __init__(self, input_size: int, hidden_size: int, output_size: int):
143 | super(BidirectionalLSTM, self).__init__()
144 | self.rnn = nn.LSTM(
145 | input_size,
146 | hidden_size,
147 | bidirectional=True,
148 | batch_first=True,
149 | )
150 | self.linear = nn.Linear(hidden_size * 2, output_size)
151 |
152 | def forward(self, x):
153 | """
154 | x : visual feature [batch_size x T x input_size]
155 | output : contextual feature [batch_size x T x output_size]
156 | """
157 | self.rnn.flatten_parameters()
158 | recurrent, _ = self.rnn(
159 | x
160 | ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
161 | output = self.linear(recurrent) # batch_size x T x output_size
162 | return output
163 |
164 |
165 | class ResNetFeatureExtractor(nn.Module):
166 | """
167 | FeatureExtractor of FAN
168 | (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf)
169 |
170 | """
171 |
172 | def __init__(self, n_input_channels: int = 1, n_output_channels: int = 512):
173 | super(ResNetFeatureExtractor, self).__init__()
174 | self.ConvNet = ResNet(
175 | n_input_channels,
176 | n_output_channels,
177 | BasicBlock,
178 | [1, 2, 5, 3],
179 | )
180 |
181 | def forward(self, inputs):
182 | return self.ConvNet(inputs)
183 |
184 |
185 | class BasicBlock(nn.Module):
186 | expansion = 1
187 |
188 | def __init__(self,
189 | inplanes: int,
190 | planes: int,
191 | stride: int = 1,
192 | downsample=None):
193 | super(BasicBlock, self).__init__()
194 | self.conv1 = self._conv3x3(inplanes, planes)
195 | self.bn1 = nn.BatchNorm2d(planes)
196 | self.conv2 = self._conv3x3(planes, planes)
197 | self.bn2 = nn.BatchNorm2d(planes)
198 | self.relu = nn.ReLU(inplace=True)
199 | self.downsample = downsample
200 | self.stride = stride
201 |
202 | def _conv3x3(self, in_planes, out_planes, stride=1):
203 | "3x3 convolution with padding"
204 | return nn.Conv2d(
205 | in_planes,
206 | out_planes,
207 | kernel_size=3,
208 | stride=stride,
209 | padding=1,
210 | bias=False,
211 | )
212 |
213 | def forward(self, x):
214 | residual = x
215 |
216 | out = self.conv1(x)
217 | out = self.bn1(out)
218 | out = self.relu(out)
219 |
220 | out = self.conv2(out)
221 | out = self.bn2(out)
222 |
223 | if self.downsample is not None:
224 | residual = self.downsample(x)
225 | out += residual
226 | out = self.relu(out)
227 |
228 | return out
229 |
230 |
231 | class ResNet(nn.Module):
232 |
233 | def __init__(
234 | self,
235 | n_input_channels: int,
236 | n_output_channels: int,
237 | block,
238 | layers,
239 | ):
240 | """
241 | :param n_input_channels (int): The number of input channels of the feature extractor
242 | :param n_output_channels (int): The number of output channels of the feature extractor
243 | :param block:
244 | :param layers:
245 | """
246 | super(ResNet, self).__init__()
247 |
248 | self.output_channel_blocks = [
249 | int(n_output_channels / 4),
250 | int(n_output_channels / 2),
251 | n_output_channels,
252 | n_output_channels,
253 | ]
254 |
255 | self.inplanes = int(n_output_channels / 8)
256 | self.conv0_1 = nn.Conv2d(
257 | n_input_channels,
258 | int(n_output_channels / 16),
259 | kernel_size=3,
260 | stride=1,
261 | padding=1,
262 | bias=False,
263 | )
264 | self.bn0_1 = nn.BatchNorm2d(int(n_output_channels / 16))
265 | self.conv0_2 = nn.Conv2d(
266 | int(n_output_channels / 16),
267 | self.inplanes,
268 | kernel_size=3,
269 | stride=1,
270 | padding=1,
271 | bias=False,
272 | )
273 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
274 | self.relu = nn.ReLU(inplace=True)
275 |
276 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
277 | self.layer1 = self._make_layer(
278 | block,
279 | self.output_channel_blocks[0],
280 | layers[0],
281 | )
282 | self.conv1 = nn.Conv2d(
283 | self.output_channel_blocks[0],
284 | self.output_channel_blocks[0],
285 | kernel_size=3,
286 | stride=1,
287 | padding=1,
288 | bias=False,
289 | )
290 | self.bn1 = nn.BatchNorm2d(self.output_channel_blocks[0])
291 |
292 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
293 | self.layer2 = self._make_layer(
294 | block,
295 | self.output_channel_blocks[1],
296 | layers[1],
297 | stride=1,
298 | )
299 | self.conv2 = nn.Conv2d(
300 | self.output_channel_blocks[1],
301 | self.output_channel_blocks[1],
302 | kernel_size=3,
303 | stride=1,
304 | padding=1,
305 | bias=False,
306 | )
307 | self.bn2 = nn.BatchNorm2d(self.output_channel_blocks[1])
308 |
309 | self.maxpool3 = nn.MaxPool2d(
310 | kernel_size=2,
311 | stride=(2, 1),
312 | padding=(0, 1),
313 | )
314 | self.layer3 = self._make_layer(
315 | block,
316 | self.output_channel_blocks[2],
317 | layers[2],
318 | stride=1,
319 | )
320 | self.conv3 = nn.Conv2d(
321 | self.output_channel_blocks[2],
322 | self.output_channel_blocks[2],
323 | kernel_size=3,
324 | stride=1,
325 | padding=1,
326 | bias=False,
327 | )
328 | self.bn3 = nn.BatchNorm2d(self.output_channel_blocks[2])
329 |
330 | self.layer4 = self._make_layer(
331 | block,
332 | self.output_channel_blocks[3],
333 | layers[3],
334 | stride=1,
335 | )
336 | self.conv4_1 = nn.Conv2d(
337 | self.output_channel_blocks[3],
338 | self.output_channel_blocks[3],
339 | kernel_size=2,
340 | stride=(2, 1),
341 | padding=(0, 1),
342 | bias=False,
343 | )
344 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_blocks[3])
345 | self.conv4_2 = nn.Conv2d(
346 | self.output_channel_blocks[3],
347 | self.output_channel_blocks[3],
348 | kernel_size=2,
349 | stride=1,
350 | padding=0,
351 | bias=False,
352 | )
353 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_blocks[3])
354 |
355 | def _make_layer(self, block, planes, blocks, stride=1):
356 | downsample = None
357 | if stride != 1 or self.inplanes != planes * block.expansion:
358 | downsample = nn.Sequential(
359 | nn.Conv2d(
360 | self.inplanes,
361 | planes * block.expansion,
362 | kernel_size=1,
363 | stride=stride,
364 | bias=False,
365 | ),
366 | nn.BatchNorm2d(planes * block.expansion),
367 | )
368 |
369 | layers = []
370 | layers.append(block(self.inplanes, planes, stride, downsample))
371 | self.inplanes = planes * block.expansion
372 | for i in range(1, blocks):
373 | layers.append(block(self.inplanes, planes))
374 |
375 | return nn.Sequential(*layers)
376 |
377 | def forward(self, x):
378 | x = self.conv0_1(x)
379 | x = self.bn0_1(x)
380 | x = self.relu(x)
381 | x = self.conv0_2(x)
382 | x = self.bn0_2(x)
383 | x = self.relu(x)
384 |
385 | x = self.maxpool1(x)
386 | x = self.layer1(x)
387 | x = self.conv1(x)
388 | x = self.bn1(x)
389 | x = self.relu(x)
390 |
391 | x = self.maxpool2(x)
392 | x = self.layer2(x)
393 | x = self.conv2(x)
394 | x = self.bn2(x)
395 | x = self.relu(x)
396 |
397 | x = self.maxpool3(x)
398 | x = self.layer3(x)
399 | x = self.conv3(x)
400 | x = self.bn3(x)
401 | x = self.relu(x)
402 |
403 | x = self.layer4(x)
404 | x = self.conv4_1(x)
405 | x = self.bn4_1(x)
406 | x = self.relu(x)
407 | x = self.conv4_2(x)
408 | x = self.bn4_2(x)
409 | x = self.relu(x)
410 |
411 | return x
412 |
413 |
414 | class TpsSpatialTransformerNetwork(nn.Module):
415 | """ Rectification Network of RARE, namely TPS based STN """
416 |
417 | def __init__(self, F, I_size, I_r_size, I_channel_num: int = 1):
418 | """Based on RARE TPS
419 | input:
420 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
421 | I_size : (height, width) of the input image I
422 | I_r_size : (height, width) of the rectified image I_r
423 | I_channel_num : the number of channels of the input image I
424 | output:
425 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
426 | """
427 | super(TpsSpatialTransformerNetwork, self).__init__()
428 | self.F = F
429 | self.I_size = I_size
430 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
431 | self.I_channel_num = I_channel_num
432 | self.LocalizationNetwork = LocalizationNetwork(
433 | self.F,
434 | self.I_channel_num,
435 | )
436 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
437 |
438 | def forward(self, batch_I):
439 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
440 | build_P_prime = self.GridGenerator.build_P_prime(
441 | batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
442 | build_P_prime_reshape = build_P_prime.reshape(
443 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
444 |
445 | batch_I_r = F.grid_sample(
446 | batch_I,
447 | build_P_prime_reshape,
448 | padding_mode="border",
449 | )
450 |
451 | return batch_I_r
452 |
453 |
454 | class LocalizationNetwork(nn.Module):
455 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
456 |
457 | def __init__(self, F, I_channel_num: int):
458 | super(LocalizationNetwork, self).__init__()
459 | self.F = F
460 | self.I_channel_num = I_channel_num
461 | self.conv = nn.Sequential(
462 | nn.Conv2d(
463 | in_channels=self.I_channel_num,
464 | out_channels=64,
465 | kernel_size=3,
466 | stride=1,
467 | padding=1,
468 | bias=False,
469 | ),
470 | nn.BatchNorm2d(64),
471 | nn.ReLU(True),
472 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
473 | nn.Conv2d(64, 128, 3, 1, 1, bias=False),
474 | nn.BatchNorm2d(128),
475 | nn.ReLU(True),
476 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
477 | nn.Conv2d(128, 256, 3, 1, 1, bias=False),
478 | nn.BatchNorm2d(256),
479 | nn.ReLU(True),
480 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
481 | nn.Conv2d(256, 512, 3, 1, 1, bias=False),
482 | nn.BatchNorm2d(512),
483 | nn.ReLU(True),
484 | nn.AdaptiveAvgPool2d(1), # batch_size x 512
485 | )
486 |
487 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256),
488 | nn.ReLU(True))
489 | self.localization_fc2 = nn.Linear(256, self.F * 2)
490 |
491 | # Init fc2 in LocalizationNetwork
492 | self.localization_fc2.weight.data.fill_(0)
493 |
494 | # see RARE paper Fig. 6 (a)
495 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
496 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
497 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
498 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
499 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
500 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
501 | self.localization_fc2.bias.data = (
502 | torch.from_numpy(initial_bias).float().view(-1))
503 |
504 | def forward(self, batch_I):
505 | """
506 | :param batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
507 | :return: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
508 | """
509 | batch_size = batch_I.size(0)
510 | features = self.conv(batch_I).view(batch_size, -1)
511 | batch_C_prime = self.localization_fc2(
512 | self.localization_fc1(features)).view(batch_size, self.F, 2)
513 | return batch_C_prime
514 |
515 |
516 | class GridGenerator(nn.Module):
517 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """
518 |
519 | def __init__(self, F, I_r_size):
520 | """ Generate P_hat and inv_delta_C for later """
521 | super(GridGenerator, self).__init__()
522 | self.eps = 1e-6
523 | self.I_r_height, self.I_r_width = I_r_size
524 | self.F = F
525 | self.C = self._build_C(self.F) # F x 2
526 | self.P = self._build_P(self.I_r_width, self.I_r_height)
527 |
528 | # for multi-gpu, you need register buffer
529 | self.register_buffer(
530 | "inv_delta_C",
531 | torch.tensor(self._build_inv_delta_C(
532 | self.F,
533 | self.C,
534 | )).float(),
535 | ) # F+3 x F+3
536 | self.register_buffer(
537 | "P_hat",
538 | torch.tensor(self._build_P_hat(
539 | self.F,
540 | self.C,
541 | self.P,
542 | )).float(),
543 | ) # n x F+3
544 |
545 | def _build_C(self, F):
546 | """ Return coordinates of fiducial points in I_r; C """
547 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
548 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
549 | ctrl_pts_y_bottom = np.ones(int(F / 2))
550 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
551 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
552 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
553 | return C # F x 2
554 |
555 | def _build_inv_delta_C(self, F, C):
556 | """ Return inv_delta_C which is needed to calculate T """
557 | hat_C = np.zeros((F, F), dtype=float) # F x F
558 | for i in range(0, F):
559 | for j in range(i, F):
560 | r = np.linalg.norm(C[i] - C[j])
561 | hat_C[i, j] = r
562 | hat_C[j, i] = r
563 | np.fill_diagonal(hat_C, 1)
564 | hat_C = (hat_C**2) * np.log(hat_C)
565 | # print(C.shape, hat_C.shape)
566 | delta_C = np.concatenate( # F+3 x F+3
567 | [
568 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
569 | np.concatenate([np.zeros(
570 | (2, 3)), np.transpose(C)], axis=1), # 2 x F+3
571 | np.concatenate([np.zeros(
572 | (1, 3)), np.ones((1, F))], axis=1), # 1 x F+3
573 | ],
574 | axis=0,
575 | )
576 | inv_delta_C = np.linalg.inv(delta_C)
577 | return inv_delta_C # F+3 x F+3
578 |
579 | def _build_P(self, I_r_width, I_r_height):
580 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) +
581 | 1.0) / I_r_width # self.I_r_width
582 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) +
583 | 1.0) / I_r_height # self.I_r_height
584 | P = np.stack( # self.I_r_width x self.I_r_height x 2
585 | np.meshgrid(I_r_grid_x, I_r_grid_y),
586 | axis=2)
587 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
588 |
589 | def _build_P_hat(self, F, C, P):
590 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
591 | P_tile = np.tile(np.expand_dims(P, axis=1),
592 | (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
593 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
594 | P_diff = P_tile - C_tile # n x F x 2
595 | rbf_norm = np.linalg.norm(
596 | P_diff,
597 | ord=2,
598 | axis=2,
599 | keepdims=False,
600 | ) # n x F
601 | rbf = np.multiply(
602 | np.square(rbf_norm),
603 | np.log(rbf_norm + self.eps),
604 | ) # n x F
605 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
606 | return P_hat # n x F+3
607 |
608 | def build_P_prime(self, batch_C_prime):
609 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """
610 | batch_size = batch_C_prime.size(0)
611 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
612 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
613 | batch_C_prime_with_zeros = torch.cat(
614 | (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)),
615 | dim=1) # batch_size x F+3 x 2
616 | batch_T = torch.bmm(
617 | batch_inv_delta_C,
618 | batch_C_prime_with_zeros,
619 | ) # batch_size x F+3 x 2
620 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
621 | return batch_P_prime # batch_size x n x 2
622 |
--------------------------------------------------------------------------------
/pororo/models/brainOCR/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is adapted from https://github.com/JaidedAI/EasyOCR/blob/8af936ba1b2f3c230968dc1022d0cd3e9ca1efbb/easyocr/utils.py
3 | """
4 |
5 | import math
6 | import os
7 | from urllib.request import urlretrieve
8 |
9 | import cv2
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 | from torch import Tensor
14 |
15 | from .imgproc import load_image
16 |
17 |
18 | def consecutive(data, mode: str = "first", stepsize: int = 1):
19 | group = np.split(data, np.where(np.diff(data) != stepsize)[0] + 1)
20 | group = [item for item in group if len(item) > 0]
21 |
22 | if mode == "first":
23 | result = [l[0] for l in group]
24 | elif mode == "last":
25 | result = [l[-1] for l in group]
26 | return result
27 |
28 |
29 | def word_segmentation(
30 | mat,
31 | separator_idx={
32 | "th": [1, 2],
33 | "en": [3, 4]
34 | },
35 | separator_idx_list=[1, 2, 3, 4],
36 | ):
37 | result = []
38 | sep_list = []
39 | start_idx = 0
40 | sep_lang = ""
41 | for sep_idx in separator_idx_list:
42 | if sep_idx % 2 == 0:
43 | mode = "first"
44 | else:
45 | mode = "last"
46 | a = consecutive(np.argwhere(mat == sep_idx).flatten(), mode)
47 | new_sep = [[item, sep_idx] for item in a]
48 | sep_list += new_sep
49 | sep_list = sorted(sep_list, key=lambda x: x[0])
50 |
51 | for sep in sep_list:
52 | for lang in separator_idx.keys():
53 | if sep[1] == separator_idx[lang][0]: # start lang
54 | sep_lang = lang
55 | sep_start_idx = sep[0]
56 | elif sep[1] == separator_idx[lang][1]: # end lang
57 | if sep_lang == lang: # check if last entry if the same start lang
58 | new_sep_pair = [lang, [sep_start_idx + 1, sep[0] - 1]]
59 | if sep_start_idx > start_idx:
60 | result.append(["", [start_idx, sep_start_idx - 1]])
61 | start_idx = sep[0] + 1
62 | result.append(new_sep_pair)
63 | sep_lang = "" # reset
64 |
65 | if start_idx <= len(mat) - 1:
66 | result.append(["", [start_idx, len(mat) - 1]])
67 | return result
68 |
69 |
70 | # code is based from https://github.com/githubharald/CTCDecoder/blob/master/src/BeamSearch.py
71 | class BeamEntry:
72 | "information about one single beam at specific time-step"
73 |
74 | def __init__(self):
75 | self.prTotal = 0 # blank and non-blank
76 | self.prNonBlank = 0 # non-blank
77 | self.prBlank = 0 # blank
78 | self.prText = 1 # LM score
79 | self.lmApplied = False # flag if LM was already applied to this beam
80 | self.labeling = () # beam-labeling
81 |
82 |
83 | class BeamState:
84 | "information about the beams at specific time-step"
85 |
86 | def __init__(self):
87 | self.entries = {}
88 |
89 | def norm(self):
90 | "length-normalise LM score"
91 | for (k, _) in self.entries.items():
92 | labelingLen = len(self.entries[k].labeling)
93 | self.entries[k].prText = self.entries[k].prText**(
94 | 1.0 / (labelingLen if labelingLen else 1.0))
95 |
96 | def sort(self):
97 | "return beam-labelings, sorted by probability"
98 | beams = [v for (_, v) in self.entries.items()]
99 | sortedBeams = sorted(
100 | beams,
101 | reverse=True,
102 | key=lambda x: x.prTotal * x.prText,
103 | )
104 | return [x.labeling for x in sortedBeams]
105 |
106 | def wordsearch(self, classes, ignore_idx, maxCandidate, dict_list):
107 | beams = [v for (_, v) in self.entries.items()]
108 | sortedBeams = sorted(
109 | beams,
110 | reverse=True,
111 | key=lambda x: x.prTotal * x.prText,
112 | )
113 | if len(sortedBeams) > maxCandidate:
114 | sortedBeams = sortedBeams[:maxCandidate]
115 |
116 | for j, candidate in enumerate(sortedBeams):
117 | idx_list = candidate.labeling
118 | text = ""
119 | for i, l in enumerate(idx_list):
120 | if l not in ignore_idx and (
121 | not (i > 0 and idx_list[i - 1] == idx_list[i])):
122 | text += classes[l]
123 |
124 | if j == 0:
125 | best_text = text
126 | if text in dict_list:
127 | # print('found text: ', text)
128 | best_text = text
129 | break
130 | else:
131 | pass
132 | # print('not in dict: ', text)
133 | return best_text
134 |
135 |
136 | def applyLM(parentBeam, childBeam, classes, lm_model, lm_factor: float = 0.01):
137 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars"
138 | if lm_model is not None and not childBeam.lmApplied:
139 | history = parentBeam.labeling
140 | history = " ".join(
141 | classes[each].replace(" ", "▁") for each in history if each != 0)
142 |
143 | current_char = classes[childBeam.labeling[-1]].replace(" ", "▁")
144 | if current_char == "[blank]":
145 | lmProb = 1
146 | else:
147 | text = history + " " + current_char
148 | lmProb = 10**lm_model.score(text, bos=True) * lm_factor
149 |
150 | childBeam.prText = lmProb # probability of char sequence
151 | childBeam.lmApplied = True # only apply LM once per beam entry
152 |
153 |
154 | def simplify_label(labeling, blankIdx: int = 0):
155 | labeling = np.array(labeling)
156 |
157 | # collapse blank
158 | idx = np.where(~((np.roll(labeling, 1) == labeling) &
159 | (labeling == blankIdx)))[0]
160 | labeling = labeling[idx]
161 |
162 | # get rid of blank between different characters
163 | idx = np.where(~((np.roll(labeling, 1) != np.roll(labeling, -1)) &
164 | (labeling == blankIdx)))[0]
165 |
166 | if len(labeling) > 0:
167 | last_idx = len(labeling) - 1
168 | if last_idx not in idx:
169 | idx = np.append(idx, [last_idx])
170 | labeling = labeling[idx]
171 |
172 | return tuple(labeling)
173 |
174 |
175 | def addBeam(beamState, labeling):
176 | "add beam if it does not yet exist"
177 | if labeling not in beamState.entries:
178 | beamState.entries[labeling] = BeamEntry()
179 |
180 |
181 | def ctcBeamSearch(
182 | mat,
183 | classes: list,
184 | ignore_idx: int,
185 | lm_model,
186 | lm_factor: float = 0.01,
187 | beam_width: int = 5,
188 | ):
189 | blankIdx = 0
190 | maxT, maxC = mat.shape
191 |
192 | # initialise beam state
193 | last = BeamState()
194 | labeling = ()
195 | last.entries[labeling] = BeamEntry()
196 | last.entries[labeling].prBlank = 1
197 | last.entries[labeling].prTotal = 1
198 |
199 | # go over all time-steps
200 | for t in range(maxT):
201 | # print("t=", t)
202 | curr = BeamState()
203 | # get beam-labelings of best beams
204 | bestLabelings = last.sort()[0:beam_width]
205 | # go over best beams
206 | for labeling in bestLabelings:
207 | # print("labeling:", labeling)
208 | # probability of paths ending with a non-blank
209 | prNonBlank = 0
210 | # in case of non-empty beam
211 | if labeling:
212 | # probability of paths with repeated last char at the end
213 | prNonBlank = last.entries[labeling].prNonBlank * mat[
214 | t, labeling[-1]]
215 |
216 | # probability of paths ending with a blank
217 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
218 |
219 | # add beam at current time-step if needed
220 | labeling = simplify_label(labeling, blankIdx)
221 | addBeam(curr, labeling)
222 |
223 | # fill in data
224 | curr.entries[labeling].labeling = labeling
225 | curr.entries[labeling].prNonBlank += prNonBlank
226 | curr.entries[labeling].prBlank += prBlank
227 | curr.entries[labeling].prTotal += prBlank + prNonBlank
228 | curr.entries[labeling].prText = last.entries[labeling].prText
229 | # beam-labeling not changed, therefore also LM score unchanged from
230 |
231 | curr.entries[labeling].lmApplied = (
232 | True # LM already applied at previous time-step for this beam-labeling
233 | )
234 |
235 | # extend current beam-labeling
236 | # char_highscore = np.argpartition(mat[t, :], -5)[-5:] # run through 5 highest probability
237 | char_highscore = np.where(
238 | mat[t, :] >= 0.5 /
239 | maxC)[0] # run through all probable characters
240 | for c in char_highscore:
241 | # for c in range(maxC - 1):
242 | # add new char to current beam-labeling
243 | newLabeling = labeling + (c,)
244 | newLabeling = simplify_label(newLabeling, blankIdx)
245 |
246 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank
247 | if labeling and labeling[-1] == c:
248 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank
249 | else:
250 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal
251 |
252 | # add beam at current time-step if needed
253 | addBeam(curr, newLabeling)
254 |
255 | # fill in data
256 | curr.entries[newLabeling].labeling = newLabeling
257 | curr.entries[newLabeling].prNonBlank += prNonBlank
258 | curr.entries[newLabeling].prTotal += prNonBlank
259 |
260 | # apply LM
261 | applyLM(
262 | curr.entries[labeling],
263 | curr.entries[newLabeling],
264 | classes,
265 | lm_model,
266 | lm_factor,
267 | )
268 |
269 | # set new beam state
270 |
271 | last = curr
272 |
273 | # normalise LM scores according to beam-labeling-length
274 | last.norm()
275 |
276 | bestLabeling = last.sort()[0] # get most probable labeling
277 | res = ""
278 | for i, l in enumerate(bestLabeling):
279 | # removing repeated characters and blank.
280 | if l != ignore_idx and (not (i > 0 and
281 | bestLabeling[i - 1] == bestLabeling[i])):
282 | res += classes[l]
283 |
284 | return res
285 |
286 |
287 | class CTCLabelConverter(object):
288 | """ Convert between text-label and text-index """
289 |
290 | def __init__(self, vocab: list):
291 | self.char2idx = {char: idx for idx, char in enumerate(vocab)}
292 | self.idx2char = {idx: char for idx, char in enumerate(vocab)}
293 | self.ignored_index = 0
294 | self.vocab = vocab
295 |
296 | def encode(self, texts: list):
297 | """
298 | Convert input texts into indices
299 | texts (list): text labels of each image. [batch_size]
300 |
301 | Returns
302 | text: concatenated text index for CTCLoss.
303 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
304 | length: length of each text. [batch_size]
305 | """
306 | lengths = [len(text) for text in texts]
307 | concatenated_text = "".join(texts)
308 | indices = [self.char2idx[char] for char in concatenated_text]
309 |
310 | return torch.IntTensor(indices), torch.IntTensor(lengths)
311 |
312 | def decode_greedy(self, indices: Tensor, lengths: Tensor):
313 | """convert text-index into text-label.
314 |
315 | :param indices (1D int32 Tensor): [N*length,]
316 | :param lengths (1D int32 Tensor): [N,]
317 | :return:
318 | """
319 | texts = []
320 | index = 0
321 | for length in lengths:
322 | text = indices[index:index + length]
323 |
324 | chars = []
325 | for i in range(length):
326 | if (text[i] != self.ignored_index) and (
327 | not (i > 0 and text[i - 1] == text[i])
328 | ): # removing repeated characters and blank (and separator).
329 | chars.append(self.idx2char[text[i].item()])
330 | texts.append("".join(chars))
331 | index += length
332 | return texts
333 |
334 | def decode_beamsearch(self, mat, lm_model, lm_factor, beam_width: int = 5):
335 | texts = []
336 | for i in range(mat.shape[0]):
337 | text = ctcBeamSearch(
338 | mat[i],
339 | self.vocab,
340 | self.ignored_index,
341 | lm_model,
342 | lm_factor,
343 | beam_width,
344 | )
345 | texts.append(text)
346 | return texts
347 |
348 |
349 | def four_point_transform(image, rect):
350 | (tl, tr, br, bl) = rect
351 |
352 | widthA = np.sqrt(((br[0] - bl[0])**2) + ((br[1] - bl[1])**2))
353 | widthB = np.sqrt(((tr[0] - tl[0])**2) + ((tr[1] - tl[1])**2))
354 | maxWidth = max(int(widthA), int(widthB))
355 |
356 | # compute the height of the new image, which will be the
357 | # maximum distance between the top-right and bottom-right
358 | # y-coordinates or the top-left and bottom-left y-coordinates
359 | heightA = np.sqrt(((tr[0] - br[0])**2) + ((tr[1] - br[1])**2))
360 | heightB = np.sqrt(((tl[0] - bl[0])**2) + ((tl[1] - bl[1])**2))
361 | maxHeight = max(int(heightA), int(heightB))
362 |
363 | dst = np.array(
364 | [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1],
365 | [0, maxHeight - 1]],
366 | dtype="float32",
367 | )
368 |
369 | # compute the perspective transform matrix and then apply it
370 | M = cv2.getPerspectiveTransform(rect, dst)
371 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
372 |
373 | return warped
374 |
375 |
376 | def group_text_box(
377 | polys,
378 | slope_ths: float = 0.1,
379 | ycenter_ths: float = 0.5,
380 | height_ths: float = 0.5,
381 | width_ths: float = 1.0,
382 | add_margin: float = 0.05,
383 | ):
384 | # poly top-left, top-right, low-right, low-left
385 | horizontal_list, free_list, combined_list, merged_list = [], [], [], []
386 |
387 | for poly in polys:
388 | slope_up = (poly[3] - poly[1]) / np.maximum(10, (poly[2] - poly[0]))
389 | slope_down = (poly[5] - poly[7]) / np.maximum(10, (poly[4] - poly[6]))
390 | if max(abs(slope_up), abs(slope_down)) < slope_ths:
391 | x_max = max([poly[0], poly[2], poly[4], poly[6]])
392 | x_min = min([poly[0], poly[2], poly[4], poly[6]])
393 | y_max = max([poly[1], poly[3], poly[5], poly[7]])
394 | y_min = min([poly[1], poly[3], poly[5], poly[7]])
395 | horizontal_list.append([
396 | x_min, x_max, y_min, y_max, 0.5 * (y_min + y_max), y_max - y_min
397 | ])
398 | else:
399 | height = np.linalg.norm([poly[6] - poly[0], poly[7] - poly[1]])
400 | margin = int(1.44 * add_margin * height)
401 |
402 | theta13 = abs(
403 | np.arctan(
404 | (poly[1] - poly[5]) / np.maximum(10, (poly[0] - poly[4]))))
405 | theta24 = abs(
406 | np.arctan(
407 | (poly[3] - poly[7]) / np.maximum(10, (poly[2] - poly[6]))))
408 | # do I need to clip minimum, maximum value here?
409 | x1 = poly[0] - np.cos(theta13) * margin
410 | y1 = poly[1] - np.sin(theta13) * margin
411 | x2 = poly[2] + np.cos(theta24) * margin
412 | y2 = poly[3] - np.sin(theta24) * margin
413 | x3 = poly[4] + np.cos(theta13) * margin
414 | y3 = poly[5] + np.sin(theta13) * margin
415 | x4 = poly[6] - np.cos(theta24) * margin
416 | y4 = poly[7] + np.sin(theta24) * margin
417 |
418 | free_list.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
419 | horizontal_list = sorted(horizontal_list, key=lambda item: item[4])
420 |
421 | # combine box
422 | new_box = []
423 | for poly in horizontal_list:
424 |
425 | if len(new_box) == 0:
426 | b_height = [poly[5]]
427 | b_ycenter = [poly[4]]
428 | new_box.append(poly)
429 | else:
430 | # comparable height and comparable y_center level up to ths*height
431 | if (abs(np.mean(b_height) - poly[5]) < height_ths *
432 | np.mean(b_height)) and (abs(np.mean(b_ycenter) - poly[4]) <
433 | ycenter_ths * np.mean(b_height)):
434 | b_height.append(poly[5])
435 | b_ycenter.append(poly[4])
436 | new_box.append(poly)
437 | else:
438 | b_height = [poly[5]]
439 | b_ycenter = [poly[4]]
440 | combined_list.append(new_box)
441 | new_box = [poly]
442 | combined_list.append(new_box)
443 |
444 | # merge list use sort again
445 | for boxes in combined_list:
446 | if len(boxes) == 1: # one box per line
447 | box = boxes[0]
448 | margin = int(add_margin * box[5])
449 | merged_list.append([
450 | box[0] - margin, box[1] + margin, box[2] - margin,
451 | box[3] + margin
452 | ])
453 | else: # multiple boxes per line
454 | boxes = sorted(boxes, key=lambda item: item[0])
455 |
456 | merged_box, new_box = [], []
457 | for box in boxes:
458 | if len(new_box) == 0:
459 | x_max = box[1]
460 | new_box.append(box)
461 | else:
462 | if abs(box[0] - x_max) < width_ths * (
463 | box[3] - box[2]): # merge boxes
464 | x_max = box[1]
465 | new_box.append(box)
466 | else:
467 | x_max = box[1]
468 | merged_box.append(new_box)
469 | new_box = [box]
470 | if len(new_box) > 0:
471 | merged_box.append(new_box)
472 |
473 | for mbox in merged_box:
474 | if len(mbox) != 1: # adjacent box in same line
475 | # do I need to add margin here?
476 | x_min = min(mbox, key=lambda x: x[0])[0]
477 | x_max = max(mbox, key=lambda x: x[1])[1]
478 | y_min = min(mbox, key=lambda x: x[2])[2]
479 | y_max = max(mbox, key=lambda x: x[3])[3]
480 |
481 | margin = int(add_margin * (y_max - y_min))
482 |
483 | merged_list.append([
484 | x_min - margin, x_max + margin, y_min - margin,
485 | y_max + margin
486 | ])
487 | else: # non adjacent box in same line
488 | box = mbox[0]
489 |
490 | margin = int(add_margin * (box[3] - box[2]))
491 | merged_list.append([
492 | box[0] - margin,
493 | box[1] + margin,
494 | box[2] - margin,
495 | box[3] + margin,
496 | ])
497 | # may need to check if box is really in image
498 | return merged_list, free_list
499 |
500 |
501 | def get_image_list(horizontal_list: list,
502 | free_list: list,
503 | img: np.ndarray,
504 | model_height: int = 64):
505 | image_list = []
506 | maximum_y, maximum_x = img.shape
507 |
508 | max_ratio_hori, max_ratio_free = 1, 1
509 | for box in free_list:
510 | rect = np.array(box, dtype="float32")
511 | transformed_img = four_point_transform(img, rect)
512 | ratio = transformed_img.shape[1] / transformed_img.shape[0]
513 | crop_img = cv2.resize(
514 | transformed_img,
515 | (int(model_height * ratio), model_height),
516 | interpolation=Image.ANTIALIAS,
517 | )
518 | # box : [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
519 | image_list.append((box, crop_img))
520 | max_ratio_free = max(ratio, max_ratio_free)
521 |
522 | max_ratio_free = math.ceil(max_ratio_free)
523 |
524 | for box in horizontal_list:
525 | x_min = max(0, box[0])
526 | x_max = min(box[1], maximum_x)
527 | y_min = max(0, box[2])
528 | y_max = min(box[3], maximum_y)
529 | crop_img = img[y_min:y_max, x_min:x_max]
530 | width = x_max - x_min
531 | height = y_max - y_min
532 | ratio = width / height
533 | crop_img = cv2.resize(
534 | crop_img,
535 | (int(model_height * ratio), model_height),
536 | interpolation=Image.ANTIALIAS,
537 | )
538 | image_list.append((
539 | [
540 | [x_min, y_min],
541 | [x_max, y_min],
542 | [x_max, y_max],
543 | [x_min, y_max],
544 | ],
545 | crop_img,
546 | ))
547 | max_ratio_hori = max(ratio, max_ratio_hori)
548 |
549 | max_ratio_hori = math.ceil(max_ratio_hori)
550 | max_ratio = max(max_ratio_hori, max_ratio_free)
551 | max_width = math.ceil(max_ratio) * model_height
552 |
553 | image_list = sorted(
554 | image_list, key=lambda item: item[0][0][1]) # sort by vertical position
555 | return image_list, max_width
556 |
557 |
558 | def diff(input_list):
559 | return max(input_list) - min(input_list)
560 |
561 |
562 | def get_paragraph(raw_result,
563 | x_ths: int = 1,
564 | y_ths: float = 0.5,
565 | mode: str = "ltr"):
566 | # create basic attributes
567 | box_group = []
568 | for box in raw_result:
569 | all_x = [int(coord[0]) for coord in box[0]]
570 | all_y = [int(coord[1]) for coord in box[0]]
571 | min_x = min(all_x)
572 | max_x = max(all_x)
573 | min_y = min(all_y)
574 | max_y = max(all_y)
575 | height = max_y - min_y
576 | box_group.append([
577 | box[1], min_x, max_x, min_y, max_y, height, 0.5 * (min_y + max_y), 0
578 | ]) # last element indicates group
579 | # cluster boxes into paragraph
580 | current_group = 1
581 | while len([box for box in box_group if box[7] == 0]) > 0:
582 | # group0 = non-group
583 | box_group0 = [box for box in box_group if box[7] == 0]
584 | # new group
585 | if len([box for box in box_group if box[7] == current_group]) == 0:
586 | # assign first box to form new group
587 | box_group0[0][7] = current_group
588 | # try to add group
589 | else:
590 | current_box_group = [
591 | box for box in box_group if box[7] == current_group
592 | ]
593 | mean_height = np.mean([box[5] for box in current_box_group])
594 | # yapf: disable
595 | min_gx = min([box[1] for box in current_box_group]) - x_ths * mean_height
596 | max_gx = max([box[2] for box in current_box_group]) + x_ths * mean_height
597 | min_gy = min([box[3] for box in current_box_group]) - y_ths * mean_height
598 | max_gy = max([box[4] for box in current_box_group]) + y_ths * mean_height
599 | add_box = False
600 | for box in box_group0:
601 | same_horizontal_level = (min_gx <= box[1] <= max_gx) or (min_gx <= box[2] <= max_gx)
602 | same_vertical_level = (min_gy <= box[3] <= max_gy) or (min_gy <= box[4] <= max_gy)
603 | if same_horizontal_level and same_vertical_level:
604 | box[7] = current_group
605 | add_box = True
606 | break
607 | # cannot add more box, go to next group
608 | if not add_box:
609 | current_group += 1
610 | # yapf: enable
611 | # arrage order in paragraph
612 | result = []
613 | for i in set(box[7] for box in box_group):
614 | current_box_group = [box for box in box_group if box[7] == i]
615 | mean_height = np.mean([box[5] for box in current_box_group])
616 | min_gx = min([box[1] for box in current_box_group])
617 | max_gx = max([box[2] for box in current_box_group])
618 | min_gy = min([box[3] for box in current_box_group])
619 | max_gy = max([box[4] for box in current_box_group])
620 |
621 | text = ""
622 | while len(current_box_group) > 0:
623 | highest = min([box[6] for box in current_box_group])
624 | candidates = [
625 | box for box in current_box_group
626 | if box[6] < highest + 0.4 * mean_height
627 | ]
628 | # get the far left
629 | if mode == "ltr":
630 | most_left = min([box[1] for box in candidates])
631 | for box in candidates:
632 | if box[1] == most_left:
633 | best_box = box
634 | elif mode == "rtl":
635 | most_right = max([box[2] for box in candidates])
636 | for box in candidates:
637 | if box[2] == most_right:
638 | best_box = box
639 | text += " " + best_box[0]
640 | current_box_group.remove(best_box)
641 |
642 | result.append([
643 | [
644 | [min_gx, min_gy],
645 | [max_gx, min_gy],
646 | [max_gx, max_gy],
647 | [min_gx, max_gy],
648 | ],
649 | text[1:],
650 | ])
651 |
652 | return result
653 |
654 |
655 | def printProgressBar(
656 | prefix="",
657 | suffix="",
658 | decimals: int = 1,
659 | length: int = 100,
660 | fill: str = "█",
661 | printEnd: str = "\r",
662 | ):
663 | """
664 | Call in a loop to create terminal progress bar
665 | @params:
666 | prefix - Optional : prefix string (Str)
667 | suffix - Optional : suffix string (Str)
668 | decimals - Optional : positive number of decimals in percent complete (Int)
669 | length - Optional : character length of bar (Int)
670 | fill - Optional : bar fill character (Str)
671 | printEnd - Optional : end character (e.g. "\r", "\r\n") (Str)
672 | """
673 |
674 | def progress_hook(count, blockSize, totalSize):
675 | progress = count * blockSize / totalSize
676 | percent = ("{0:." + str(decimals) + "f}").format(progress * 100)
677 | filledLength = int(length * progress)
678 | bar = fill * filledLength + "-" * (length - filledLength)
679 | print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=printEnd)
680 |
681 | return progress_hook
682 |
683 |
684 | def reformat_input(image):
685 | """
686 | :param image: image file path or bytes or array
687 | :return:
688 | img (array): (original_image_height, original_image_width, 3)
689 | img_cv_grey (array): (original_image_height, original_image_width, 3)
690 | """
691 | if type(image) == str:
692 | if image.startswith("http://") or image.startswith("https://"):
693 | tmp, _ = urlretrieve(
694 | image,
695 | reporthook=printProgressBar(
696 | prefix="Progress:",
697 | suffix="Complete",
698 | length=50,
699 | ),
700 | )
701 | img_cv_grey = cv2.imread(tmp, cv2.IMREAD_GRAYSCALE)
702 | os.remove(tmp)
703 | else:
704 | img_cv_grey = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
705 | image = os.path.expanduser(image)
706 | img = load_image(image) # can accept URL
707 | elif type(image) == bytes:
708 | nparr = np.frombuffer(image, np.uint8)
709 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
710 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
711 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
712 |
713 | elif type(image) == np.ndarray:
714 | if len(image.shape) == 2: # grayscale
715 | img_cv_grey = image
716 | img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
717 | elif len(image.shape) == 3 and image.shape[2] == 3: # BGRscale
718 | img = image
719 | img_cv_grey = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
720 | elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale
721 | img = image[:, :, :3]
722 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
723 | img_cv_grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
724 |
725 | return img, img_cv_grey
726 |
--------------------------------------------------------------------------------