├── src
├── text_recognizer
│ ├── __init__.py
│ ├── modules
│ │ ├── utils.py
│ │ ├── sequence_modeling.py
│ │ ├── prediction.py
│ │ ├── model.py
│ │ ├── model_utils.py
│ │ ├── dataset.py
│ │ ├── transformation.py
│ │ └── feature_extraction.py
│ ├── load_model.py
│ └── infer.py
├── __init__.py
├── text_detector
│ ├── __init__.py
│ ├── basenet
│ │ ├── __init__.py
│ │ └── vgg16_bn.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── imgproc.py
│ │ ├── refinenet.py
│ │ ├── craft.py
│ │ └── craft_utils.py
│ ├── load_model.py
│ └── infer.py
├── model.py
└── engine.py
├── data
├── tes.jpg
└── sample_output.jpg
├── pyproject.toml
├── setup.cfg
├── requirements.txt
├── environment.yaml
├── configs
├── craft_config.yaml
└── star_config.yaml
├── Dockerfile
├── LICENSE
├── main.py
├── CONTRIBUTING.md
├── .gitignore
├── notebooks
├── test_api.ipynb
├── inference_onnx_engine.ipynb
├── inference_default_engine.ipynb
└── export_onnx_model.ipynb
└── README.md
/src/text_recognizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .text_detector import *
2 | from .text_recognizer import *
3 |
--------------------------------------------------------------------------------
/src/text_detector/__init__.py:
--------------------------------------------------------------------------------
1 | from .basenet import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/data/tes.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakartaresearch/receipt-ocr/HEAD/data/tes.jpg
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | profile = "black"
3 |
4 | [tool.black]
5 | line-length = 100
--------------------------------------------------------------------------------
/data/sample_output.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakartaresearch/receipt-ocr/HEAD/data/sample_output.jpg
--------------------------------------------------------------------------------
/src/text_detector/basenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg16_bn import init_weights, vgg16_bn
2 |
3 | __all__ = ['init_weights', 'vgg16_bn']
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = Q000, WPS110, WPS214, WPS300, WPS414, WPS420, I003, C812
3 | max-line-length = 100
4 | extend-ignore = E203
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.1
2 | torchvision==0.10.1
3 | opencv-python==4.5.3.56
4 | scikit-image==0.18.3
5 | scipy==1.7.1
6 | PyYAML==5.4.1
7 | onnx==1.10.1
8 | onnxruntime==1.9.0
9 | fastapi==0.68.1
10 | uvicorn[standard]
11 | gunicorn
12 | pydantic
--------------------------------------------------------------------------------
/src/text_detector/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .craft_utils import *
2 | from .imgproc import *
3 | from .utils import yaml_loader
4 | from .craft import CRAFT
5 |
6 |
7 | __all__ = ['getDetBoxes_core', 'getPoly_core', 'getDetBoxes',
8 | 'adjustResultCoordinates', 'normalizeMeanVariance',
9 | 'denormalizeMeanVariance', 'resize_aspect_ratio', 'cvt2HeatmapImg',
10 | 'yaml_loader', 'CRAFT']
11 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: receipt-ocr
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python>=3.8
7 | - pip
8 | - pip:
9 | - torch==1.9.1
10 | - torchvision==0.10.1
11 | - opencv-python==4.5.3.56
12 | - scikit-image==0.18.3
13 | - scipy==1.7.1
14 | - PyYAML==5.4.1
15 | - onnx==1.10.1
16 | - onnxruntime==1.9.0
17 | - fastapi==0.68.1
18 | - uvicorn[standard]
19 | - pydantic
--------------------------------------------------------------------------------
/configs/craft_config.yaml:
--------------------------------------------------------------------------------
1 | #text confidence threshold
2 | text_threshold: 0.7
3 | #text low-bound score
4 | low_text: 0.2
5 | #link confidence threshold
6 | link_threshold: 0.2
7 | #Use cuda for inference
8 | cuda: False
9 | #image size for inference
10 | canvas_size: 1280
11 | #image magnification ratio
12 | mag_ratio: 1.5
13 | #enable polygon type
14 | poly: False
15 | #enable link refiner
16 | refine: False
17 | #pretrained refiner model
18 | refiner_model: ''
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7
2 |
3 | WORKDIR /receipt-ocr
4 | COPY configs /receipt-ocr/configs
5 | COPY models /receipt-ocr/models
6 | COPY src /receipt-ocr/src
7 | COPY main.py /receipt-ocr
8 | COPY requirements.txt /receipt-ocr
9 |
10 | RUN apt-get update
11 | RUN apt-get install ffmpeg libsm6 libxext6 -y
12 | RUN pip install --upgrade pip
13 | RUN pip install --no-cache-dir -r requirements.txt
14 |
15 | EXPOSE 8000
16 | CMD uvicorn main:app --host 0.0.0.0 --port 8000
--------------------------------------------------------------------------------
/src/text_recognizer/modules/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from yaml.loader import SafeLoader
3 |
4 |
5 | def yaml_loader(filename):
6 | with open(filename) as f:
7 | data = yaml.load(f, Loader=SafeLoader)
8 | return data
9 |
10 |
11 | class DictObj:
12 | def __init__(self, in_dict: dict):
13 | assert isinstance(in_dict, dict)
14 | for key, val in in_dict.items():
15 | if isinstance(val, (list, tuple)):
16 | setattr(self, key, [DictObj(x) if isinstance(
17 | x, dict) else x for x in val])
18 | else:
19 | setattr(self, key, DictObj(val)
20 | if isinstance(val, dict) else val)
21 |
--------------------------------------------------------------------------------
/configs/star_config.yaml:
--------------------------------------------------------------------------------
1 | # input batch size
2 | batch_size: 192
3 | # maximum label length
4 | batch_max_length: 25
5 | # number of data loading workers
6 | workers: 0
7 | # the height of the input image
8 | imgH: 32
9 | # the width of the input image
10 | imgW: 100
11 | # character label
12 | character: '0123456789abcdefghijklmnopqrstuvwxyz'
13 | sensitive: True
14 |
15 | # Model Architecture
16 | Transformation: 'TPS'
17 | FeatureExtraction: 'ResNet'
18 | SequenceModeling: 'BiLSTM'
19 | Prediction: 'Attn'
20 |
21 | # number of fiducial points of TPS-STN
22 | num_fiducial: 20
23 | # the number of input channel of Feature extractor
24 | input_channel: 1
25 | # the number of output channel of Feature extractor
26 | output_channel: 512
27 | # the size of the LSTM hidden state
28 | hidden_size: 256
29 | PAD: False
30 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/sequence_modeling.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BidirectionalLSTM(nn.Module):
5 |
6 | def __init__(self, input_size, hidden_size, output_size):
7 | super(BidirectionalLSTM, self).__init__()
8 | self.rnn = nn.LSTM(input_size, hidden_size,
9 | bidirectional=True, batch_first=True)
10 | self.linear = nn.Linear(hidden_size * 2, output_size)
11 |
12 | def forward(self, input):
13 | """
14 | input : visual feature [batch_size x T x input_size]
15 | output : contextual feature [batch_size x T x output_size]
16 | """
17 | self.rnn.flatten_parameters()
18 | # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
19 | recurrent, _ = self.rnn(input)
20 | output = self.linear(recurrent) # batch_size x T x output_size
21 | return output
22 |
--------------------------------------------------------------------------------
/src/text_detector/modules/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import multiprocessing
3 | import onnxruntime as rt
4 |
5 | from onnxruntime import InferenceSession, get_all_providers
6 | from yaml.loader import SafeLoader
7 |
8 |
9 | def yaml_loader(filename):
10 | with open(filename) as f:
11 | data = yaml.load(f, Loader=SafeLoader)
12 | return data
13 |
14 |
15 | def create_model_for_provider(model_path: str, provider: str) -> InferenceSession:
16 | """Return inference session for ONNX model with specific provider."""
17 | assert provider in get_all_providers(
18 | ), f"provider {provider} not found, {get_all_providers()}"
19 |
20 | sess_options = rt.SessionOptions()
21 | sess_options.intra_op_num_threads = multiprocessing.cpu_count()
22 | sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL
23 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
24 |
25 | return InferenceSession(model_path, sess_options, providers=[provider])
26 |
--------------------------------------------------------------------------------
/src/text_recognizer/load_model.py:
--------------------------------------------------------------------------------
1 | import string
2 | import torch
3 |
4 | from .modules.model_utils import CTCLabelConverter, AttnLabelConverter
5 | from .modules.utils import yaml_loader, DictObj
6 | from .modules.model import Model
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | def load_star(config_file, model_pth):
12 | cfg = yaml_loader(config_file)
13 | obj_cfg = DictObj(cfg)
14 | """ vocab / character number configuration """
15 | if obj_cfg.sensitive:
16 | obj_cfg.character = string.printable[:-6]
17 |
18 | if "CTC" in obj_cfg.Prediction:
19 | converter = CTCLabelConverter(obj_cfg.character)
20 | else:
21 | converter = AttnLabelConverter(obj_cfg.character)
22 |
23 | obj_cfg.num_class = len(converter.character)
24 |
25 | net = Model(obj_cfg)
26 | net = torch.nn.DataParallel(net).to(device)
27 | print(f"Loading weights from checkpoint ({model_pth})")
28 | net.load_state_dict(torch.load(model_pth, map_location=device))
29 | return obj_cfg, net, converter
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jakarta AI Research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """Script for Fast API Endpoint."""
2 | import base64
3 | import io
4 | import warnings
5 | import numpy as np
6 | from fastapi import FastAPI
7 | from PIL import Image
8 | from pydantic import BaseModel
9 | from src.engine import DefaultEngine
10 | from src.model import DefaultModel
11 |
12 | warnings.filterwarnings("ignore")
13 |
14 |
15 | app = FastAPI()
16 |
17 | detector_cfg = "configs/craft_config.yaml"
18 | detector_model = "models/text_detector/craft_mlt_25k.pth"
19 | recognizer_cfg = "configs/star_config.yaml"
20 | recognizer_model = "models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth"
21 |
22 | model = DefaultModel(detector_cfg, detector_model, recognizer_cfg, recognizer_model)
23 | engine = DefaultEngine(model)
24 |
25 |
26 | class Item(BaseModel):
27 | image: str
28 |
29 |
30 | @app.get("/")
31 | def read_root():
32 | return {"message": "API is running..."}
33 |
34 |
35 | @app.post("/ocr/predict")
36 | def predict(item: Item):
37 | item = item.dict()
38 | img_bytes = base64.b64decode(item["image"].encode("utf-8"))
39 | image = Image.open(io.BytesIO(img_bytes))
40 | image = np.array(image)
41 |
42 | engine.predict(image)
43 | return engine.result
44 |
--------------------------------------------------------------------------------
/src/text_detector/load_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 |
4 | from collections import OrderedDict
5 | from .modules.utils import yaml_loader, create_model_for_provider
6 | from .modules.craft import CRAFT
7 |
8 |
9 | def copy_state_dict(state_dict):
10 | if list(state_dict.keys())[0].startswith("module"):
11 | start_idx = 1
12 | else:
13 | start_idx = 0
14 | new_state_dict = OrderedDict()
15 | for k, v in state_dict.items():
16 | name = ".".join(k.split(".")[start_idx:])
17 | new_state_dict[name] = v
18 | return new_state_dict
19 |
20 |
21 | def load_craft(config_file, model_pth):
22 | cfg = yaml_loader(config_file)
23 | net = CRAFT()
24 |
25 | print("Loading weights from checkpoint (" + model_pth + ")")
26 | if cfg["cuda"]:
27 | net.load_state_dict(copy_state_dict(torch.load(model_pth)))
28 | else:
29 | net.load_state_dict(copy_state_dict(torch.load(model_pth, map_location="cpu")))
30 |
31 | if cfg["cuda"]:
32 | net = net.cuda()
33 | net = torch.nn.DataParallel(net)
34 | cudnn.benchmark = False
35 |
36 | net.eval()
37 | return cfg, net
38 |
39 |
40 | def load_craft_onnx(config_file, model_pth):
41 | cfg = yaml_loader(config_file)
42 | device = "CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"
43 | print("Loading weights from checkpoint (" + model_pth + ")")
44 | net = create_model_for_provider(model_pth, device)
45 | return cfg, net
46 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | """Model Class."""
2 | from abc import ABC, abstractmethod
3 | from .text_detector.load_model import load_craft, load_craft_onnx
4 | from .text_recognizer.load_model import load_star
5 |
6 |
7 | class BaseModel(ABC):
8 | """Abstract base class for Receipt OCR models."""
9 |
10 | def __init__(self, detector_cfg, detector_model, recognizer_cfg, recognizer_model):
11 | """Init model config.
12 |
13 | Args:
14 | detector_cfg: config file for text detector
15 | recognizer_cfg: config file for text recognizer
16 | """
17 | self._cfg_detector, self._detector = self._load_detector(detector_cfg, detector_model)
18 | self._cfg_recognizer, self._recognizer, self._converter = self._load_recognizer(
19 | recognizer_cfg, recognizer_model
20 | )
21 |
22 | @property
23 | def cfg_detector(self):
24 | return self._cfg_detector
25 |
26 | @property
27 | def detector(self):
28 | return self._detector
29 |
30 | @property
31 | def cfg_recognizer(self):
32 | return self._cfg_recognizer
33 |
34 | @property
35 | def recognizer(self):
36 | return self._recognizer
37 |
38 | @property
39 | def converter(self):
40 | return self._converter
41 |
42 | @abstractmethod
43 | def _load_detector(self):
44 | """Return CRAFT model."""
45 |
46 | @abstractmethod
47 | def _load_recognizer(self):
48 | """Return STAR model."""
49 |
50 |
51 | class DefaultModel(BaseModel):
52 | """Default implementation of Receipt OCR models."""
53 |
54 | def _load_detector(self, detector_cfg, detector_model):
55 | return load_craft(detector_cfg, detector_model)
56 |
57 | def _load_recognizer(self, recognizer_cfg, recognizer_model):
58 | return load_star(recognizer_cfg, recognizer_model)
59 |
60 |
61 | class ONNXModel(DefaultModel):
62 | """ONNX Model."""
63 |
64 | def _load_detector(self, detector_cfg, detector_model):
65 | return load_craft_onnx(detector_cfg, detector_model)
66 |
--------------------------------------------------------------------------------
/src/text_detector/modules/imgproc.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import numpy as np
8 | import cv2
9 |
10 |
11 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
12 | # should be RGB order
13 | img = in_img.copy().astype(np.float32)
14 |
15 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0,
16 | mean[2] * 255.0], dtype=np.float32)
17 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0,
18 | variance[2] * 255.0], dtype=np.float32)
19 | return img
20 |
21 |
22 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
23 | # should be RGB order
24 | img = in_img.copy()
25 | img *= variance
26 | img += mean
27 | img *= 255.0
28 | img = np.clip(img, 0, 255).astype(np.uint8)
29 | return img
30 |
31 |
32 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
33 | height, width, channel = img.shape
34 |
35 | # magnify image size
36 | target_size = mag_ratio * max(height, width)
37 |
38 | # set original image size
39 | if target_size > square_size:
40 | target_size = square_size
41 |
42 | ratio = target_size / max(height, width)
43 |
44 | target_h, target_w = int(height * ratio), int(width * ratio)
45 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
46 |
47 | # make canvas and paste image
48 | target_h32, target_w32 = target_h, target_w
49 | if target_h % 32 != 0:
50 | target_h32 = target_h + (32 - target_h % 32)
51 | if target_w % 32 != 0:
52 | target_w32 = target_w + (32 - target_w % 32)
53 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
54 | resized[0:target_h, 0:target_w, :] = proc
55 | target_h, target_w = target_h32, target_w32
56 |
57 | size_heatmap = (int(target_w/2), int(target_h/2))
58 |
59 | return resized, ratio, size_heatmap
60 |
61 |
62 | def cvt2HeatmapImg(img):
63 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
64 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
65 | return img
66 |
--------------------------------------------------------------------------------
/src/text_recognizer/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 | import torch.utils.data
4 | import torch.nn.functional as F
5 |
6 | from .modules.dataset import RawDataset, AlignCollate
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | def data_preparation(opt, list_data):
12 | AlignCollate_obj = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
13 | dataset = RawDataset(list_data=list_data, opt=opt)
14 | data_loader = torch.utils.data.DataLoader(
15 | dataset,
16 | batch_size=opt.batch_size,
17 | shuffle=False,
18 | num_workers=int(opt.workers),
19 | collate_fn=AlignCollate_obj,
20 | pin_memory=True,
21 | )
22 | return data_loader
23 |
24 |
25 | def inference(opt, model, converter, data_loader):
26 | output_pred, output_conf_score = [], []
27 | model.eval()
28 | with torch.no_grad():
29 | for image_tensors, _ in data_loader:
30 | batch_size = image_tensors.size(0)
31 | image = image_tensors.to(device)
32 | # For max length prediction
33 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
34 | text_for_pred = (
35 | torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
36 | )
37 |
38 | preds = model(image, text_for_pred, is_train=False)
39 | # select max probabilty (greedy decoding) then decode index to character
40 | _, preds_index = preds.max(2)
41 | preds_str = converter.decode(preds_index, length_for_pred)
42 | preds_prob = F.softmax(preds, dim=2)
43 | preds_max_prob, _ = preds_prob.max(dim=2)
44 |
45 | for pred, pred_max_prob in zip(preds_str, preds_max_prob):
46 | pred_eos = pred.find("[s]")
47 | # prune after "end of sentence" token ([s])
48 | pred = pred[:pred_eos]
49 | pred_max_prob = pred_max_prob[:pred_eos]
50 | # calculate confidence score (= multiply of pred_max_prob)
51 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
52 | output_pred.append(pred)
53 | output_conf_score.append(confidence_score)
54 | return output_pred, output_conf_score
55 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Welcome to Receipt OCR docs contributing guide
2 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's:
3 |
4 | - Reporting a bug
5 | - Discussing the current state of the code
6 | - Submitting a fix
7 | - Proposing new features
8 | - Becoming a maintainer
9 |
10 | ## New contributor guide
11 |
12 | To get an overview of the project, read the [README](README.md). Here are some resources to help you get started with open source contributions:
13 |
14 | - [Finding ways to contribute to open source on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github)
15 | - [Set up Git](https://docs.github.com/en/get-started/quickstart/set-up-git)
16 | - [GitHub flow](https://docs.github.com/en/get-started/quickstart/github-flow)
17 | - [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests)
18 |
19 | ## Getting started
20 | 1. Fork the repo and create your branch from `main`.
21 | 2. If you've changed code, update the docstrings and documentation.
22 | 3. Make sure your code lints.
23 | 4. Issue that pull request!
24 |
25 | ### Issues
26 |
27 | **Great Bug Reports** tend to have:
28 |
29 | - A quick summary and/or background
30 | - Steps to reproduce
31 | - Be specific!
32 | - Give sample code if you can.
33 | - What you expected would happen
34 | - What actually happens
35 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
36 |
37 | #### Create a new issue
38 |
39 | If you spot a problem with the docs, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments).
40 |
41 | #### Solve an issue
42 |
43 | Scan through our existing issues to find one that interests you. As a general rule, we don’t assign issues to anyone. If you find an issue to work on, you are welcome to open a PR with a fix.
44 |
45 | ### Code Formatting and Typing
46 |
47 | This project use flake8 as code linter and black as code formatter, with config that you can check on [pyproject.toml](pyproject.toml) and [setup.cfg](setup.cfg)
48 |
49 | ## License
50 |
51 | By contributing, you agree that your contributions will be licensed under its MIT License.
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.jpg
3 | *.jpeg
4 | *.onnx
5 | settings.json
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/src/text_detector/modules/refinenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | from ..basenet.vgg16_bn import init_weights
12 |
13 |
14 | class RefineNet(nn.Module):
15 | def __init__(self):
16 | super(RefineNet, self).__init__()
17 |
18 | self.last_conv = nn.Sequential(
19 | nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(
20 | 64), nn.ReLU(inplace=True),
21 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(
22 | 64), nn.ReLU(inplace=True),
23 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(
24 | 64), nn.ReLU(inplace=True)
25 | )
26 |
27 | self.aspp1 = nn.Sequential(
28 | nn.Conv2d(64, 128, kernel_size=3, dilation=6,
29 | padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
30 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(
31 | 128), nn.ReLU(inplace=True),
32 | nn.Conv2d(128, 1, kernel_size=1)
33 | )
34 |
35 | self.aspp2 = nn.Sequential(
36 | nn.Conv2d(64, 128, kernel_size=3, dilation=12,
37 | padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
38 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(
39 | 128), nn.ReLU(inplace=True),
40 | nn.Conv2d(128, 1, kernel_size=1)
41 | )
42 |
43 | self.aspp3 = nn.Sequential(
44 | nn.Conv2d(64, 128, kernel_size=3, dilation=18,
45 | padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
46 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(
47 | 128), nn.ReLU(inplace=True),
48 | nn.Conv2d(128, 1, kernel_size=1)
49 | )
50 |
51 | self.aspp4 = nn.Sequential(
52 | nn.Conv2d(64, 128, kernel_size=3, dilation=24,
53 | padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
54 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(
55 | 128), nn.ReLU(inplace=True),
56 | nn.Conv2d(128, 1, kernel_size=1)
57 | )
58 |
59 | init_weights(self.last_conv.modules())
60 | init_weights(self.aspp1.modules())
61 | init_weights(self.aspp2.modules())
62 | init_weights(self.aspp3.modules())
63 | init_weights(self.aspp4.modules())
64 |
65 | def forward(self, y, upconv4):
66 | refine = torch.cat([y.permute(0, 3, 1, 2), upconv4], dim=1)
67 | refine = self.last_conv(refine)
68 |
69 | aspp1 = self.aspp1(refine)
70 | aspp2 = self.aspp2(refine)
71 | aspp3 = self.aspp3(refine)
72 | aspp4 = self.aspp4(refine)
73 |
74 | # out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
75 | out = aspp1 + aspp2 + aspp3 + aspp4
76 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
77 |
--------------------------------------------------------------------------------
/src/text_detector/basenet/vgg16_bn.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torchvision import models
7 | from torchvision.models.vgg import model_urls
8 |
9 |
10 | def init_weights(modules):
11 | for m in modules:
12 | if isinstance(m, nn.Conv2d):
13 | init.xavier_uniform_(m.weight.data)
14 | if m.bias is not None:
15 | m.bias.data.zero_()
16 | elif isinstance(m, nn.BatchNorm2d):
17 | m.weight.data.fill_(1)
18 | m.bias.data.zero_()
19 | elif isinstance(m, nn.Linear):
20 | m.weight.data.normal_(0, 0.01)
21 | m.bias.data.zero_()
22 |
23 |
24 | class vgg16_bn(torch.nn.Module):
25 | def __init__(self, pretrained=True, freeze=True):
26 | super(vgg16_bn, self).__init__()
27 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace(
28 | 'https://', 'http://')
29 | vgg_pretrained_features = models.vgg16_bn(
30 | pretrained=pretrained).features
31 | self.slice1 = torch.nn.Sequential()
32 | self.slice2 = torch.nn.Sequential()
33 | self.slice3 = torch.nn.Sequential()
34 | self.slice4 = torch.nn.Sequential()
35 | self.slice5 = torch.nn.Sequential()
36 | for x in range(12): # conv2_2
37 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
38 | for x in range(12, 19): # conv3_3
39 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
40 | for x in range(19, 29): # conv4_3
41 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
42 | for x in range(29, 39): # conv5_3
43 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
44 |
45 | # fc6, fc7 without atrous conv
46 | self.slice5 = torch.nn.Sequential(
47 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
48 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
49 | nn.Conv2d(1024, 1024, kernel_size=1)
50 | )
51 |
52 | if not pretrained:
53 | init_weights(self.slice1.modules())
54 | init_weights(self.slice2.modules())
55 | init_weights(self.slice3.modules())
56 | init_weights(self.slice4.modules())
57 |
58 | # no pretrained model for fc6 and fc7
59 | init_weights(self.slice5.modules())
60 |
61 | if freeze:
62 | for param in self.slice1.parameters(): # only first conv
63 | param.requires_grad = False
64 |
65 | def forward(self, X):
66 | h = self.slice1(X)
67 | h_relu2_2 = h
68 | h = self.slice2(h)
69 | h_relu3_2 = h
70 | h = self.slice3(h)
71 | h_relu4_3 = h
72 | h = self.slice4(h)
73 | h_relu5_3 = h
74 | h = self.slice5(h)
75 | h_fc7 = h
76 | vgg_outputs = namedtuple(
77 | "VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
78 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
79 | return out
80 |
--------------------------------------------------------------------------------
/src/text_detector/modules/craft.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from ..basenet.vgg16_bn import vgg16_bn, init_weights
12 |
13 |
14 | class double_conv(nn.Module):
15 | def __init__(self, in_ch, mid_ch, out_ch):
16 | super(double_conv, self).__init__()
17 | self.conv = nn.Sequential(
18 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
19 | nn.BatchNorm2d(mid_ch),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
22 | nn.BatchNorm2d(out_ch),
23 | nn.ReLU(inplace=True)
24 | )
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | return x
29 |
30 |
31 | class CRAFT(nn.Module):
32 | def __init__(self, pretrained=False, freeze=False):
33 | super(CRAFT, self).__init__()
34 |
35 | """ Base network """
36 | self.basenet = vgg16_bn(pretrained, freeze)
37 |
38 | """ U network """
39 | self.upconv1 = double_conv(1024, 512, 256)
40 | self.upconv2 = double_conv(512, 256, 128)
41 | self.upconv3 = double_conv(256, 128, 64)
42 | self.upconv4 = double_conv(128, 64, 32)
43 |
44 | num_class = 2
45 | self.conv_cls = nn.Sequential(
46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
47 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
48 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
49 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
50 | nn.Conv2d(16, num_class, kernel_size=1),
51 | )
52 |
53 | init_weights(self.upconv1.modules())
54 | init_weights(self.upconv2.modules())
55 | init_weights(self.upconv3.modules())
56 | init_weights(self.upconv4.modules())
57 | init_weights(self.conv_cls.modules())
58 |
59 | def forward(self, x):
60 | """ Base network """
61 | sources = self.basenet(x)
62 |
63 | """ U network """
64 | y = torch.cat([sources[0], sources[1]], dim=1)
65 | y = self.upconv1(y)
66 |
67 | y = F.interpolate(y, size=sources[2].size()[
68 | 2:], mode='bilinear', align_corners=False)
69 | y = torch.cat([y, sources[2]], dim=1)
70 | y = self.upconv2(y)
71 |
72 | y = F.interpolate(y, size=sources[3].size()[
73 | 2:], mode='bilinear', align_corners=False)
74 | y = torch.cat([y, sources[3]], dim=1)
75 | y = self.upconv3(y)
76 |
77 | y = F.interpolate(y, size=sources[4].size()[
78 | 2:], mode='bilinear', align_corners=False)
79 | y = torch.cat([y, sources[4]], dim=1)
80 | feature = self.upconv4(y)
81 |
82 | y = self.conv_cls(feature)
83 |
84 | return y.permute(0, 2, 3, 1), feature
85 |
86 |
87 | if __name__ == '__main__':
88 | model = CRAFT(pretrained=True).cuda()
89 | output, _ = model(torch.randn(1, 3, 768, 768).cuda())
90 | print(output.shape)
91 |
--------------------------------------------------------------------------------
/src/text_detector/infer.py:
--------------------------------------------------------------------------------
1 | """Copyright (c) 2019-present NAVER Corp.
2 |
3 | MIT License
4 | """
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from .modules import craft_utils
9 | from .modules import imgproc
10 | from torch.autograd import Variable
11 |
12 |
13 | def test_net(
14 | net,
15 | image,
16 | text_threshold,
17 | link_threshold,
18 | low_text,
19 | cuda,
20 | poly,
21 | canvas_size,
22 | mag_ratio,
23 | refine_net=None,
24 | onnx=False,
25 | ):
26 | # resize
27 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
28 | image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
29 | )
30 | ratio_h = ratio_w = 1 / target_ratio
31 |
32 | # preprocessing
33 | x = imgproc.normalizeMeanVariance(img_resized)
34 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
35 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
36 | if cuda:
37 | x = x.cuda()
38 |
39 | if onnx:
40 | # forward pass
41 | input_onnx = {"input": x.numpy()}
42 | with torch.no_grad():
43 | y, feature = net.run(None, input_onnx)
44 |
45 | # make score and link map
46 | score_text = y[0, :, :, 0]
47 | score_link = y[0, :, :, 1]
48 |
49 | # refine link
50 | if refine_net is not None:
51 | with torch.no_grad():
52 | y_refiner = refine_net(y, feature)
53 | score_link = y_refiner[0, :, :, 0]
54 | else:
55 | # forward pass
56 | with torch.no_grad():
57 | y, feature = net(x)
58 |
59 | # make score and link map
60 | score_text = y[0, :, :, 0].cpu().data.numpy()
61 | score_link = y[0, :, :, 1].cpu().data.numpy()
62 |
63 | # refine link
64 | if refine_net is not None:
65 | with torch.no_grad():
66 | y_refiner = refine_net(y, feature)
67 | score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
68 |
69 | # Post-processing
70 | boxes, polys = craft_utils.getDetBoxes(
71 | score_text, score_link, text_threshold, link_threshold, low_text, poly
72 | )
73 |
74 | # coordinate adjustment
75 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
76 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
77 | for k in range(len(polys)):
78 | if polys[k] is None:
79 | polys[k] = boxes[k]
80 |
81 | # render results (optional)
82 | render_img = score_text.copy()
83 | render_img = np.hstack((render_img, score_link))
84 | ret_score_text = imgproc.cvt2HeatmapImg(render_img)
85 | return boxes, polys, ret_score_text
86 |
87 |
88 | def inference(cfg, net, image, onnx=False):
89 | bboxes, polys, score_text = test_net(
90 | net,
91 | image,
92 | cfg["text_threshold"],
93 | cfg["link_threshold"],
94 | cfg["low_text"],
95 | cfg["cuda"],
96 | cfg["poly"],
97 | cfg["canvas_size"],
98 | cfg["mag_ratio"],
99 | onnx=onnx,
100 | )
101 | return polys
102 |
--------------------------------------------------------------------------------
/notebooks/test_api.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "specific-pulse",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import requests\n",
11 | "import json\n",
12 | "import base64\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "id": "earned-missouri",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "image_file = '../data/tes.jpg'\n",
24 | "\n",
25 | "with open(image_file, \"rb\") as f:\n",
26 | " im_bytes = f.read()\n",
27 | "\n",
28 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 3,
34 | "id": "afraid-cheese",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "data = {\"image\": im_b64}\n",
39 | "payload = json.dumps(data)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 4,
45 | "id": "acquired-factory",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "api_endpoint = \"http://127.0.0.1:8000/ocr/predict\"\n",
50 | "headers = {'content-type': 'application/json'}\n",
51 | "\n",
52 | "response = requests.request(\"POST\", api_endpoint, data=payload, headers=headers)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 5,
58 | "id": "turned-queue",
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | "200"
65 | ]
66 | },
67 | "execution_count": 5,
68 | "metadata": {},
69 | "output_type": "execute_result"
70 | }
71 | ],
72 | "source": [
73 | "response.status_code"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 6,
79 | "id": "hindu-diagnosis",
80 | "metadata": {},
81 | "outputs": [
82 | {
83 | "data": {
84 | "text/plain": [
85 | "['Geprek Bensu Kopo Bandung',\n",
86 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n",
87 | " 'KOTA BANDUNG',\n",
88 | " 'order: 33',\n",
89 | " 'Kode',\n",
90 | " 'Tanggal 16-07-2021 11:53:47',\n",
91 | " 'Kasiri Kasir 1 Kopo BDG',\n",
92 | " 'Pelanggant gjk wahyu',\n",
93 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n",
94 | " 'D Level I X 27',\n",
95 | " 't Harga (27',\n",
96 | " 'Dada X',\n",
97 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n",
98 | " 'Harga (17',\n",
99 | " '+ Dada X',\n",
100 | " 'Take Away Charge X 4.000',\n",
101 | " 'Subtotal 49.000',\n",
102 | " 'PB1 (10%) 4.500',\n",
103 | " 'Total 53,500',\n",
104 | " 'Gobiz 53.500',\n",
105 | " 'Kembal Ii',\n",
106 | " 'LUNAS *x',\n",
107 | " 'Terima Kasih']"
108 | ]
109 | },
110 | "execution_count": 6,
111 | "metadata": {},
112 | "output_type": "execute_result"
113 | }
114 | ],
115 | "source": [
116 | "response.json()"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "id": "coordinate-trouble",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": []
126 | }
127 | ],
128 | "metadata": {
129 | "kernelspec": {
130 | "display_name": "Python 3",
131 | "language": "python",
132 | "name": "python3"
133 | },
134 | "language_info": {
135 | "codemirror_mode": {
136 | "name": "ipython",
137 | "version": 3
138 | },
139 | "file_extension": ".py",
140 | "mimetype": "text/x-python",
141 | "name": "python",
142 | "nbconvert_exporter": "python",
143 | "pygments_lexer": "ipython3",
144 | "version": "3.8.3"
145 | }
146 | },
147 | "nbformat": 4,
148 | "nbformat_minor": 5
149 | }
150 |
--------------------------------------------------------------------------------
/src/engine.py:
--------------------------------------------------------------------------------
1 | """Engine Class."""
2 | import math
3 | import time
4 | from abc import ABC, abstractmethod
5 | from .model import DefaultModel
6 | from .text_detector.infer import inference as infer_detector
7 | from .text_recognizer.infer import data_preparation, inference as infer_recognizer
8 |
9 |
10 | def timeit(method):
11 | def timed(*args, **kw):
12 | ts = time.time()
13 | result = method(*args, **kw)
14 | te = time.time()
15 | if "log_time" in kw:
16 | name = kw.get("log_name", method.__name__.upper())
17 | kw["log_time"][name] = int((te - ts))
18 | else:
19 | print("%r %2.2f s" % (method.__name__, (te - ts)))
20 | return result
21 |
22 | return timed
23 |
24 |
25 | class BaseEngine(ABC):
26 | def __init__(self, receipt_ocr_model: DefaultModel):
27 | if not isinstance(receipt_ocr_model, DefaultModel):
28 | raise TypeError
29 | self._model: DefaultModel = receipt_ocr_model
30 |
31 | @abstractmethod
32 | def inference_detector(self):
33 | pass
34 |
35 | @abstractmethod
36 | def inference_recognizer(self):
37 | pass
38 |
39 | @abstractmethod
40 | def predict(self):
41 | pass
42 |
43 |
44 | class DefaultEngine(BaseEngine):
45 | @timeit
46 | def inference_detector(self):
47 | output = infer_detector(self._model.cfg_detector, self._model.detector, self._input)
48 | self._out_detector = output
49 |
50 | @timeit
51 | def inference_recognizer(self):
52 | data_loader = data_preparation(opt=self._model.cfg_recognizer, list_data=self._imgs)
53 | pred, conf_score = infer_recognizer(
54 | opt=self._model.cfg_recognizer,
55 | model=self._model.recognizer,
56 | converter=self._model.converter,
57 | data_loader=data_loader,
58 | )
59 | output = list(zip(pred, conf_score, self._coords))
60 | output = filter(lambda x: x[1] > 0.5, output)
61 | self.raw_output = sorted(output, key=lambda x: x[2][0])
62 | self.result = self.combine_entity()
63 |
64 | @timeit
65 | def predict(self, image):
66 | self._input = image
67 |
68 | self.inference_detector()
69 | self.get_img_from_bb()
70 | self.inference_recognizer()
71 |
72 | def get_img_from_bb(self):
73 | imgs, coords = [], []
74 | for bb in self._out_detector:
75 | cropped_img, coord = self.crop_img(self._input, bb)
76 | imgs.append(cropped_img)
77 | coords.append(coord)
78 | self._imgs = imgs
79 | self._coords = coords
80 |
81 | def crop_img(self, img, bb):
82 | x1, y1 = bb[0]
83 | x2, y2 = bb[1]
84 | x3, y3 = bb[2]
85 | x4, y4 = bb[3]
86 |
87 | top_left_x = math.ceil(min([x1, x2, x3, x4]))
88 | top_left_y = math.ceil(min([y1, y2, y3, y4]))
89 | bot_right_x = math.ceil(max([x1, x2, x3, x4]))
90 | bot_right_y = math.ceil(max([y1, y2, y3, y4]))
91 | coord = (top_left_y, bot_right_y, top_left_x, bot_right_x)
92 |
93 | cropped_image = img[top_left_y : bot_right_y + 1, top_left_x : bot_right_x + 1]
94 | return cropped_image, coord
95 |
96 | def combine_entity(self):
97 | thres = 20
98 | output, all_entity, entity = [], [], []
99 | entity.append(self.raw_output[0])
100 |
101 | for idx in range(len(self.raw_output) - 1):
102 | diff = abs(self.raw_output[idx][2][0] - self.raw_output[idx + 1][2][0])
103 | if diff < thres:
104 | entity.append(self.raw_output[idx + 1])
105 | else:
106 | all_entity.append(entity)
107 | entity = []
108 | entity.append(self.raw_output[idx + 1])
109 |
110 | # Sorting entity by coordinates
111 | for idx in range(len(all_entity)):
112 | all_entity[idx] = sorted(all_entity[idx], key=lambda x: (x[2][3], x[2][1], x[2][2]))
113 |
114 | # Concatenate Entity
115 | for entity in all_entity:
116 | tmp = [x[0] for x in entity]
117 | output.append(" ".join(tmp))
118 | return output
119 |
120 |
121 | class ONNXEngine(DefaultEngine):
122 | @timeit
123 | def inference_detector(self):
124 | output = infer_detector(
125 | self._model.cfg_detector, self._model.detector, self._input, onnx=True
126 | )
127 | self._out_detector = output
128 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5 |
6 |
7 | class Attention(nn.Module):
8 |
9 | def __init__(self, input_size, hidden_size, num_classes):
10 | super(Attention, self).__init__()
11 | self.attention_cell = AttentionCell(
12 | input_size, hidden_size, num_classes)
13 | self.hidden_size = hidden_size
14 | self.num_classes = num_classes
15 | self.generator = nn.Linear(hidden_size, num_classes)
16 |
17 | def _char_to_onehot(self, input_char, onehot_dim=38):
18 | input_char = input_char.unsqueeze(1)
19 | batch_size = input_char.size(0)
20 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
21 | one_hot = one_hot.scatter_(1, input_char, 1)
22 | return one_hot
23 |
24 | def forward(self, batch_H, text, is_train=True, batch_max_length=25):
25 | """
26 | input:
27 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
28 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
29 | output: probability distribution at each step [batch_size x num_steps x num_classes]
30 | """
31 | batch_size = batch_H.size(0)
32 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
33 |
34 | output_hiddens = torch.FloatTensor(
35 | batch_size, num_steps, self.hidden_size).fill_(0).to(device)
36 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
37 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
38 |
39 | if is_train:
40 | for i in range(num_steps):
41 | # one-hot vectors for a i-th char. in a batch
42 | char_onehots = self._char_to_onehot(
43 | text[:, i], onehot_dim=self.num_classes)
44 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
45 | hidden, alpha = self.attention_cell(
46 | hidden, batch_H, char_onehots)
47 | # LSTM hidden index (0: hidden, 1: Cell)
48 | output_hiddens[:, i, :] = hidden[0]
49 | probs = self.generator(output_hiddens)
50 |
51 | else:
52 | targets = torch.LongTensor(batch_size).fill_(
53 | 0).to(device) # [GO] token
54 | probs = torch.FloatTensor(
55 | batch_size, num_steps, self.num_classes).fill_(0).to(device)
56 |
57 | for i in range(num_steps):
58 | char_onehots = self._char_to_onehot(
59 | targets, onehot_dim=self.num_classes)
60 | hidden, alpha = self.attention_cell(
61 | hidden, batch_H, char_onehots)
62 | probs_step = self.generator(hidden[0])
63 | probs[:, i, :] = probs_step
64 | _, next_input = probs_step.max(1)
65 | targets = next_input
66 |
67 | return probs # batch_size x num_steps x num_classes
68 |
69 |
70 | class AttentionCell(nn.Module):
71 |
72 | def __init__(self, input_size, hidden_size, num_embeddings):
73 | super(AttentionCell, self).__init__()
74 | self.i2h = nn.Linear(input_size, hidden_size, bias=False)
75 | # either i2i or h2h should have bias
76 | self.h2h = nn.Linear(hidden_size, hidden_size)
77 | self.score = nn.Linear(hidden_size, 1, bias=False)
78 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
79 | self.hidden_size = hidden_size
80 |
81 | def forward(self, prev_hidden, batch_H, char_onehots):
82 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
83 | batch_H_proj = self.i2h(batch_H)
84 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
85 | # batch_size x num_encoder_step * 1
86 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj))
87 |
88 | alpha = F.softmax(e, dim=1)
89 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(
90 | 1) # batch_size x num_channel
91 | # batch_size x (num_channel + num_embedding)
92 | concat_context = torch.cat([context, char_onehots], 1)
93 | cur_hidden = self.rnn(concat_context, prev_hidden)
94 | return cur_hidden, alpha
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Optical Character Recognition for Receipt
2 |
3 | ## Sample Results
4 | Input Image | Output
5 | :----------------------:|:----------------------:
6 |
|
7 |
8 | ## References
9 |
10 | | Title | Author | Year | Github | Paper | Download Model|
11 | | ----------------------------------------------------------------------------------------| ---------------- | ---- | --------- | ----- | -------- |
12 | | Character Region Awareness for Text Detection | Clova AI Research, NAVER Corp.| 2019 | https://github.com/clovaai/CRAFT-pytorch | https://arxiv.org/abs/1904.01941 | [craft_mlt_25k.pth](https://drive.google.com/file/d/1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ/view)|
13 | | What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis | Clova AI Research, NAVER Corp.| 2019 | https://github.com/clovaai/deep-text-recognition-benchmark | https://arxiv.org/abs/1904.01906 | [TPS-ResNet-BiLSTM-Attn-case-sensitive.pth](https://www.dropbox.com/sh/j3xmli4di1zuv3s/AAArdcPgz7UFxIHUuKNOeKv_a?dl=0) |
14 |
15 | ## Folder structure
16 | ```
17 | .
18 | ├─ configs
19 | | ├─ craft_config.yaml
20 | | └─ star_config.yaml
21 | ├─ data
22 | | ├─ sample_output.jpg
23 | | └─ tes.jpg
24 | ├─ notebooks
25 | | ├─ export_onnx_model.ipynb
26 | | ├─ inference_default_engine.ipynb
27 | | ├─ inference_onnx_engine.ipynb
28 | | └─ test_api.ipynb
29 | ├─ src
30 | | ├─ text_detector
31 | | │ ├─ basenet
32 | | │ │ ├─ __init__.py
33 | | │ │ └─ vgg16_bn.py
34 | | │ ├─ modules
35 | | │ │ ├─ __init__.py
36 | | │ │ ├─ craft.py
37 | | │ │ ├─ craft_utils.py
38 | | │ │ ├─ imgproc.py
39 | | │ │ ├─ refinenet.py
40 | | │ │ └─ utils.py
41 | | │ ├─ __init__.py
42 | | │ ├─ infer.py
43 | | │ └─ load_model.py
44 | | ├─ text_recognizer
45 | | │ ├─ modules
46 | | │ │ ├─ dataset.py
47 | | │ │ ├─ feature_extraction.py
48 | | │ │ ├─ model.py
49 | | │ │ ├─ model_utils.py
50 | | │ │ ├─ prediction.py
51 | | │ │ ├─ sequence_modeling.py
52 | | │ │ ├─ transformation.py
53 | | │ │ └─ utils.py
54 | | │ ├─ __init__.py
55 | | │ ├─ infer.py
56 | | │ └─ load_model.py
57 | | ├─ __init__.py
58 | | ├─ engine.py
59 | | └─ model.py
60 | ├─ .gitignore
61 | ├─ CONTRIBUTING.md
62 | ├─ Dockerfile
63 | ├─ environment.yaml
64 | ├─ LICENSE
65 | ├─ main.py
66 | ├─ pyproject.toml
67 | ├─ README.md
68 | ├─ requirements.txt
69 | ├─ setup.cfg
70 | ```
71 |
72 | ## Model Preparation
73 | You need to create "models" folder to store this:
74 | - detector_model = "models/text_detector/craft_mlt_25k.pth"
75 | - recognizer_model = "models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth"
76 |
77 | Download all of pretrained models from "References" section
78 |
79 | ## Requirements
80 | You can setup the environment using conda or pip
81 | ```
82 | pip install -r requirements.txt
83 | ```
84 | or
85 | ```
86 | conda env create -f environment.yaml
87 | ```
88 |
89 | ## Container
90 | ```
91 | docker build -t receipt-ocr .
92 | docker run -d --name receipt-ocr-service -p 80:80 receipt-ocr
93 | docker start receipt-ocr-service
94 | docker stop receipt-ocr-service
95 | ```
96 |
97 | ## How to contribute?
98 | Check the docs [here](CONTRIBUTING.md)
99 |
100 | ## Creator
101 | [](https://github.com/rubentea16)
--------------------------------------------------------------------------------
/src/text_recognizer/modules/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import torch.nn as nn
18 |
19 | from .transformation import TPS_SpatialTransformerNetwork
20 | from .feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
21 | from .sequence_modeling import BidirectionalLSTM
22 | from .prediction import Attention
23 |
24 |
25 | class Model(nn.Module):
26 |
27 | def __init__(self, opt):
28 | super(Model, self).__init__()
29 | self.opt = opt
30 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
31 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
32 |
33 | """ Transformation """
34 | if opt.Transformation == 'TPS':
35 | self.Transformation = TPS_SpatialTransformerNetwork(F=opt.num_fiducial,
36 | I_size=(
37 | opt.imgH, opt.imgW),
38 | I_r_size=(
39 | opt.imgH, opt.imgW),
40 | I_channel_num=opt.input_channel)
41 | else:
42 | print('No Transformation module specified')
43 |
44 | """ FeatureExtraction """
45 | if opt.FeatureExtraction == 'VGG':
46 | self.FeatureExtraction = VGG_FeatureExtractor(
47 | opt.input_channel, opt.output_channel)
48 | elif opt.FeatureExtraction == 'RCNN':
49 | self.FeatureExtraction = RCNN_FeatureExtractor(
50 | opt.input_channel, opt.output_channel)
51 | elif opt.FeatureExtraction == 'ResNet':
52 | self.FeatureExtraction = ResNet_FeatureExtractor(
53 | opt.input_channel, opt.output_channel)
54 | else:
55 | raise Exception('No FeatureExtraction module specified')
56 | # int(imgH/16-1) * 512
57 | self.FeatureExtraction_output = opt.output_channel
58 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
59 | (None, 1)) # Transform final (imgH/16-1) -> 1
60 |
61 | """ Sequence modeling"""
62 | if opt.SequenceModeling == 'BiLSTM':
63 | self.SequenceModeling = nn.Sequential(
64 | BidirectionalLSTM(self.FeatureExtraction_output,
65 | opt.hidden_size, opt.hidden_size),
66 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
67 | self.SequenceModeling_output = opt.hidden_size
68 | else:
69 | print('No SequenceModeling module specified')
70 | self.SequenceModeling_output = self.FeatureExtraction_output
71 |
72 | """ Prediction """
73 | if opt.Prediction == 'CTC':
74 | self.Prediction = nn.Linear(
75 | self.SequenceModeling_output, opt.num_class)
76 | elif opt.Prediction == 'Attn':
77 | self.Prediction = Attention(
78 | self.SequenceModeling_output, opt.hidden_size, opt.num_class)
79 | else:
80 | raise Exception('Prediction is neither CTC or Attn')
81 |
82 | def forward(self, input, text, is_train=True):
83 | """ Transformation stage """
84 | if not self.stages['Trans'] == "None":
85 | input = self.Transformation(input)
86 |
87 | """ Feature extraction stage """
88 | visual_feature = self.FeatureExtraction(input)
89 | visual_feature = self.AdaptiveAvgPool(
90 | visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
91 | visual_feature = visual_feature.squeeze(3)
92 |
93 | """ Sequence modeling stage """
94 | if self.stages['Seq'] == 'BiLSTM':
95 | contextual_feature = self.SequenceModeling(visual_feature)
96 | else:
97 | # for convenience. this is NOT contextually modeled by BiLSTM
98 | contextual_feature = visual_feature
99 |
100 | """ Prediction stage """
101 | if self.stages['Pred'] == 'CTC':
102 | prediction = self.Prediction(contextual_feature.contiguous())
103 | else:
104 | prediction = self.Prediction(contextual_feature.contiguous(
105 | ), text, is_train, batch_max_length=self.opt.batch_max_length)
106 |
107 | return prediction
108 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3 |
4 |
5 | class CTCLabelConverter(object):
6 | """ Convert between text-label and text-index """
7 |
8 | def __init__(self, character):
9 | # character (str): set of the possible characters.
10 | dict_character = list(character)
11 |
12 | self.dict = {}
13 | for i, char in enumerate(dict_character):
14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
15 | self.dict[char] = i + 1
16 |
17 | # dummy '[CTCblank]' token for CTCLoss (index 0)
18 | self.character = ['[CTCblank]'] + dict_character
19 |
20 | def encode(self, text, batch_max_length=25):
21 | """convert text-label into text-index.
22 | input:
23 | text: text labels of each image. [batch_size]
24 | batch_max_length: max length of text label in the batch. 25 by default
25 |
26 | output:
27 | text: text index for CTCLoss. [batch_size, batch_max_length]
28 | length: length of each text. [batch_size]
29 | """
30 | length = [len(s) for s in text]
31 |
32 | # The index used for padding (=0) would not affect the CTC loss calculation.
33 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
34 | for i, t in enumerate(text):
35 | text = list(t)
36 | text = [self.dict[char] for char in text]
37 | batch_text[i][:len(text)] = torch.LongTensor(text)
38 | return (batch_text.to(device), torch.IntTensor(length).to(device))
39 |
40 | def decode(self, text_index, length):
41 | """ convert text-index into text-label. """
42 | texts = []
43 | for index, l in enumerate(length):
44 | t = text_index[index, :]
45 |
46 | char_list = []
47 | for i in range(l):
48 | # removing repeated characters and blank.
49 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
50 | char_list.append(self.character[t[i]])
51 | text = ''.join(char_list)
52 |
53 | texts.append(text)
54 | return texts
55 |
56 |
57 | class CTCLabelConverterForBaiduWarpctc(object):
58 | """ Convert between text-label and text-index for baidu warpctc """
59 |
60 | def __init__(self, character):
61 | # character (str): set of the possible characters.
62 | dict_character = list(character)
63 |
64 | self.dict = {}
65 | for i, char in enumerate(dict_character):
66 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
67 | self.dict[char] = i + 1
68 |
69 | # dummy '[CTCblank]' token for CTCLoss (index 0)
70 | self.character = ['[CTCblank]'] + dict_character
71 |
72 | def encode(self, text, batch_max_length=25):
73 | """convert text-label into text-index.
74 | input:
75 | text: text labels of each image. [batch_size]
76 | output:
77 | text: concatenated text index for CTCLoss.
78 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
79 | length: length of each text. [batch_size]
80 | """
81 | length = [len(s) for s in text]
82 | text = ''.join(text)
83 | text = [self.dict[char] for char in text]
84 |
85 | return (torch.IntTensor(text), torch.IntTensor(length))
86 |
87 | def decode(self, text_index, length):
88 | """ convert text-index into text-label. """
89 | texts = []
90 | index = 0
91 | for l in length:
92 | t = text_index[index:index + l]
93 |
94 | char_list = []
95 | for i in range(l):
96 | # removing repeated characters and blank.
97 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
98 | char_list.append(self.character[t[i]])
99 | text = ''.join(char_list)
100 |
101 | texts.append(text)
102 | index += l
103 | return texts
104 |
105 |
106 | class AttnLabelConverter(object):
107 | """ Convert between text-label and text-index """
108 |
109 | def __init__(self, character):
110 | # character (str): set of the possible characters.
111 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
112 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
113 | list_character = list(character)
114 | self.character = list_token + list_character
115 |
116 | self.dict = {}
117 | for i, char in enumerate(self.character):
118 | # print(i, char)
119 | self.dict[char] = i
120 |
121 | def encode(self, text, batch_max_length=25):
122 | """ convert text-label into text-index.
123 | input:
124 | text: text labels of each image. [batch_size]
125 | batch_max_length: max length of text label in the batch. 25 by default
126 |
127 | output:
128 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
129 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
130 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
131 | """
132 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
133 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting
134 | batch_max_length += 1
135 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
136 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
137 | for i, t in enumerate(text):
138 | text = list(t)
139 | text.append('[s]')
140 | text = [self.dict[char] for char in text]
141 | # batch_text[:, 0] = [GO] token
142 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text)
143 | return (batch_text.to(device), torch.IntTensor(length).to(device))
144 |
145 | def decode(self, text_index, length):
146 | """ convert text-index into text-label. """
147 | texts = []
148 | for index, l in enumerate(length):
149 | text = ''.join([self.character[i] for i in text_index[index, :]])
150 | texts.append(text)
151 | return texts
152 |
153 |
154 | class Averager(object):
155 | """Compute average for torch.Tensor, used for loss average."""
156 |
157 | def __init__(self):
158 | self.reset()
159 |
160 | def add(self, v):
161 | count = v.data.numel()
162 | v = v.data.sum()
163 | self.n_count += count
164 | self.sum += v
165 |
166 | def reset(self):
167 | self.n_count = 0
168 | self.sum = 0
169 |
170 | def val(self):
171 | res = 0
172 | if self.n_count != 0:
173 | res = self.sum / float(self.n_count)
174 | return res
175 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import six
5 | import math
6 | import numpy as np
7 | import torch
8 |
9 | import torchvision.transforms as transforms
10 | from PIL import Image
11 | from torch.utils.data import Dataset, ConcatDataset, Subset
12 | from torch._utils import _accumulate
13 |
14 |
15 | class Batch_Balanced_Dataset(object):
16 |
17 | def __init__(self, opt):
18 | """
19 | Modulate the data ratio in the batch.
20 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
21 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
22 | """
23 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
24 | dashed_line = '-' * 80
25 | print(dashed_line)
26 | log.write(dashed_line + '\n')
27 | print(
28 | f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
29 | log.write(
30 | f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
31 | assert len(opt.select_data) == len(opt.batch_ratio)
32 |
33 | _AlignCollate = AlignCollate(
34 | imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
35 | self.data_loader_list = []
36 | self.dataloader_iter_list = []
37 | batch_size_list = []
38 | Total_batch_size = 0
39 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
40 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
41 | print(dashed_line)
42 | log.write(dashed_line + '\n')
43 | _dataset, _dataset_log = hierarchical_dataset(
44 | root=opt.train_data, opt=opt, select_data=[selected_d])
45 | total_number_dataset = len(_dataset)
46 | log.write(_dataset_log)
47 |
48 | """
49 | The total number of data can be modified with opt.total_data_usage_ratio.
50 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
51 | See 4.2 section in our paper.
52 | """
53 | number_dataset = int(total_number_dataset *
54 | float(opt.total_data_usage_ratio))
55 | dataset_split = [number_dataset,
56 | total_number_dataset - number_dataset]
57 | indices = range(total_number_dataset)
58 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset])
59 | for offset, length in zip(_accumulate(dataset_split), dataset_split)]
60 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
61 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
62 | print(selected_d_log)
63 | log.write(selected_d_log + '\n')
64 | batch_size_list.append(str(_batch_size))
65 | Total_batch_size += _batch_size
66 |
67 | _data_loader = torch.utils.data.DataLoader(
68 | _dataset, batch_size=_batch_size,
69 | shuffle=True,
70 | num_workers=int(opt.workers),
71 | collate_fn=_AlignCollate, pin_memory=True)
72 | self.data_loader_list.append(_data_loader)
73 | self.dataloader_iter_list.append(iter(_data_loader))
74 |
75 | Total_batch_size_log = f'{dashed_line}\n'
76 | batch_size_sum = '+'.join(batch_size_list)
77 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
78 | Total_batch_size_log += f'{dashed_line}'
79 | opt.batch_size = Total_batch_size
80 |
81 | print(Total_batch_size_log)
82 | log.write(Total_batch_size_log + '\n')
83 | log.close()
84 |
85 | def get_batch(self):
86 | balanced_batch_images = []
87 | balanced_batch_texts = []
88 |
89 | for i, data_loader_iter in enumerate(self.dataloader_iter_list):
90 | try:
91 | image, text = data_loader_iter.next()
92 | balanced_batch_images.append(image)
93 | balanced_batch_texts += text
94 | except StopIteration:
95 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
96 | image, text = self.dataloader_iter_list[i].next()
97 | balanced_batch_images.append(image)
98 | balanced_batch_texts += text
99 | except ValueError:
100 | pass
101 |
102 | balanced_batch_images = torch.cat(balanced_batch_images, 0)
103 |
104 | return balanced_batch_images, balanced_batch_texts
105 |
106 |
107 | class RawDataset(Dataset):
108 |
109 | def __init__(self, list_data, opt):
110 | self.opt = opt
111 | self.image_list = list_data
112 | self.nSamples = len(self.image_list)
113 |
114 | def __len__(self):
115 | return self.nSamples
116 |
117 | def __getitem__(self, index):
118 | gray_img = Image.fromarray(self.image_list[index]).convert('L')
119 | return (gray_img, 'image')
120 |
121 |
122 | class ResizeNormalize(object):
123 |
124 | def __init__(self, size, interpolation=Image.BICUBIC):
125 | self.size = size
126 | self.interpolation = interpolation
127 | self.toTensor = transforms.ToTensor()
128 |
129 | def __call__(self, img):
130 | img = img.resize(self.size, self.interpolation)
131 | img = self.toTensor(img)
132 | img.sub_(0.5).div_(0.5)
133 | return img
134 |
135 |
136 | class NormalizePAD(object):
137 |
138 | def __init__(self, max_size, PAD_type='right'):
139 | self.toTensor = transforms.ToTensor()
140 | self.max_size = max_size
141 | self.max_width_half = math.floor(max_size[2] / 2)
142 | self.PAD_type = PAD_type
143 |
144 | def __call__(self, img):
145 | img = self.toTensor(img)
146 | img.sub_(0.5).div_(0.5)
147 | c, h, w = img.size()
148 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
149 | Pad_img[:, :, :w] = img # right pad
150 | if self.max_size[2] != w: # add border Pad
151 | Pad_img[:, :, w:] = img[:, :, w -
152 | 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
153 |
154 | return Pad_img
155 |
156 |
157 | class AlignCollate(object):
158 |
159 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
160 | self.imgH = imgH
161 | self.imgW = imgW
162 | self.keep_ratio_with_pad = keep_ratio_with_pad
163 |
164 | def __call__(self, batch):
165 | batch = filter(lambda x: x is not None, batch)
166 | images, labels = zip(*batch)
167 |
168 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
169 | resized_max_w = self.imgW
170 | input_channel = 3 if images[0].mode == 'RGB' else 1
171 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
172 |
173 | resized_images = []
174 | for image in images:
175 | w, h = image.size
176 | ratio = w / float(h)
177 | if math.ceil(self.imgH * ratio) > self.imgW:
178 | resized_w = self.imgW
179 | else:
180 | resized_w = math.ceil(self.imgH * ratio)
181 |
182 | resized_image = image.resize(
183 | (resized_w, self.imgH), Image.BICUBIC)
184 | resized_images.append(transform(resized_image))
185 | # resized_image.save('./image_test/%d_test.jpg' % w)
186 |
187 | image_tensors = torch.cat([t.unsqueeze(0)
188 | for t in resized_images], 0)
189 |
190 | else:
191 | transform = ResizeNormalize((self.imgW, self.imgH))
192 | image_tensors = [transform(image) for image in images]
193 | image_tensors = torch.cat([t.unsqueeze(0)
194 | for t in image_tensors], 0)
195 |
196 | return image_tensors, labels
197 |
198 |
199 | def tensor2im(image_tensor, imtype=np.uint8):
200 | image_numpy = image_tensor.cpu().float().numpy()
201 | if image_numpy.shape[0] == 1:
202 | image_numpy = np.tile(image_numpy, (3, 1, 1))
203 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
204 | return image_numpy.astype(imtype)
205 |
206 |
207 | def save_image(image_numpy, image_path):
208 | image_pil = Image.fromarray(image_numpy)
209 | image_pil.save(image_path)
210 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 |
8 | class TPS_SpatialTransformerNetwork(nn.Module):
9 | """ Rectification Network of RARE, namely TPS based STN """
10 |
11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1):
12 | """ Based on RARE TPS
13 | input:
14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
15 | I_size : (height, width) of the input image I
16 | I_r_size : (height, width) of the rectified image I_r
17 | I_channel_num : the number of channels of the input image I
18 | output:
19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
20 | """
21 | super(TPS_SpatialTransformerNetwork, self).__init__()
22 | self.F = F
23 | self.I_size = I_size
24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
25 | self.I_channel_num = I_channel_num
26 | self.LocalizationNetwork = LocalizationNetwork(
27 | self.F, self.I_channel_num)
28 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
29 |
30 | def forward(self, batch_I):
31 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
32 | # batch_size x n (= I_r_width x I_r_height) x 2
33 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime)
34 | build_P_prime_reshape = build_P_prime.reshape(
35 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
36 |
37 | if torch.__version__ > "1.2.0":
38 | batch_I_r = F.grid_sample(
39 | batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
40 | else:
41 | batch_I_r = F.grid_sample(
42 | batch_I, build_P_prime_reshape, padding_mode='border')
43 |
44 | return batch_I_r
45 |
46 |
47 | class LocalizationNetwork(nn.Module):
48 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
49 |
50 | def __init__(self, F, I_channel_num):
51 | super(LocalizationNetwork, self).__init__()
52 | self.F = F
53 | self.I_channel_num = I_channel_num
54 | self.conv = nn.Sequential(
55 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
56 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
57 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
58 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(
59 | 128), nn.ReLU(True),
60 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
61 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(
62 | 256), nn.ReLU(True),
63 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
64 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(
65 | 512), nn.ReLU(True),
66 | nn.AdaptiveAvgPool2d(1) # batch_size x 512
67 | )
68 |
69 | self.localization_fc1 = nn.Sequential(
70 | nn.Linear(512, 256), nn.ReLU(True))
71 | self.localization_fc2 = nn.Linear(256, self.F * 2)
72 |
73 | # Init fc2 in LocalizationNetwork
74 | self.localization_fc2.weight.data.fill_(0)
75 | """ see RARE paper Fig. 6 (a) """
76 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
77 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
78 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
79 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
80 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
81 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
82 | self.localization_fc2.bias.data = torch.from_numpy(
83 | initial_bias).float().view(-1)
84 |
85 | def forward(self, batch_I):
86 | """
87 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
88 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
89 | """
90 | batch_size = batch_I.size(0)
91 | features = self.conv(batch_I).view(batch_size, -1)
92 | batch_C_prime = self.localization_fc2(
93 | self.localization_fc1(features)).view(batch_size, self.F, 2)
94 | return batch_C_prime
95 |
96 |
97 | class GridGenerator(nn.Module):
98 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """
99 |
100 | def __init__(self, F, I_r_size):
101 | """ Generate P_hat and inv_delta_C for later """
102 | super(GridGenerator, self).__init__()
103 | self.eps = 1e-6
104 | self.I_r_height, self.I_r_width = I_r_size
105 | self.F = F
106 | self.C = self._build_C(self.F) # F x 2
107 | self.P = self._build_P(self.I_r_width, self.I_r_height)
108 | # for multi-gpu, you need register buffer
109 | self.register_buffer("inv_delta_C", torch.tensor(
110 | self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
111 | self.register_buffer("P_hat", torch.tensor(
112 | self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
113 | # for fine-tuning with different image width, you may use below instead of self.register_buffer
114 | # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
115 | # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
116 |
117 | def _build_C(self, F):
118 | """ Return coordinates of fiducial points in I_r; C """
119 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
120 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
121 | ctrl_pts_y_bottom = np.ones(int(F / 2))
122 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
123 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
124 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
125 | return C # F x 2
126 |
127 | def _build_inv_delta_C(self, F, C):
128 | """ Return inv_delta_C which is needed to calculate T """
129 | hat_C = np.zeros((F, F), dtype=float) # F x F
130 | for i in range(0, F):
131 | for j in range(i, F):
132 | r = np.linalg.norm(C[i] - C[j])
133 | hat_C[i, j] = r
134 | hat_C[j, i] = r
135 | np.fill_diagonal(hat_C, 1)
136 | hat_C = (hat_C ** 2) * np.log(hat_C)
137 | # print(C.shape, hat_C.shape)
138 | delta_C = np.concatenate( # F+3 x F+3
139 | [
140 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
141 | np.concatenate(
142 | [np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
143 | np.concatenate(
144 | [np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
145 | ],
146 | axis=0
147 | )
148 | inv_delta_C = np.linalg.inv(delta_C)
149 | return inv_delta_C # F+3 x F+3
150 |
151 | def _build_P(self, I_r_width, I_r_height):
152 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) +
153 | 1.0) / I_r_width # self.I_r_width
154 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) +
155 | 1.0) / I_r_height # self.I_r_height
156 | P = np.stack( # self.I_r_width x self.I_r_height x 2
157 | np.meshgrid(I_r_grid_x, I_r_grid_y),
158 | axis=2
159 | )
160 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
161 |
162 | def _build_P_hat(self, F, C, P):
163 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
164 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)
165 | ) # n x 2 -> n x 1 x 2 -> n x F x 2
166 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
167 | P_diff = P_tile - C_tile # n x F x 2
168 | rbf_norm = np.linalg.norm(
169 | P_diff, ord=2, axis=2, keepdims=False) # n x F
170 | rbf = np.multiply(np.square(rbf_norm), np.log(
171 | rbf_norm + self.eps)) # n x F
172 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
173 | return P_hat # n x F+3
174 |
175 | def build_P_prime(self, batch_C_prime):
176 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """
177 | batch_size = batch_C_prime.size(0)
178 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
179 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
180 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
181 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
182 | # batch_size x F+3 x 2
183 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros)
184 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
185 | return batch_P_prime # batch_size x n x 2
186 |
--------------------------------------------------------------------------------
/src/text_detector/modules/craft_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import numpy as np
8 | import cv2
9 | import math
10 |
11 | """ auxilary functions """
12 | # unwarp corodinates
13 |
14 |
15 | def warpCoord(Minv, pt):
16 | out = np.matmul(Minv, (pt[0], pt[1], 1))
17 | return np.array([out[0]/out[2], out[1]/out[2]])
18 |
19 |
20 | """ end of auxilary functions """
21 |
22 |
23 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
24 | # prepare data
25 | linkmap = linkmap.copy()
26 | textmap = textmap.copy()
27 | img_h, img_w = textmap.shape
28 |
29 | """ labeling method """
30 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
31 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
32 |
33 | text_score_comb = np.clip(text_score + link_score, 0, 1)
34 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
35 | text_score_comb.astype(np.uint8), connectivity=4)
36 |
37 | det = []
38 | mapper = []
39 | for k in range(1, nLabels):
40 | # size filtering
41 | size = stats[k, cv2.CC_STAT_AREA]
42 | if size < 10:
43 | continue
44 |
45 | # thresholding
46 | if np.max(textmap[labels == k]) < text_threshold:
47 | continue
48 |
49 | # make segmentation map
50 | segmap = np.zeros(textmap.shape, dtype=np.uint8)
51 | segmap[labels == k] = 255
52 | # remove link area
53 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0
54 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
55 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
56 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
57 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
58 | # boundary check
59 | if sx < 0:
60 | sx = 0
61 | if sy < 0:
62 | sy = 0
63 | if ex >= img_w:
64 | ex = img_w
65 | if ey >= img_h:
66 | ey = img_h
67 | kernel = cv2.getStructuringElement(
68 | cv2.MORPH_RECT, (1 + niter, 1 + niter))
69 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
70 |
71 | # make box
72 | np_contours = np.roll(np.array(np.where(segmap != 0)),
73 | 1, axis=0).transpose().reshape(-1, 2)
74 | rectangle = cv2.minAreaRect(np_contours)
75 | box = cv2.boxPoints(rectangle)
76 |
77 | # align diamond-shape
78 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
79 | box_ratio = max(w, h) / (min(w, h) + 1e-5)
80 | if abs(1 - box_ratio) <= 0.1:
81 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
82 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
83 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
84 |
85 | # make clock-wise order
86 | startidx = box.sum(axis=1).argmin()
87 | box = np.roll(box, 4-startidx, 0)
88 | box = np.array(box)
89 |
90 | det.append(box)
91 | mapper.append(k)
92 |
93 | return det, labels, mapper
94 |
95 |
96 | def getPoly_core(boxes, labels, mapper, linkmap):
97 | # configs
98 | num_cp = 5
99 | max_len_ratio = 0.7
100 | expand_ratio = 1.45
101 | max_r = 2.0
102 | step_r = 0.2
103 |
104 | polys = []
105 | for k, box in enumerate(boxes):
106 | # size filter for small instance
107 | w, h = int(np.linalg.norm(box[0] - box[1]) +
108 | 1), int(np.linalg.norm(box[1] - box[2]) + 1)
109 | if w < 10 or h < 10:
110 | polys.append(None)
111 | continue
112 |
113 | # warp image
114 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
115 | M = cv2.getPerspectiveTransform(box, tar)
116 | word_label = cv2.warpPerspective(
117 | labels, M, (w, h), flags=cv2.INTER_NEAREST)
118 | try:
119 | Minv = np.linalg.inv(M)
120 | except:
121 | polys.append(None)
122 | continue
123 |
124 | # binarization for selected label
125 | cur_label = mapper[k]
126 | word_label[word_label != cur_label] = 0
127 | word_label[word_label > 0] = 1
128 |
129 | """ Polygon generation """
130 | # find top/bottom contours
131 | cp = []
132 | max_len = -1
133 | for i in range(w):
134 | region = np.where(word_label[:, i] != 0)[0]
135 | if len(region) < 2:
136 | continue
137 | cp.append((i, region[0], region[-1]))
138 | length = region[-1] - region[0] + 1
139 | if length > max_len:
140 | max_len = length
141 |
142 | # pass if max_len is similar to h
143 | if h * max_len_ratio < max_len:
144 | polys.append(None)
145 | continue
146 |
147 | # get pivot points with fixed length
148 | tot_seg = num_cp * 2 + 1
149 | seg_w = w / tot_seg # segment width
150 | pp = [None] * num_cp # init pivot points
151 | cp_section = [[0, 0]] * tot_seg
152 | seg_height = [0] * num_cp
153 | seg_num = 0
154 | num_sec = 0
155 | prev_h = -1
156 | for i in range(0, len(cp)):
157 | (x, sy, ey) = cp[i]
158 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
159 | # average previous segment
160 | if num_sec == 0:
161 | break
162 | cp_section[seg_num] = [cp_section[seg_num][0] /
163 | num_sec, cp_section[seg_num][1] / num_sec]
164 | num_sec = 0
165 |
166 | # reset variables
167 | seg_num += 1
168 | prev_h = -1
169 |
170 | # accumulate center points
171 | cy = (sy + ey) * 0.5
172 | cur_h = ey - sy + 1
173 | cp_section[seg_num] = [cp_section[seg_num]
174 | [0] + x, cp_section[seg_num][1] + cy]
175 | num_sec += 1
176 |
177 | if seg_num % 2 == 0:
178 | continue # No polygon area
179 |
180 | if prev_h < cur_h:
181 | pp[int((seg_num - 1)/2)] = (x, cy)
182 | seg_height[int((seg_num - 1)/2)] = cur_h
183 | prev_h = cur_h
184 |
185 | # processing last segment
186 | if num_sec != 0:
187 | cp_section[-1] = [cp_section[-1][0] /
188 | num_sec, cp_section[-1][1] / num_sec]
189 |
190 | # pass if num of pivots is not sufficient or segment widh is smaller than character height
191 | if None in pp or seg_w < np.max(seg_height) * 0.25:
192 | polys.append(None)
193 | continue
194 |
195 | # calc median maximum of pivot points
196 | half_char_h = np.median(seg_height) * expand_ratio / 2
197 |
198 | # calc gradiant and apply to make horizontal pivots
199 | new_pp = []
200 | for i, (x, cy) in enumerate(pp):
201 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
202 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
203 | if dx == 0: # gradient if zero
204 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
205 | continue
206 | rad = - math.atan2(dy, dx)
207 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
208 | new_pp.append([x - s, cy - c, x + s, cy + c])
209 |
210 | # get edge points to cover character heatmaps
211 | isSppFound, isEppFound = False, False
212 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + \
213 | (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
214 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + \
215 | (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
216 | for r in np.arange(0.5, max_r, step_r):
217 | dx = 2 * half_char_h * r
218 | if not isSppFound:
219 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
220 | dy = grad_s * dx
221 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
222 | cv2.line(line_img, (int(p[0]), int(p[1])),
223 | (int(p[2]), int(p[3])), 1, thickness=1)
224 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
225 | spp = p
226 | isSppFound = True
227 | if not isEppFound:
228 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
229 | dy = grad_e * dx
230 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
231 | cv2.line(line_img, (int(p[0]), int(p[1])),
232 | (int(p[2]), int(p[3])), 1, thickness=1)
233 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
234 | epp = p
235 | isEppFound = True
236 | if isSppFound and isEppFound:
237 | break
238 |
239 | # pass if boundary of polygon is not found
240 | if not (isSppFound and isEppFound):
241 | polys.append(None)
242 | continue
243 |
244 | # make final polygon
245 | poly = []
246 | poly.append(warpCoord(Minv, (spp[0], spp[1])))
247 | for p in new_pp:
248 | poly.append(warpCoord(Minv, (p[0], p[1])))
249 | poly.append(warpCoord(Minv, (epp[0], epp[1])))
250 | poly.append(warpCoord(Minv, (epp[2], epp[3])))
251 | for p in reversed(new_pp):
252 | poly.append(warpCoord(Minv, (p[2], p[3])))
253 | poly.append(warpCoord(Minv, (spp[2], spp[3])))
254 |
255 | # add to final result
256 | polys.append(np.array(poly))
257 |
258 | return polys
259 |
260 |
261 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
262 | boxes, labels, mapper = getDetBoxes_core(
263 | textmap, linkmap, text_threshold, link_threshold, low_text)
264 |
265 | if poly:
266 | polys = getPoly_core(boxes, labels, mapper, linkmap)
267 | else:
268 | polys = [None] * len(boxes)
269 |
270 | return boxes, polys
271 |
272 |
273 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
274 | if len(polys) > 0:
275 | polys = np.array(polys)
276 | for k in range(len(polys)):
277 | if polys[k] is not None:
278 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
279 | return polys
280 |
--------------------------------------------------------------------------------
/notebooks/inference_onnx_engine.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "false-shipping",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import sys\n",
11 | "import io\n",
12 | "import base64\n",
13 | "import numpy as np\n",
14 | "import warnings\n",
15 | "from PIL import Image\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append('..')\n",
19 | "\n",
20 | "from src.model import ONNXModel\n",
21 | "from src.engine import ONNXEngine"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "id": "automated-graduation",
27 | "metadata": {},
28 | "source": [
29 | "## Load model network and weight"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "id": "periodic-timothy",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "detector_cfg = '../configs/craft_config.yaml'\n",
40 | "detector_model = '../models/text_detector/craft.onnx'\n",
41 | "recognizer_cfg = '../configs/star_config.yaml'\n",
42 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "id": "recovered-consortium",
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "Loading weights from checkpoint (../models/text_detector/craft.onnx)\n",
56 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "model = ONNXModel(detector_cfg, detector_model, \n",
62 | " recognizer_cfg, recognizer_model)"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "inside-salon",
68 | "metadata": {},
69 | "source": [
70 | "## Load Engine"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "id": "editorial-bargain",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "engine = ONNXEngine(model)"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "known-seattle",
86 | "metadata": {},
87 | "source": [
88 | "## Input Data Preparation"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "id": "sonic-jacksonville",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "image_file = '../data/tes.jpg'\n",
99 | "\n",
100 | "with open(image_file, \"rb\") as f:\n",
101 | " im_bytes = f.read()\n",
102 | "\n",
103 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "id": "featured-boost",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "img_bytes = base64.b64decode(im_b64.encode('utf-8'))\n",
114 | "image = Image.open(io.BytesIO(img_bytes))\n",
115 | "image = np.array(image)"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "id": "plain-prediction",
121 | "metadata": {},
122 | "source": [
123 | "## Prediction"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 7,
129 | "id": "biblical-davis",
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "'inference_detector' 2.73 s\n",
137 | "'inference_recognizer' 3.43 s\n",
138 | "'predict' 6.17 s\n"
139 | ]
140 | }
141 | ],
142 | "source": [
143 | "engine.predict(image)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 8,
149 | "id": "burning-jacket",
150 | "metadata": {
151 | "scrolled": true
152 | },
153 | "outputs": [
154 | {
155 | "data": {
156 | "text/plain": [
157 | "[('Kopo', tensor(0.9991), (620, 680, 481, 575)),\n",
158 | " ('Bandung', tensor(1.0000), (622, 673, 575, 713)),\n",
159 | " ('Geprek', tensor(1.0000), (626, 676, 253, 374)),\n",
160 | " ('Bensu', tensor(0.9999), (626, 676, 380, 481)),\n",
161 | " ('Babakan', tensor(0.9999), (666, 716, 602, 740)),\n",
162 | " ('Ciparay', tensor(0.9997), (666, 716, 743, 884)),\n",
163 | " ('536,', tensor(0.9938), (666, 720, 313, 397)),\n",
164 | " ('Margasuka,', tensor(0.9953), (669, 723, 404, 592)),\n",
165 | " ('J1,', tensor(0.5975), (673, 723, 85, 145)),\n",
166 | " ('Kopo', tensor(0.9966), (673, 723, 155, 239)),\n",
167 | " ('No.', tensor(0.7720), (673, 720, 246, 306)),\n",
168 | " ('KOTA', tensor(0.9945), (710, 763, 367, 454)),\n",
169 | " ('BANDUNG', tensor(0.9989), (710, 763, 458, 599)),\n",
170 | " ('33', tensor(0.7718), (804, 854, 246, 296)),\n",
171 | " ('order:', tensor(0.9085), (804, 858, 120, 240)),\n",
172 | " ('Kode', tensor(0.9972), (851, 901, 48, 135)),\n",
173 | " ('16-07-2021', tensor(0.9492), (888, 942, 212, 400)),\n",
174 | " ('11:53:47', tensor(0.9910), (891, 942, 411, 562)),\n",
175 | " ('Tanggal', tensor(0.9934), (895, 948, 51, 185)),\n",
176 | " ('BDG', tensor(0.9957), (931, 985, 407, 474)),\n",
177 | " ('Kasiri', tensor(0.5920), (938, 989, 51, 169)),\n",
178 | " ('Kasir', tensor(0.9999), (938, 989, 175, 279)),\n",
179 | " ('Kopo', tensor(0.9991), (938, 989, 316, 404)),\n",
180 | " ('1', tensor(0.5032), (945, 982, 286, 306)),\n",
181 | " ('Pelanggant', tensor(0.6637), (980, 1039, 49, 241)),\n",
182 | " ('gjk', tensor(0.9945), (982, 1036, 246, 313)),\n",
183 | " ('wahyu', tensor(0.9993), (982, 1032, 320, 424)),\n",
184 | " ('GOFOO', tensor(0.9861), (1059, 1116, 673, 780)),\n",
185 | " ('Paket', tensor(0.9999), (1066, 1120, 48, 152)),\n",
186 | " ('Geprek', tensor(0.9998), (1068, 1126, 157, 281)),\n",
187 | " ('Bensu', tensor(0.9999), (1069, 1120, 283, 387)),\n",
188 | " ('Nasi', tensor(0.9995), (1069, 1120, 394, 471)),\n",
189 | " ('Daun', tensor(0.9998), (1069, 1116, 481, 565)),\n",
190 | " ('Jeruk', tensor(0.9994), (1069, 1116, 572, 673)),\n",
191 | " ('27', tensor(0.9975), (1110, 1160, 787, 834)),\n",
192 | " ('Level', tensor(0.9845), (1116, 1167, 125, 222)),\n",
193 | " ('D', tensor(0.9934), (1120, 1160, 51, 81)),\n",
194 | " ('I', tensor(0.6730), (1123, 1160, 236, 256)),\n",
195 | " ('X', tensor(0.8183), (1123, 1160, 266, 296)),\n",
196 | " ('(27', tensor(0.9810), (1160, 1210, 195, 263)),\n",
197 | " ('Harga', tensor(1.0000), (1163, 1214, 88, 189)),\n",
198 | " ('t', tensor(0.5064), (1173, 1204, 58, 78)),\n",
199 | " ('Dada', tensor(0.9999), (1207, 1254, 88, 172)),\n",
200 | " ('X', tensor(0.9218), (1214, 1251, 179, 209)),\n",
201 | " ('17.500', tensor(0.8386), (1241, 1291, 787, 908)),\n",
202 | " ('Original', tensor(0.9989), (1241, 1301, 532, 689)),\n",
203 | " ('GOFOOD', tensor(0.9986), (1244, 1298, 370, 498)),\n",
204 | " ('Bensu', tensor(0.9998), (1247, 1298, 269, 370)),\n",
205 | " ('Ayam', tensor(0.9882), (1251, 1301, 54, 138)),\n",
206 | " ('Geprek', tensor(0.9999), (1251, 1301, 142, 263)),\n",
207 | " ('X', tensor(0.9251), (1251, 1291, 696, 726)),\n",
208 | " ('(17', tensor(0.9848), (1291, 1345, 195, 263)),\n",
209 | " ('Harga', tensor(1.0000), (1293, 1351, 85, 194)),\n",
210 | " ('Dada', tensor(0.9999), (1338, 1389, 88, 175)),\n",
211 | " ('X', tensor(0.9166), (1345, 1385, 179, 209)),\n",
212 | " ('+', tensor(0.9094), (1352, 1382, 58, 81)),\n",
213 | " ('4.000', tensor(0.9562), (1375, 1426, 807, 908)),\n",
214 | " ('Charge', tensor(0.9999), (1380, 1438, 230, 355)),\n",
215 | " ('Take', tensor(0.9999), (1382, 1432, 58, 138)),\n",
216 | " ('Away', tensor(0.9999), (1384, 1437, 143, 231)),\n",
217 | " ('X', tensor(0.8809), (1389, 1426, 357, 390)),\n",
218 | " ('49.000', tensor(0.6948), (1463, 1513, 790, 911)),\n",
219 | " ('Subtotal', tensor(0.9988), (1469, 1523, 51, 209)),\n",
220 | " ('4.500', tensor(0.9441), (1510, 1560, 807, 911)),\n",
221 | " ('PB1', tensor(0.9708), (1516, 1567, 54, 118)),\n",
222 | " ('(10%)', tensor(0.9886), (1516, 1567, 128, 229)),\n",
223 | " ('53,500', tensor(0.7126), (1597, 1647, 790, 911)),\n",
224 | " ('Total', tensor(0.9979), (1607, 1657, 58, 158)),\n",
225 | " ('53.500', tensor(0.8583), (1684, 1738, 790, 911)),\n",
226 | " ('Gobiz', tensor(0.9990), (1692, 1749, 73, 181)),\n",
227 | " ('Kembal', tensor(0.9943), (1738, 1792, 58, 162)),\n",
228 | " ('Ii', tensor(0.7184), (1738, 1789, 152, 195)),\n",
229 | " ('LUNAS', tensor(0.9925), (1825, 1876, 444, 548)),\n",
230 | " ('*x', tensor(0.5476), (1829, 1869, 552, 602)),\n",
231 | " ('Kasih', tensor(0.9972), (1910, 1963, 505, 612)),\n",
232 | " ('Terima', tensor(0.9996), (1913, 1963, 384, 505)),\n",
233 | " ('POS', tensor(0.9836), (2041, 2094, 622, 693)),\n",
234 | " ('by', tensor(0.9994), (2047, 2101, 448, 498)),\n",
235 | " ('Pawoon', tensor(0.9994), (2047, 2098, 498, 622)),\n",
236 | " ('Powered', tensor(0.9999), (2051, 2101, 303, 444))]"
237 | ]
238 | },
239 | "execution_count": 8,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "engine.raw_output"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 9,
251 | "id": "outstanding-munich",
252 | "metadata": {
253 | "scrolled": true
254 | },
255 | "outputs": [
256 | {
257 | "data": {
258 | "text/plain": [
259 | "['Geprek Bensu Kopo Bandung',\n",
260 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n",
261 | " 'KOTA BANDUNG',\n",
262 | " 'order: 33',\n",
263 | " 'Kode',\n",
264 | " 'Tanggal 16-07-2021 11:53:47',\n",
265 | " 'Kasiri Kasir 1 Kopo BDG',\n",
266 | " 'Pelanggant gjk wahyu',\n",
267 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n",
268 | " 'D Level I X 27',\n",
269 | " 't Harga (27',\n",
270 | " 'Dada X',\n",
271 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n",
272 | " 'Harga (17',\n",
273 | " '+ Dada X',\n",
274 | " 'Take Away Charge X 4.000',\n",
275 | " 'Subtotal 49.000',\n",
276 | " 'PB1 (10%) 4.500',\n",
277 | " 'Total 53,500',\n",
278 | " 'Gobiz 53.500',\n",
279 | " 'Kembal Ii',\n",
280 | " 'LUNAS *x',\n",
281 | " 'Terima Kasih']"
282 | ]
283 | },
284 | "execution_count": 9,
285 | "metadata": {},
286 | "output_type": "execute_result"
287 | }
288 | ],
289 | "source": [
290 | "engine.result"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "id": "narrow-satellite",
297 | "metadata": {},
298 | "outputs": [],
299 | "source": []
300 | }
301 | ],
302 | "metadata": {
303 | "kernelspec": {
304 | "display_name": "Python [conda env:receipt-ocr]",
305 | "language": "python",
306 | "name": "conda-env-receipt-ocr-py"
307 | },
308 | "language_info": {
309 | "codemirror_mode": {
310 | "name": "ipython",
311 | "version": 3
312 | },
313 | "file_extension": ".py",
314 | "mimetype": "text/x-python",
315 | "name": "python",
316 | "nbconvert_exporter": "python",
317 | "pygments_lexer": "ipython3",
318 | "version": "3.9.7"
319 | }
320 | },
321 | "nbformat": 4,
322 | "nbformat_minor": 5
323 | }
324 |
--------------------------------------------------------------------------------
/notebooks/inference_default_engine.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "romantic-checklist",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import sys\n",
11 | "import io\n",
12 | "import base64\n",
13 | "import numpy as np\n",
14 | "import warnings\n",
15 | "from PIL import Image\n",
16 | "\n",
17 | "warnings.filterwarnings(\"ignore\")\n",
18 | "sys.path.append('..')\n",
19 | "\n",
20 | "from src.model import DefaultModel\n",
21 | "from src.engine import DefaultEngine"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "id": "hourly-russell",
27 | "metadata": {},
28 | "source": [
29 | "## Load model network and weight"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "id": "covered-closure",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "detector_cfg = '../configs/craft_config.yaml'\n",
40 | "detector_model = '../models/text_detector/craft_mlt_25k.pth'\n",
41 | "recognizer_cfg = '../configs/star_config.yaml'\n",
42 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "id": "intense-ordinance",
49 | "metadata": {},
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "Loading weights from checkpoint (../models/text_detector/craft_mlt_25k.pth)\n",
56 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "model = DefaultModel(detector_cfg, detector_model, \n",
62 | " recognizer_cfg, recognizer_model)"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "continued-shift",
68 | "metadata": {},
69 | "source": [
70 | "## Load Engine"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "id": "marine-legend",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "engine = DefaultEngine(model)"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "id": "insured-cooperation",
86 | "metadata": {},
87 | "source": [
88 | "## Input Data Preparation"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "id": "cardiac-mother",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "image_file = '../data/tes.jpg'\n",
99 | "\n",
100 | "with open(image_file, \"rb\") as f:\n",
101 | " im_bytes = f.read()\n",
102 | "\n",
103 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 6,
109 | "id": "weird-psychology",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "img_bytes = base64.b64decode(im_b64.encode('utf-8'))\n",
114 | "image = Image.open(io.BytesIO(img_bytes))\n",
115 | "image = np.array(image)"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "id": "adjustable-prevention",
121 | "metadata": {},
122 | "source": [
123 | "## Prediction"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 7,
129 | "id": "southern-connecticut",
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "'inference_detector' 2.88 s\n",
137 | "'inference_recognizer' 3.67 s\n",
138 | "'predict' 6.56 s\n"
139 | ]
140 | }
141 | ],
142 | "source": [
143 | "engine.predict(image)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 8,
149 | "id": "composite-royalty",
150 | "metadata": {
151 | "scrolled": true
152 | },
153 | "outputs": [
154 | {
155 | "data": {
156 | "text/plain": [
157 | "[('Kopo', tensor(0.9991), (620, 680, 481, 575)),\n",
158 | " ('Bandung', tensor(1.0000), (622, 673, 575, 713)),\n",
159 | " ('Geprek', tensor(1.0000), (626, 676, 253, 374)),\n",
160 | " ('Bensu', tensor(0.9999), (626, 676, 380, 481)),\n",
161 | " ('Babakan', tensor(0.9999), (666, 716, 602, 740)),\n",
162 | " ('Ciparay', tensor(0.9997), (666, 716, 743, 884)),\n",
163 | " ('536,', tensor(0.9938), (666, 720, 313, 397)),\n",
164 | " ('Margasuka,', tensor(0.9953), (669, 723, 404, 592)),\n",
165 | " ('J1,', tensor(0.5975), (673, 723, 85, 145)),\n",
166 | " ('Kopo', tensor(0.9966), (673, 723, 155, 239)),\n",
167 | " ('No.', tensor(0.7720), (673, 720, 246, 306)),\n",
168 | " ('KOTA', tensor(0.9945), (710, 763, 367, 454)),\n",
169 | " ('BANDUNG', tensor(0.9989), (710, 763, 458, 599)),\n",
170 | " ('33', tensor(0.7718), (804, 854, 246, 296)),\n",
171 | " ('order:', tensor(0.9085), (804, 858, 120, 240)),\n",
172 | " ('Kode', tensor(0.9972), (851, 901, 48, 135)),\n",
173 | " ('16-07-2021', tensor(0.9492), (888, 942, 212, 400)),\n",
174 | " ('11:53:47', tensor(0.9910), (891, 942, 411, 562)),\n",
175 | " ('Tanggal', tensor(0.9934), (895, 948, 51, 185)),\n",
176 | " ('BDG', tensor(0.9957), (931, 985, 407, 474)),\n",
177 | " ('Kasiri', tensor(0.5920), (938, 989, 51, 169)),\n",
178 | " ('Kasir', tensor(0.9999), (938, 989, 175, 279)),\n",
179 | " ('Kopo', tensor(0.9991), (938, 989, 316, 404)),\n",
180 | " ('1', tensor(0.5032), (945, 982, 286, 306)),\n",
181 | " ('Pelanggant', tensor(0.6637), (980, 1039, 49, 241)),\n",
182 | " ('gjk', tensor(0.9945), (982, 1036, 246, 313)),\n",
183 | " ('wahyu', tensor(0.9993), (982, 1032, 320, 424)),\n",
184 | " ('GOFOO', tensor(0.9861), (1059, 1116, 673, 780)),\n",
185 | " ('Paket', tensor(0.9999), (1066, 1120, 48, 152)),\n",
186 | " ('Geprek', tensor(0.9998), (1068, 1126, 157, 281)),\n",
187 | " ('Bensu', tensor(0.9999), (1069, 1120, 283, 387)),\n",
188 | " ('Nasi', tensor(0.9995), (1069, 1120, 394, 471)),\n",
189 | " ('Daun', tensor(0.9998), (1069, 1116, 481, 565)),\n",
190 | " ('Jeruk', tensor(0.9994), (1069, 1116, 572, 673)),\n",
191 | " ('27', tensor(0.9975), (1110, 1160, 787, 834)),\n",
192 | " ('Level', tensor(0.9845), (1116, 1167, 125, 222)),\n",
193 | " ('D', tensor(0.9934), (1120, 1160, 51, 81)),\n",
194 | " ('I', tensor(0.6730), (1123, 1160, 236, 256)),\n",
195 | " ('X', tensor(0.8183), (1123, 1160, 266, 296)),\n",
196 | " ('(27', tensor(0.9810), (1160, 1210, 195, 263)),\n",
197 | " ('Harga', tensor(1.0000), (1163, 1214, 88, 189)),\n",
198 | " ('t', tensor(0.5064), (1173, 1204, 58, 78)),\n",
199 | " ('Dada', tensor(0.9999), (1207, 1254, 88, 172)),\n",
200 | " ('X', tensor(0.9218), (1214, 1251, 179, 209)),\n",
201 | " ('17.500', tensor(0.8386), (1241, 1291, 787, 908)),\n",
202 | " ('Original', tensor(0.9989), (1241, 1301, 532, 689)),\n",
203 | " ('GOFOOD', tensor(0.9986), (1244, 1298, 370, 498)),\n",
204 | " ('Bensu', tensor(0.9998), (1247, 1298, 269, 370)),\n",
205 | " ('Ayam', tensor(0.9882), (1251, 1301, 54, 138)),\n",
206 | " ('Geprek', tensor(0.9999), (1251, 1301, 142, 263)),\n",
207 | " ('X', tensor(0.9251), (1251, 1291, 696, 726)),\n",
208 | " ('(17', tensor(0.9848), (1291, 1345, 195, 263)),\n",
209 | " ('Harga', tensor(1.0000), (1293, 1351, 85, 194)),\n",
210 | " ('Dada', tensor(0.9999), (1338, 1389, 88, 175)),\n",
211 | " ('X', tensor(0.9166), (1345, 1385, 179, 209)),\n",
212 | " ('+', tensor(0.9094), (1352, 1382, 58, 81)),\n",
213 | " ('4.000', tensor(0.9562), (1375, 1426, 807, 908)),\n",
214 | " ('Charge', tensor(0.9999), (1380, 1438, 230, 355)),\n",
215 | " ('Take', tensor(0.9999), (1382, 1432, 58, 138)),\n",
216 | " ('Away', tensor(0.9999), (1384, 1437, 143, 231)),\n",
217 | " ('X', tensor(0.8809), (1389, 1426, 357, 390)),\n",
218 | " ('49.000', tensor(0.6948), (1463, 1513, 790, 911)),\n",
219 | " ('Subtotal', tensor(0.9988), (1469, 1523, 51, 209)),\n",
220 | " ('4.500', tensor(0.9441), (1510, 1560, 807, 911)),\n",
221 | " ('PB1', tensor(0.9708), (1516, 1567, 54, 118)),\n",
222 | " ('(10%)', tensor(0.9886), (1516, 1567, 128, 229)),\n",
223 | " ('53,500', tensor(0.7126), (1597, 1647, 790, 911)),\n",
224 | " ('Total', tensor(0.9979), (1607, 1657, 58, 158)),\n",
225 | " ('53.500', tensor(0.8583), (1684, 1738, 790, 911)),\n",
226 | " ('Gobiz', tensor(0.9990), (1692, 1749, 73, 181)),\n",
227 | " ('Kembal', tensor(0.9943), (1738, 1792, 58, 162)),\n",
228 | " ('Ii', tensor(0.7184), (1738, 1789, 152, 195)),\n",
229 | " ('LUNAS', tensor(0.9925), (1825, 1876, 444, 548)),\n",
230 | " ('*x', tensor(0.5476), (1829, 1869, 552, 602)),\n",
231 | " ('Kasih', tensor(0.9972), (1910, 1963, 505, 612)),\n",
232 | " ('Terima', tensor(0.9996), (1913, 1963, 384, 505)),\n",
233 | " ('POS', tensor(0.9836), (2041, 2094, 622, 693)),\n",
234 | " ('by', tensor(0.9994), (2047, 2101, 448, 498)),\n",
235 | " ('Pawoon', tensor(0.9994), (2047, 2098, 498, 622)),\n",
236 | " ('Powered', tensor(0.9999), (2051, 2101, 303, 444))]"
237 | ]
238 | },
239 | "execution_count": 8,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "engine.raw_output"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 9,
251 | "id": "helpful-madrid",
252 | "metadata": {
253 | "scrolled": true
254 | },
255 | "outputs": [
256 | {
257 | "data": {
258 | "text/plain": [
259 | "['Geprek Bensu Kopo Bandung',\n",
260 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n",
261 | " 'KOTA BANDUNG',\n",
262 | " 'order: 33',\n",
263 | " 'Kode',\n",
264 | " 'Tanggal 16-07-2021 11:53:47',\n",
265 | " 'Kasiri Kasir 1 Kopo BDG',\n",
266 | " 'Pelanggant gjk wahyu',\n",
267 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n",
268 | " 'D Level I X 27',\n",
269 | " 't Harga (27',\n",
270 | " 'Dada X',\n",
271 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n",
272 | " 'Harga (17',\n",
273 | " '+ Dada X',\n",
274 | " 'Take Away Charge X 4.000',\n",
275 | " 'Subtotal 49.000',\n",
276 | " 'PB1 (10%) 4.500',\n",
277 | " 'Total 53,500',\n",
278 | " 'Gobiz 53.500',\n",
279 | " 'Kembal Ii',\n",
280 | " 'LUNAS *x',\n",
281 | " 'Terima Kasih']"
282 | ]
283 | },
284 | "execution_count": 9,
285 | "metadata": {},
286 | "output_type": "execute_result"
287 | }
288 | ],
289 | "source": [
290 | "engine.result"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "id": "valid-dominican",
297 | "metadata": {},
298 | "outputs": [],
299 | "source": []
300 | }
301 | ],
302 | "metadata": {
303 | "kernelspec": {
304 | "display_name": "Python [conda env:receipt-ocr]",
305 | "language": "python",
306 | "name": "conda-env-receipt-ocr-py"
307 | },
308 | "language_info": {
309 | "codemirror_mode": {
310 | "name": "ipython",
311 | "version": 3
312 | },
313 | "file_extension": ".py",
314 | "mimetype": "text/x-python",
315 | "name": "python",
316 | "nbconvert_exporter": "python",
317 | "pygments_lexer": "ipython3",
318 | "version": "3.9.7"
319 | }
320 | },
321 | "nbformat": 4,
322 | "nbformat_minor": 5
323 | }
324 |
--------------------------------------------------------------------------------
/src/text_recognizer/modules/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class VGG_FeatureExtractor(nn.Module):
6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
7 |
8 | def __init__(self, input_channel, output_channel=512):
9 | super(VGG_FeatureExtractor, self).__init__()
10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4),
11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512]
12 | self.ConvNet = nn.Sequential(
13 | nn.Conv2d(input_channel,
14 | self.output_channel[0], 3, 1, 1), nn.ReLU(True),
15 | nn.MaxPool2d(2, 2), # 64x16x50
16 | nn.Conv2d(
17 | self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
18 | nn.MaxPool2d(2, 2), # 128x8x25
19 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(
20 | True), # 256x8x25
21 | nn.Conv2d(
22 | self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
23 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
24 | nn.Conv2d(
25 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
26 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
27 | nn.Conv2d(
28 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
29 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
30 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
31 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
32 |
33 | def forward(self, input):
34 | return self.ConvNet(input)
35 |
36 |
37 | class RCNN_FeatureExtractor(nn.Module):
38 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
39 |
40 | def __init__(self, input_channel, output_channel=512):
41 | super(RCNN_FeatureExtractor, self).__init__()
42 | self.output_channel = [int(output_channel / 8), int(output_channel / 4),
43 | int(output_channel / 2), output_channel] # [64, 128, 256, 512]
44 | self.ConvNet = nn.Sequential(
45 | nn.Conv2d(input_channel,
46 | self.output_channel[0], 3, 1, 1), nn.ReLU(True),
47 | nn.MaxPool2d(2, 2), # 64 x 16 x 50
48 | GRCL(self.output_channel[0], self.output_channel[0],
49 | num_iteration=5, kernel_size=3, pad=1),
50 | nn.MaxPool2d(2, 2), # 64 x 8 x 25
51 | GRCL(self.output_channel[0], self.output_channel[1],
52 | num_iteration=5, kernel_size=3, pad=1),
53 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26
54 | GRCL(self.output_channel[1], self.output_channel[2],
55 | num_iteration=5, kernel_size=3, pad=1),
56 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27
57 | nn.Conv2d(
58 | self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False),
59 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26
60 |
61 | def forward(self, input):
62 | return self.ConvNet(input)
63 |
64 |
65 | class ResNet_FeatureExtractor(nn.Module):
66 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
67 |
68 | def __init__(self, input_channel, output_channel=512):
69 | super(ResNet_FeatureExtractor, self).__init__()
70 | self.ConvNet = ResNet(input_channel, output_channel,
71 | BasicBlock, [1, 2, 5, 3])
72 |
73 | def forward(self, input):
74 | return self.ConvNet(input)
75 |
76 |
77 | # For Gated RCNN
78 | class GRCL(nn.Module):
79 |
80 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad):
81 | super(GRCL, self).__init__()
82 | self.wgf_u = nn.Conv2d(
83 | input_channel, output_channel, 1, 1, 0, bias=False)
84 | self.wgr_x = nn.Conv2d(
85 | output_channel, output_channel, 1, 1, 0, bias=False)
86 | self.wf_u = nn.Conv2d(input_channel, output_channel,
87 | kernel_size, 1, pad, bias=False)
88 | self.wr_x = nn.Conv2d(output_channel, output_channel,
89 | kernel_size, 1, pad, bias=False)
90 |
91 | self.BN_x_init = nn.BatchNorm2d(output_channel)
92 |
93 | self.num_iteration = num_iteration
94 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)]
95 | self.GRCL = nn.Sequential(*self.GRCL)
96 |
97 | def forward(self, input):
98 | """ The input of GRCL is consistant over time t, which is denoted by u(0)
99 | thus wgf_u / wf_u is also consistant over time t.
100 | """
101 | wgf_u = self.wgf_u(input)
102 | wf_u = self.wf_u(input)
103 | x = F.relu(self.BN_x_init(wf_u))
104 |
105 | for i in range(self.num_iteration):
106 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x))
107 |
108 | return x
109 |
110 |
111 | class GRCL_unit(nn.Module):
112 |
113 | def __init__(self, output_channel):
114 | super(GRCL_unit, self).__init__()
115 | self.BN_gfu = nn.BatchNorm2d(output_channel)
116 | self.BN_grx = nn.BatchNorm2d(output_channel)
117 | self.BN_fu = nn.BatchNorm2d(output_channel)
118 | self.BN_rx = nn.BatchNorm2d(output_channel)
119 | self.BN_Gx = nn.BatchNorm2d(output_channel)
120 |
121 | def forward(self, wgf_u, wgr_x, wf_u, wr_x):
122 | G_first_term = self.BN_gfu(wgf_u)
123 | G_second_term = self.BN_grx(wgr_x)
124 | G = F.sigmoid(G_first_term + G_second_term)
125 |
126 | x_first_term = self.BN_fu(wf_u)
127 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G)
128 | x = F.relu(x_first_term + x_second_term)
129 |
130 | return x
131 |
132 |
133 | class BasicBlock(nn.Module):
134 | expansion = 1
135 |
136 | def __init__(self, inplanes, planes, stride=1, downsample=None):
137 | super(BasicBlock, self).__init__()
138 | self.conv1 = self._conv3x3(inplanes, planes)
139 | self.bn1 = nn.BatchNorm2d(planes)
140 | self.conv2 = self._conv3x3(planes, planes)
141 | self.bn2 = nn.BatchNorm2d(planes)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.downsample = downsample
144 | self.stride = stride
145 |
146 | def _conv3x3(self, in_planes, out_planes, stride=1):
147 | "3x3 convolution with padding"
148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
149 | padding=1, bias=False)
150 |
151 | def forward(self, x):
152 | residual = x
153 |
154 | out = self.conv1(x)
155 | out = self.bn1(out)
156 | out = self.relu(out)
157 |
158 | out = self.conv2(out)
159 | out = self.bn2(out)
160 |
161 | if self.downsample is not None:
162 | residual = self.downsample(x)
163 | out += residual
164 | out = self.relu(out)
165 |
166 | return out
167 |
168 |
169 | class ResNet(nn.Module):
170 |
171 | def __init__(self, input_channel, output_channel, block, layers):
172 | super(ResNet, self).__init__()
173 |
174 | self.output_channel_block = [
175 | int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
176 |
177 | self.inplanes = int(output_channel / 8)
178 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
179 | kernel_size=3, stride=1, padding=1, bias=False)
180 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
181 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
182 | kernel_size=3, stride=1, padding=1, bias=False)
183 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
184 | self.relu = nn.ReLU(inplace=True)
185 |
186 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
187 | self.layer1 = self._make_layer(
188 | block, self.output_channel_block[0], layers[0])
189 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
190 | 0], kernel_size=3, stride=1, padding=1, bias=False)
191 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
192 |
193 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
194 | self.layer2 = self._make_layer(
195 | block, self.output_channel_block[1], layers[1], stride=1)
196 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
197 | 1], kernel_size=3, stride=1, padding=1, bias=False)
198 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
199 |
200 | self.maxpool3 = nn.MaxPool2d(
201 | kernel_size=2, stride=(2, 1), padding=(0, 1))
202 | self.layer3 = self._make_layer(
203 | block, self.output_channel_block[2], layers[2], stride=1)
204 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
205 | 2], kernel_size=3, stride=1, padding=1, bias=False)
206 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
207 |
208 | self.layer4 = self._make_layer(
209 | block, self.output_channel_block[3], layers[3], stride=1)
210 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
211 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
212 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
213 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
214 | 3], kernel_size=2, stride=1, padding=0, bias=False)
215 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
216 |
217 | def _make_layer(self, block, planes, blocks, stride=1):
218 | downsample = None
219 | if stride != 1 or self.inplanes != planes * block.expansion:
220 | downsample = nn.Sequential(
221 | nn.Conv2d(self.inplanes, planes * block.expansion,
222 | kernel_size=1, stride=stride, bias=False),
223 | nn.BatchNorm2d(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample))
228 | self.inplanes = planes * block.expansion
229 | for i in range(1, blocks):
230 | layers.append(block(self.inplanes, planes))
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def forward(self, x):
235 | x = self.conv0_1(x)
236 | x = self.bn0_1(x)
237 | x = self.relu(x)
238 | x = self.conv0_2(x)
239 | x = self.bn0_2(x)
240 | x = self.relu(x)
241 |
242 | x = self.maxpool1(x)
243 | x = self.layer1(x)
244 | x = self.conv1(x)
245 | x = self.bn1(x)
246 | x = self.relu(x)
247 |
248 | x = self.maxpool2(x)
249 | x = self.layer2(x)
250 | x = self.conv2(x)
251 | x = self.bn2(x)
252 | x = self.relu(x)
253 |
254 | x = self.maxpool3(x)
255 | x = self.layer3(x)
256 | x = self.conv3(x)
257 | x = self.bn3(x)
258 | x = self.relu(x)
259 |
260 | x = self.layer4(x)
261 | x = self.conv4_1(x)
262 | x = self.bn4_1(x)
263 | x = self.relu(x)
264 | x = self.conv4_2(x)
265 | x = self.bn4_2(x)
266 | x = self.relu(x)
267 |
268 | return x
269 |
--------------------------------------------------------------------------------
/notebooks/export_onnx_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "enormous-reward",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import sys\n",
11 | "import warnings\n",
12 | "import onnx\n",
13 | "import torch\n",
14 | "import torch.onnx\n",
15 | "\n",
16 | "warnings.filterwarnings(\"ignore\")\n",
17 | "sys.path.append('..')\n",
18 | "\n",
19 | "from src.model import DefaultModel"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "id": "stock-christian",
26 | "metadata": {},
27 | "outputs": [
28 | {
29 | "name": "stdout",
30 | "output_type": "stream",
31 | "text": [
32 | "cpu\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
38 | "print(device)"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "id": "partial-slope",
44 | "metadata": {},
45 | "source": [
46 | "## Load model network and weight"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "id": "coastal-brunei",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "detector_cfg = '../configs/craft_config.yaml'\n",
57 | "detector_model = '../models/text_detector/craft_mlt_25k.pth'\n",
58 | "recognizer_cfg = '../configs/star_config.yaml'\n",
59 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "id": "brief-victor",
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stdout",
70 | "output_type": "stream",
71 | "text": [
72 | "Loading weights from checkpoint (../models/text_detector/craft_mlt_25k.pth)\n",
73 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n"
74 | ]
75 | }
76 | ],
77 | "source": [
78 | "model = DefaultModel(detector_cfg, detector_model, \n",
79 | " recognizer_cfg, recognizer_model)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "id": "unusual-fence",
85 | "metadata": {},
86 | "source": [
87 | "# Detector"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "congressional-enforcement",
93 | "metadata": {},
94 | "source": [
95 | "## Exporter Model\n",
96 | "Batch Size X Channel X Height X Width"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 5,
102 | "id": "lasting-identification",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "detector_dummy_input = torch.randn(1, 3, 1280, 720)"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 6,
112 | "id": "expanded-enzyme",
113 | "metadata": {
114 | "scrolled": true
115 | },
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | "(tensor([[[[0.0010, 0.0002],\n",
121 | " [0.0085, 0.0020],\n",
122 | " [0.0010, 0.0002],\n",
123 | " ...,\n",
124 | " [0.0010, 0.0002],\n",
125 | " [0.0019, 0.0016],\n",
126 | " [0.0010, 0.0002]],\n",
127 | " \n",
128 | " [[0.0022, 0.0016],\n",
129 | " [0.0010, 0.0002],\n",
130 | " [0.0010, 0.0002],\n",
131 | " ...,\n",
132 | " [0.0010, 0.0002],\n",
133 | " [0.0010, 0.0002],\n",
134 | " [0.0010, 0.0002]],\n",
135 | " \n",
136 | " [[0.0010, 0.0002],\n",
137 | " [0.0010, 0.0002],\n",
138 | " [0.0010, 0.0002],\n",
139 | " ...,\n",
140 | " [0.0010, 0.0002],\n",
141 | " [0.0010, 0.0002],\n",
142 | " [0.0010, 0.0002]],\n",
143 | " \n",
144 | " ...,\n",
145 | " \n",
146 | " [[0.0010, 0.0002],\n",
147 | " [0.0010, 0.0002],\n",
148 | " [0.0010, 0.0002],\n",
149 | " ...,\n",
150 | " [0.0010, 0.0002],\n",
151 | " [0.0010, 0.0002],\n",
152 | " [0.0010, 0.0002]],\n",
153 | " \n",
154 | " [[0.0128, 0.0013],\n",
155 | " [0.0056, 0.0013],\n",
156 | " [0.0065, 0.0015],\n",
157 | " ...,\n",
158 | " [0.0010, 0.0002],\n",
159 | " [0.0010, 0.0002],\n",
160 | " [0.0020, 0.0006]],\n",
161 | " \n",
162 | " [[0.0065, 0.0086],\n",
163 | " [0.0084, 0.0021],\n",
164 | " [0.0093, 0.0018],\n",
165 | " ...,\n",
166 | " [0.0037, 0.0011],\n",
167 | " [0.0024, 0.0007],\n",
168 | " [0.0100, 0.0079]]]], grad_fn=),\n",
169 | " tensor([[[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0552],\n",
170 | " [0.0870, 0.0165, 0.2143, ..., 0.0000, 0.0000, 0.0886],\n",
171 | " [0.0047, 0.0641, 0.2276, ..., 0.0000, 0.0000, 0.0913],\n",
172 | " ...,\n",
173 | " [0.0000, 0.0000, 0.2815, ..., 1.2193, 1.1758, 1.3355],\n",
174 | " [0.0000, 0.0000, 0.4303, ..., 1.3450, 1.2745, 1.4509],\n",
175 | " [0.0000, 0.0000, 0.2416, ..., 0.8716, 0.9751, 1.0762]],\n",
176 | " \n",
177 | " [[0.1752, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.3734],\n",
178 | " [0.1021, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0797],\n",
179 | " [0.1667, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0534],\n",
180 | " ...,\n",
181 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
182 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
183 | " [0.0576, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.1832]],\n",
184 | " \n",
185 | " [[1.0205, 0.6065, 0.8373, ..., 0.2779, 0.2694, 0.6539],\n",
186 | " [0.9442, 0.3013, 0.7273, ..., 0.6779, 0.8180, 1.2405],\n",
187 | " [1.3239, 0.6361, 0.8955, ..., 0.7086, 0.7718, 1.0392],\n",
188 | " ...,\n",
189 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.2606, 2.2630],\n",
190 | " [0.0000, 0.0000, 0.0000, ..., 0.4381, 0.4628, 1.9584],\n",
191 | " [0.0000, 0.0000, 0.0000, ..., 0.2608, 0.3710, 1.2387]],\n",
192 | " \n",
193 | " ...,\n",
194 | " \n",
195 | " [[0.0000, 0.0000, 0.0221, ..., 0.0000, 0.0000, 0.0000],\n",
196 | " [0.0000, 0.2290, 0.2829, ..., 0.0000, 0.0000, 0.0000],\n",
197 | " [0.0000, 0.0408, 0.0449, ..., 0.0000, 0.0000, 0.0000],\n",
198 | " ...,\n",
199 | " [0.7339, 1.3421, 1.8549, ..., 2.7119, 2.3538, 1.3433],\n",
200 | " [0.2865, 0.6955, 0.9729, ..., 1.9042, 1.6367, 1.0503],\n",
201 | " [0.3265, 0.5312, 0.6013, ..., 1.1321, 0.9693, 0.7332]],\n",
202 | " \n",
203 | " [[0.3608, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
204 | " [0.0601, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
205 | " [0.0870, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
206 | " ...,\n",
207 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
208 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
209 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n",
210 | " \n",
211 | " [[0.0937, 0.1586, 0.0410, ..., 0.1801, 0.1988, 0.4503],\n",
212 | " [0.0000, 0.0638, 0.1298, ..., 0.0000, 0.0304, 0.4700],\n",
213 | " [0.0000, 0.1438, 0.1875, ..., 0.0000, 0.1855, 0.5608],\n",
214 | " ...,\n",
215 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
216 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
217 | " [0.0997, 0.0000, 0.0000, ..., 0.0000, 0.2321, 0.6662]]]],\n",
218 | " grad_fn=))"
219 | ]
220 | },
221 | "execution_count": 6,
222 | "metadata": {},
223 | "output_type": "execute_result"
224 | }
225 | ],
226 | "source": [
227 | "model.detector(detector_dummy_input)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 7,
233 | "id": "devoted-information",
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "out_detector_model = '../models/text_detector/craft.onnx'"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 8,
243 | "id": "arbitrary-valuation",
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "# Export the model\n",
248 | "torch.onnx.export(model.detector, \n",
249 | " detector_dummy_input,\n",
250 | " out_detector_model,\n",
251 | " export_params=True,\n",
252 | " opset_version=13,\n",
253 | " do_constant_folding=True,\n",
254 | " input_names = ['input'],\n",
255 | " output_names = ['output'],\n",
256 | " dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},\n",
257 | " 'output' : {0:'batch_size'}})"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "id": "reflected-librarian",
263 | "metadata": {},
264 | "source": [
265 | "## Inspecting Model"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 9,
271 | "id": "cognitive-albania",
272 | "metadata": {
273 | "scrolled": true
274 | },
275 | "outputs": [
276 | {
277 | "name": "stdout",
278 | "output_type": "stream",
279 | "text": [
280 | "graph torch-jit-export (\n",
281 | " %input[FLOAT, batch_sizex3xheightxwidth]\n",
282 | ") initializers (\n",
283 | " %basenet.slice5.1.weight[FLOAT, 1024x512x3x3]\n",
284 | " %basenet.slice5.1.bias[FLOAT, 1024]\n",
285 | " %basenet.slice5.2.weight[FLOAT, 1024x1024x1x1]\n",
286 | " %basenet.slice5.2.bias[FLOAT, 1024]\n",
287 | " %conv_cls.0.weight[FLOAT, 32x32x3x3]\n",
288 | " %conv_cls.0.bias[FLOAT, 32]\n",
289 | " %conv_cls.2.weight[FLOAT, 32x32x3x3]\n",
290 | " %conv_cls.2.bias[FLOAT, 32]\n",
291 | " %conv_cls.4.weight[FLOAT, 16x32x3x3]\n",
292 | " %conv_cls.4.bias[FLOAT, 16]\n",
293 | " %conv_cls.6.weight[FLOAT, 16x16x1x1]\n",
294 | " %conv_cls.6.bias[FLOAT, 16]\n",
295 | " %conv_cls.8.weight[FLOAT, 2x16x1x1]\n",
296 | " %conv_cls.8.bias[FLOAT, 2]\n",
297 | " %299[FLOAT, 64x3x3x3]\n",
298 | " %300[FLOAT, 64]\n",
299 | " %302[FLOAT, 64x64x3x3]\n",
300 | " %303[FLOAT, 64]\n",
301 | " %305[FLOAT, 128x64x3x3]\n",
302 | " %306[FLOAT, 128]\n",
303 | " %308[FLOAT, 128x128x3x3]\n",
304 | " %309[FLOAT, 128]\n",
305 | " %311[FLOAT, 256x128x3x3]\n",
306 | " %312[FLOAT, 256]\n",
307 | " %314[FLOAT, 256x256x3x3]\n",
308 | " %315[FLOAT, 256]\n",
309 | " %317[FLOAT, 256x256x3x3]\n",
310 | " %318[FLOAT, 256]\n",
311 | " %320[FLOAT, 512x256x3x3]\n",
312 | " %321[FLOAT, 512]\n",
313 | " %323[FLOAT, 512x512x3x3]\n",
314 | " %324[FLOAT, 512]\n",
315 | " %326[FLOAT, 512x512x3x3]\n",
316 | " %327[FLOAT, 512]\n",
317 | " %329[FLOAT, 512x512x3x3]\n",
318 | " %330[FLOAT, 512]\n",
319 | " %332[FLOAT, 512x512x3x3]\n",
320 | " %333[FLOAT, 512]\n",
321 | " %335[FLOAT, 512x1536x1x1]\n",
322 | " %336[FLOAT, 512]\n",
323 | " %338[FLOAT, 256x512x3x3]\n",
324 | " %339[FLOAT, 256]\n",
325 | " %341[FLOAT, 256x768x1x1]\n",
326 | " %342[FLOAT, 256]\n",
327 | " %344[FLOAT, 128x256x3x3]\n",
328 | " %345[FLOAT, 128]\n",
329 | " %347[FLOAT, 128x384x1x1]\n",
330 | " %348[FLOAT, 128]\n",
331 | " %350[FLOAT, 64x128x3x3]\n",
332 | " %351[FLOAT, 64]\n",
333 | " %353[FLOAT, 64x192x1x1]\n",
334 | " %354[FLOAT, 64]\n",
335 | " %356[FLOAT, 32x64x3x3]\n",
336 | " %357[FLOAT, 32]\n",
337 | ") {\n",
338 | " %298 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%input, %299, %300)\n",
339 | " %157 = Relu(%298)\n",
340 | " %301 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%157, %302, %303)\n",
341 | " %160 = Relu(%301)\n",
342 | " %161 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%160)\n",
343 | " %304 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%161, %305, %306)\n",
344 | " %164 = Relu(%304)\n",
345 | " %307 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%164, %308, %309)\n",
346 | " %167 = Relu(%307)\n",
347 | " %168 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%167)\n",
348 | " %310 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%168, %311, %312)\n",
349 | " %171 = Relu(%310)\n",
350 | " %313 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%171, %314, %315)\n",
351 | " %174 = Relu(%313)\n",
352 | " %316 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%174, %317, %318)\n",
353 | " %177 = Relu(%316)\n",
354 | " %178 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%177)\n",
355 | " %319 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%178, %320, %321)\n",
356 | " %181 = Relu(%319)\n",
357 | " %322 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%181, %323, %324)\n",
358 | " %184 = Relu(%322)\n",
359 | " %325 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%184, %326, %327)\n",
360 | " %187 = Relu(%325)\n",
361 | " %188 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%187)\n",
362 | " %328 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%188, %329, %330)\n",
363 | " %191 = Relu(%328)\n",
364 | " %331 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%191, %332, %333)\n",
365 | " %194 = MaxPool[ceil_mode = 0, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%331)\n",
366 | " %195 = Conv[dilations = [6, 6], group = 1, kernel_shape = [3, 3], pads = [6, 6, 6, 6], strides = [1, 1]](%194, %basenet.slice5.1.weight, %basenet.slice5.1.bias)\n",
367 | " %196 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%195, %basenet.slice5.2.weight, %basenet.slice5.2.bias)\n",
368 | " %197 = Concat[axis = 1](%196, %331)\n",
369 | " %334 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%197, %335, %336)\n",
370 | " %200 = Relu(%334)\n",
371 | " %337 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%200, %338, %339)\n",
372 | " %203 = Relu(%337)\n",
373 | " %204 = Shape(%184)\n",
374 | " %205 = Constant[value = ]()\n",
375 | " %206 = Gather[axis = 0](%204, %205)\n",
376 | " %207 = Shape(%184)\n",
377 | " %208 = Constant[value = ]()\n",
378 | " %209 = Gather[axis = 0](%207, %208)\n",
379 | " %210 = Constant[value = ]()\n",
380 | " %211 = Unsqueeze(%206, %210)\n",
381 | " %212 = Constant[value = ]()\n",
382 | " %213 = Unsqueeze(%209, %212)\n",
383 | " %214 = Concat[axis = 0](%211, %213)\n",
384 | " %215 = Shape(%203)\n",
385 | " %216 = Constant[value = ]()\n",
386 | " %217 = Constant[value = ]()\n",
387 | " %218 = Constant[value = ]()\n",
388 | " %219 = Slice(%215, %217, %218, %216)\n",
389 | " %220 = Cast[to = 7](%214)\n",
390 | " %221 = Concat[axis = 0](%219, %220)\n",
391 | " %224 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%203, %, %, %221)\n",
392 | " %225 = Concat[axis = 1](%224, %184)\n",
393 | " %340 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%225, %341, %342)\n",
394 | " %228 = Relu(%340)\n",
395 | " %343 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%228, %344, %345)\n",
396 | " %231 = Relu(%343)\n",
397 | " %232 = Shape(%174)\n",
398 | " %233 = Constant[value = ]()\n",
399 | " %234 = Gather[axis = 0](%232, %233)\n",
400 | " %235 = Shape(%174)\n",
401 | " %236 = Constant[value = ]()\n",
402 | " %237 = Gather[axis = 0](%235, %236)\n",
403 | " %238 = Constant[value = ]()\n",
404 | " %239 = Unsqueeze(%234, %238)\n",
405 | " %240 = Constant[value = ]()\n",
406 | " %241 = Unsqueeze(%237, %240)\n",
407 | " %242 = Concat[axis = 0](%239, %241)\n",
408 | " %243 = Shape(%231)\n",
409 | " %244 = Constant[value = ]()\n",
410 | " %245 = Constant[value = ]()\n",
411 | " %246 = Constant[value = ]()\n",
412 | " %247 = Slice(%243, %245, %246, %244)\n",
413 | " %248 = Cast[to = 7](%242)\n",
414 | " %249 = Concat[axis = 0](%247, %248)\n",
415 | " %252 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%231, %, %, %249)\n",
416 | " %253 = Concat[axis = 1](%252, %174)\n",
417 | " %346 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%253, %347, %348)\n",
418 | " %256 = Relu(%346)\n",
419 | " %349 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%256, %350, %351)\n",
420 | " %259 = Relu(%349)\n",
421 | " %260 = Shape(%167)\n",
422 | " %261 = Constant[value = ]()\n",
423 | " %262 = Gather[axis = 0](%260, %261)\n",
424 | " %263 = Shape(%167)\n",
425 | " %264 = Constant[value = ]()\n",
426 | " %265 = Gather[axis = 0](%263, %264)\n",
427 | " %266 = Constant[value = ]()\n",
428 | " %267 = Unsqueeze(%262, %266)\n",
429 | " %268 = Constant[value = ]()\n",
430 | " %269 = Unsqueeze(%265, %268)\n",
431 | " %270 = Concat[axis = 0](%267, %269)\n",
432 | " %271 = Shape(%259)\n",
433 | " %272 = Constant[value = ]()\n",
434 | " %273 = Constant[value = ]()\n",
435 | " %274 = Constant[value = ]()\n",
436 | " %275 = Slice(%271, %273, %274, %272)\n",
437 | " %276 = Cast[to = 7](%270)\n",
438 | " %277 = Concat[axis = 0](%275, %276)\n",
439 | " %280 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%259, %, %, %277)\n",
440 | " %281 = Concat[axis = 1](%280, %167)\n",
441 | " %352 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%281, %353, %354)\n",
442 | " %284 = Relu(%352)\n",
443 | " %355 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%284, %356, %357)\n",
444 | " %287 = Relu(%355)\n",
445 | " %288 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%287, %conv_cls.0.weight, %conv_cls.0.bias)\n",
446 | " %289 = Relu(%288)\n",
447 | " %290 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%289, %conv_cls.2.weight, %conv_cls.2.bias)\n",
448 | " %291 = Relu(%290)\n",
449 | " %292 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%291, %conv_cls.4.weight, %conv_cls.4.bias)\n",
450 | " %293 = Relu(%292)\n",
451 | " %294 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%293, %conv_cls.6.weight, %conv_cls.6.bias)\n",
452 | " %295 = Relu(%294)\n",
453 | " %296 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%295, %conv_cls.8.weight, %conv_cls.8.bias)\n",
454 | " %output = Transpose[perm = [0, 2, 3, 1]](%296)\n",
455 | " return %output, %287\n",
456 | "}\n"
457 | ]
458 | }
459 | ],
460 | "source": [
461 | "# Load the ONNX model\n",
462 | "onnx_model = onnx.load(out_detector_model)\n",
463 | "\n",
464 | "# Check that the IR is well formed\n",
465 | "onnx.checker.check_model(onnx_model)\n",
466 | "\n",
467 | "# Print a human readable representation of the graph\n",
468 | "print(onnx.helper.printable_graph(onnx_model.graph))"
469 | ]
470 | },
471 | {
472 | "cell_type": "markdown",
473 | "id": "israeli-northwest",
474 | "metadata": {},
475 | "source": [
476 | "# Recognizer\n",
477 | "\n",
478 | "ERROR UNSOLVED BY CREATOR\n",
479 | "https://github.com/pytorch/pytorch/issues/27212"
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "id": "eleven-institution",
485 | "metadata": {},
486 | "source": [
487 | "## Exporter Model\n",
488 | "Batch Size X Channel X Height X Width"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": 10,
494 | "id": "italian-trailer",
495 | "metadata": {},
496 | "outputs": [],
497 | "source": [
498 | "recognizer_dummy_input = torch.randn(100, 1, 32, 100)\n",
499 | "recognizer_dummy_text = torch.LongTensor(100, 26).fill_(0)"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": 11,
505 | "id": "sized-truck",
506 | "metadata": {
507 | "scrolled": true
508 | },
509 | "outputs": [
510 | {
511 | "data": {
512 | "text/plain": [
513 | "tensor([[[ -6.9061, -7.5613, -5.4793, ..., -2.4262, -4.9112, -4.1666],\n",
514 | " [-11.9275, -13.6367, -11.6911, ..., -10.5464, -11.6802, -11.4881],\n",
515 | " [-15.8177, -16.6920, -15.7986, ..., -15.0340, -15.5713, -15.5707],\n",
516 | " ...,\n",
517 | " [-13.8261, -2.1892, -13.1530, ..., -13.9286, -13.0425, -13.7967],\n",
518 | " [-13.8304, -2.1870, -13.1481, ..., -13.9520, -13.0352, -13.8002],\n",
519 | " [-13.8197, -2.1547, -13.1142, ..., -13.9470, -13.0154, -13.7876]],\n",
520 | "\n",
521 | " [[ -8.2746, -8.6355, -6.8791, ..., -6.2013, -8.2145, -6.9954],\n",
522 | " [-13.3459, -18.1712, -12.9131, ..., -12.4395, -12.8243, -11.7817],\n",
523 | " [-15.9471, -18.3970, -13.9544, ..., -14.9989, -15.8966, -14.6564],\n",
524 | " ...,\n",
525 | " [-13.8625, -3.7098, -11.3744, ..., -13.9167, -12.4986, -14.3229],\n",
526 | " [-13.8717, -3.6356, -11.3872, ..., -13.9434, -12.5487, -14.3171],\n",
527 | " [-13.8459, -3.6648, -11.3504, ..., -13.9306, -12.5294, -14.2871]],\n",
528 | "\n",
529 | " [[ -9.2234, -8.9115, -7.3134, ..., -8.0188, -9.4038, -7.2299],\n",
530 | " [-14.8482, -16.4109, -13.4717, ..., -14.9134, -15.2613, -14.8936],\n",
531 | " [-18.2801, -17.6252, -17.8121, ..., -17.7865, -18.4960, -17.5789],\n",
532 | " ...,\n",
533 | " [-14.0301, -4.4127, -15.5972, ..., -13.7900, -13.4122, -14.4714],\n",
534 | " [-14.0370, -4.2523, -15.5777, ..., -13.7668, -13.4058, -14.4595],\n",
535 | " [-14.0503, -4.2084, -15.6115, ..., -13.8018, -13.4401, -14.5031]],\n",
536 | "\n",
537 | " ...,\n",
538 | "\n",
539 | " [[ -7.1346, -8.7684, -3.9511, ..., -6.0389, -7.8837, -6.5712],\n",
540 | " [-13.3955, -14.4589, -13.1261, ..., -12.2595, -12.9576, -12.9921],\n",
541 | " [-16.3747, -14.4627, -16.1623, ..., -15.8627, -17.5044, -15.9218],\n",
542 | " ...,\n",
543 | " [-12.4704, -3.2038, -12.0584, ..., -13.4266, -12.3316, -12.8561],\n",
544 | " [-12.4187, -3.2129, -11.9786, ..., -13.3757, -12.2918, -12.8280],\n",
545 | " [-12.3670, -3.2296, -11.9009, ..., -13.3300, -12.2517, -12.7935]],\n",
546 | "\n",
547 | " [[ -8.4369, -8.5261, -6.7348, ..., -7.9101, -9.0706, -8.3969],\n",
548 | " [-13.5088, -12.4055, -13.0897, ..., -12.9229, -13.2860, -13.7020],\n",
549 | " [-15.6351, -13.3638, -15.9259, ..., -15.3991, -15.9885, -16.2208],\n",
550 | " ...,\n",
551 | " [-11.8698, -4.2842, -11.2330, ..., -13.6111, -12.6553, -13.3193],\n",
552 | " [-11.8772, -4.2873, -11.2386, ..., -13.6150, -12.6616, -13.3284],\n",
553 | " [-11.8831, -4.3029, -11.2439, ..., -13.6189, -12.6683, -13.3351]],\n",
554 | "\n",
555 | " [[ -5.7706, -4.6843, -6.4071, ..., -3.7921, -4.4570, -4.4655],\n",
556 | " [-12.4181, -13.1821, -10.1006, ..., -9.4514, -11.3294, -11.0276],\n",
557 | " [-13.5660, -13.7678, -9.6513, ..., -13.9703, -14.8214, -13.9399],\n",
558 | " ...,\n",
559 | " [-13.8621, -3.4651, -13.2679, ..., -14.5769, -13.8149, -14.4902],\n",
560 | " [-13.8660, -3.4625, -13.2367, ..., -14.5793, -13.8194, -14.4901],\n",
561 | " [-13.8639, -3.4490, -13.2186, ..., -14.5810, -13.8137, -14.4848]]],\n",
562 | " grad_fn=)"
563 | ]
564 | },
565 | "execution_count": 11,
566 | "metadata": {},
567 | "output_type": "execute_result"
568 | }
569 | ],
570 | "source": [
571 | "model.recognizer.module(recognizer_dummy_input, recognizer_dummy_text)"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": 12,
577 | "id": "enabling-philippines",
578 | "metadata": {},
579 | "outputs": [],
580 | "source": [
581 | "out_recognizer_model = '../models/text_recognizer/star.onnx'"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 13,
587 | "id": "roman-medicare",
588 | "metadata": {},
589 | "outputs": [
590 | {
591 | "ename": "RuntimeError",
592 | "evalue": "Exporting the operator grid_sampler to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.",
593 | "output_type": "error",
594 | "traceback": [
595 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
596 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
597 | "\u001b[0;32m/tmp/ipykernel_17483/4145972323.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Export the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m torch.onnx.export(model.recognizer.module, \n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrecognizer_dummy_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecognizer_dummy_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mout_recognizer_model\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mexport_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
598 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36mexport\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 275\u001b[0;31m return utils.export(model, args, f, export_params, verbose, training,\n\u001b[0m\u001b[1;32m 276\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maten\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexport_raw_ir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
599 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36mexport\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0moperator_export_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOperatorExportTypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mONNX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m _export(model, args, f, export_params, verbose, training, input_names, output_names,\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdo_constant_folding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdo_constant_folding\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
600 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_export\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference)\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 688\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_out\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 689\u001b[0;31m _model_to_graph(model, args, verbose, input_names,\n\u001b[0m\u001b[1;32m 690\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 691\u001b[0m \u001b[0mexample_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
601 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_model_to_graph\u001b[0;34m(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0mparams_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_named_param_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m graph = _optimize_graph(graph, operator_export_type,\n\u001b[0m\u001b[1;32m 464\u001b[0m \u001b[0m_disable_torch_constant_prop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_disable_torch_constant_prop\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[0mfixed_batch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfixed_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
602 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_optimize_graph\u001b[0;34m(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mdynamic_axes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdynamic_axes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mdynamic_axes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx_set_dynamic_input_shape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdynamic_axes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_lint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
603 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 314\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
604 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(g, block, n, inputs, env, operator_export_type)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;31m# Export it regularly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0mdomain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m \u001b[0msymbolic_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_find_symbolic_in_registry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 991\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msymbolic_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 992\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
605 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_find_symbolic_in_registry\u001b[0;34m(domain, op_name, opset_version, operator_export_type)\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0;31m# Use the original node directly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 943\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 944\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msym_registry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_registered_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 945\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 946\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
606 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/symbolic_registry.py\u001b[0m in \u001b[0;36mget_registered_op\u001b[0;34m(opname, domain, version)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"Please feel free to request support or submit a pull request on PyTorch GitHub.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_registry\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mopname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
607 | "\u001b[0;31mRuntimeError\u001b[0m: Exporting the operator grid_sampler to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub."
608 | ]
609 | }
610 | ],
611 | "source": [
612 | "# Export the model\n",
613 | "torch.onnx.export(model.recognizer.module, \n",
614 | " (recognizer_dummy_input, recognizer_dummy_text),\n",
615 | " out_recognizer_model,\n",
616 | " export_params=True,\n",
617 | " opset_version=13,\n",
618 | " do_constant_folding=True,\n",
619 | " input_names = ['input'],\n",
620 | " output_names = ['output'],\n",
621 | " dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},\n",
622 | " 'output' : {0:'batch_size'}})"
623 | ]
624 | },
625 | {
626 | "cell_type": "markdown",
627 | "id": "outside-grill",
628 | "metadata": {},
629 | "source": [
630 | "## Inspecting Model"
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": null,
636 | "id": "understood-lobby",
637 | "metadata": {},
638 | "outputs": [],
639 | "source": [
640 | "# Load the ONNX model\n",
641 | "onnx_model = onnx.load(out_recognizer_model)\n",
642 | "\n",
643 | "# Check that the IR is well formed\n",
644 | "onnx.checker.check_model(onnx_model)\n",
645 | "\n",
646 | "# Print a human readable representation of the graph\n",
647 | "print(onnx.helper.printable_graph(onnx_model.graph))"
648 | ]
649 | }
650 | ],
651 | "metadata": {
652 | "kernelspec": {
653 | "display_name": "Python [conda env:receipt-ocr]",
654 | "language": "python",
655 | "name": "conda-env-receipt-ocr-py"
656 | },
657 | "language_info": {
658 | "codemirror_mode": {
659 | "name": "ipython",
660 | "version": 3
661 | },
662 | "file_extension": ".py",
663 | "mimetype": "text/x-python",
664 | "name": "python",
665 | "nbconvert_exporter": "python",
666 | "pygments_lexer": "ipython3",
667 | "version": "3.9.7"
668 | }
669 | },
670 | "nbformat": 4,
671 | "nbformat_minor": 5
672 | }
673 |
--------------------------------------------------------------------------------