├── src ├── text_recognizer │ ├── __init__.py │ ├── modules │ │ ├── utils.py │ │ ├── sequence_modeling.py │ │ ├── prediction.py │ │ ├── model.py │ │ ├── model_utils.py │ │ ├── dataset.py │ │ ├── transformation.py │ │ └── feature_extraction.py │ ├── load_model.py │ └── infer.py ├── __init__.py ├── text_detector │ ├── __init__.py │ ├── basenet │ │ ├── __init__.py │ │ └── vgg16_bn.py │ ├── modules │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── imgproc.py │ │ ├── refinenet.py │ │ ├── craft.py │ │ └── craft_utils.py │ ├── load_model.py │ └── infer.py ├── model.py └── engine.py ├── data ├── tes.jpg └── sample_output.jpg ├── pyproject.toml ├── setup.cfg ├── requirements.txt ├── environment.yaml ├── configs ├── craft_config.yaml └── star_config.yaml ├── Dockerfile ├── LICENSE ├── main.py ├── CONTRIBUTING.md ├── .gitignore ├── notebooks ├── test_api.ipynb ├── inference_onnx_engine.ipynb ├── inference_default_engine.ipynb └── export_onnx_model.ipynb └── README.md /src/text_recognizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_detector import * 2 | from .text_recognizer import * 3 | -------------------------------------------------------------------------------- /src/text_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .basenet import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /data/tes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakartaresearch/receipt-ocr/HEAD/data/tes.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [tool.black] 5 | line-length = 100 -------------------------------------------------------------------------------- /data/sample_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakartaresearch/receipt-ocr/HEAD/data/sample_output.jpg -------------------------------------------------------------------------------- /src/text_detector/basenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg16_bn import init_weights, vgg16_bn 2 | 3 | __all__ = ['init_weights', 'vgg16_bn'] -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = Q000, WPS110, WPS214, WPS300, WPS414, WPS420, I003, C812 3 | max-line-length = 100 4 | extend-ignore = E203 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.1 2 | torchvision==0.10.1 3 | opencv-python==4.5.3.56 4 | scikit-image==0.18.3 5 | scipy==1.7.1 6 | PyYAML==5.4.1 7 | onnx==1.10.1 8 | onnxruntime==1.9.0 9 | fastapi==0.68.1 10 | uvicorn[standard] 11 | gunicorn 12 | pydantic -------------------------------------------------------------------------------- /src/text_detector/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .craft_utils import * 2 | from .imgproc import * 3 | from .utils import yaml_loader 4 | from .craft import CRAFT 5 | 6 | 7 | __all__ = ['getDetBoxes_core', 'getPoly_core', 'getDetBoxes', 8 | 'adjustResultCoordinates', 'normalizeMeanVariance', 9 | 'denormalizeMeanVariance', 'resize_aspect_ratio', 'cvt2HeatmapImg', 10 | 'yaml_loader', 'CRAFT'] 11 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: receipt-ocr 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python>=3.8 7 | - pip 8 | - pip: 9 | - torch==1.9.1 10 | - torchvision==0.10.1 11 | - opencv-python==4.5.3.56 12 | - scikit-image==0.18.3 13 | - scipy==1.7.1 14 | - PyYAML==5.4.1 15 | - onnx==1.10.1 16 | - onnxruntime==1.9.0 17 | - fastapi==0.68.1 18 | - uvicorn[standard] 19 | - pydantic -------------------------------------------------------------------------------- /configs/craft_config.yaml: -------------------------------------------------------------------------------- 1 | #text confidence threshold 2 | text_threshold: 0.7 3 | #text low-bound score 4 | low_text: 0.2 5 | #link confidence threshold 6 | link_threshold: 0.2 7 | #Use cuda for inference 8 | cuda: False 9 | #image size for inference 10 | canvas_size: 1280 11 | #image magnification ratio 12 | mag_ratio: 1.5 13 | #enable polygon type 14 | poly: False 15 | #enable link refiner 16 | refine: False 17 | #pretrained refiner model 18 | refiner_model: '' -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | 3 | WORKDIR /receipt-ocr 4 | COPY configs /receipt-ocr/configs 5 | COPY models /receipt-ocr/models 6 | COPY src /receipt-ocr/src 7 | COPY main.py /receipt-ocr 8 | COPY requirements.txt /receipt-ocr 9 | 10 | RUN apt-get update 11 | RUN apt-get install ffmpeg libsm6 libxext6 -y 12 | RUN pip install --upgrade pip 13 | RUN pip install --no-cache-dir -r requirements.txt 14 | 15 | EXPOSE 8000 16 | CMD uvicorn main:app --host 0.0.0.0 --port 8000 -------------------------------------------------------------------------------- /src/text_recognizer/modules/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from yaml.loader import SafeLoader 3 | 4 | 5 | def yaml_loader(filename): 6 | with open(filename) as f: 7 | data = yaml.load(f, Loader=SafeLoader) 8 | return data 9 | 10 | 11 | class DictObj: 12 | def __init__(self, in_dict: dict): 13 | assert isinstance(in_dict, dict) 14 | for key, val in in_dict.items(): 15 | if isinstance(val, (list, tuple)): 16 | setattr(self, key, [DictObj(x) if isinstance( 17 | x, dict) else x for x in val]) 18 | else: 19 | setattr(self, key, DictObj(val) 20 | if isinstance(val, dict) else val) 21 | -------------------------------------------------------------------------------- /configs/star_config.yaml: -------------------------------------------------------------------------------- 1 | # input batch size 2 | batch_size: 192 3 | # maximum label length 4 | batch_max_length: 25 5 | # number of data loading workers 6 | workers: 0 7 | # the height of the input image 8 | imgH: 32 9 | # the width of the input image 10 | imgW: 100 11 | # character label 12 | character: '0123456789abcdefghijklmnopqrstuvwxyz' 13 | sensitive: True 14 | 15 | # Model Architecture 16 | Transformation: 'TPS' 17 | FeatureExtraction: 'ResNet' 18 | SequenceModeling: 'BiLSTM' 19 | Prediction: 'Attn' 20 | 21 | # number of fiducial points of TPS-STN 22 | num_fiducial: 20 23 | # the number of input channel of Feature extractor 24 | input_channel: 1 25 | # the number of output channel of Feature extractor 26 | output_channel: 512 27 | # the size of the LSTM hidden state 28 | hidden_size: 256 29 | PAD: False 30 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, 9 | bidirectional=True, batch_first=True) 10 | self.linear = nn.Linear(hidden_size * 2, output_size) 11 | 12 | def forward(self, input): 13 | """ 14 | input : visual feature [batch_size x T x input_size] 15 | output : contextual feature [batch_size x T x output_size] 16 | """ 17 | self.rnn.flatten_parameters() 18 | # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 19 | recurrent, _ = self.rnn(input) 20 | output = self.linear(recurrent) # batch_size x T x output_size 21 | return output 22 | -------------------------------------------------------------------------------- /src/text_detector/modules/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import multiprocessing 3 | import onnxruntime as rt 4 | 5 | from onnxruntime import InferenceSession, get_all_providers 6 | from yaml.loader import SafeLoader 7 | 8 | 9 | def yaml_loader(filename): 10 | with open(filename) as f: 11 | data = yaml.load(f, Loader=SafeLoader) 12 | return data 13 | 14 | 15 | def create_model_for_provider(model_path: str, provider: str) -> InferenceSession: 16 | """Return inference session for ONNX model with specific provider.""" 17 | assert provider in get_all_providers( 18 | ), f"provider {provider} not found, {get_all_providers()}" 19 | 20 | sess_options = rt.SessionOptions() 21 | sess_options.intra_op_num_threads = multiprocessing.cpu_count() 22 | sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL 23 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL 24 | 25 | return InferenceSession(model_path, sess_options, providers=[provider]) 26 | -------------------------------------------------------------------------------- /src/text_recognizer/load_model.py: -------------------------------------------------------------------------------- 1 | import string 2 | import torch 3 | 4 | from .modules.model_utils import CTCLabelConverter, AttnLabelConverter 5 | from .modules.utils import yaml_loader, DictObj 6 | from .modules.model import Model 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def load_star(config_file, model_pth): 12 | cfg = yaml_loader(config_file) 13 | obj_cfg = DictObj(cfg) 14 | """ vocab / character number configuration """ 15 | if obj_cfg.sensitive: 16 | obj_cfg.character = string.printable[:-6] 17 | 18 | if "CTC" in obj_cfg.Prediction: 19 | converter = CTCLabelConverter(obj_cfg.character) 20 | else: 21 | converter = AttnLabelConverter(obj_cfg.character) 22 | 23 | obj_cfg.num_class = len(converter.character) 24 | 25 | net = Model(obj_cfg) 26 | net = torch.nn.DataParallel(net).to(device) 27 | print(f"Loading weights from checkpoint ({model_pth})") 28 | net.load_state_dict(torch.load(model_pth, map_location=device)) 29 | return obj_cfg, net, converter 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jakarta AI Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Script for Fast API Endpoint.""" 2 | import base64 3 | import io 4 | import warnings 5 | import numpy as np 6 | from fastapi import FastAPI 7 | from PIL import Image 8 | from pydantic import BaseModel 9 | from src.engine import DefaultEngine 10 | from src.model import DefaultModel 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | app = FastAPI() 16 | 17 | detector_cfg = "configs/craft_config.yaml" 18 | detector_model = "models/text_detector/craft_mlt_25k.pth" 19 | recognizer_cfg = "configs/star_config.yaml" 20 | recognizer_model = "models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth" 21 | 22 | model = DefaultModel(detector_cfg, detector_model, recognizer_cfg, recognizer_model) 23 | engine = DefaultEngine(model) 24 | 25 | 26 | class Item(BaseModel): 27 | image: str 28 | 29 | 30 | @app.get("/") 31 | def read_root(): 32 | return {"message": "API is running..."} 33 | 34 | 35 | @app.post("/ocr/predict") 36 | def predict(item: Item): 37 | item = item.dict() 38 | img_bytes = base64.b64decode(item["image"].encode("utf-8")) 39 | image = Image.open(io.BytesIO(img_bytes)) 40 | image = np.array(image) 41 | 42 | engine.predict(image) 43 | return engine.result 44 | -------------------------------------------------------------------------------- /src/text_detector/load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | from collections import OrderedDict 5 | from .modules.utils import yaml_loader, create_model_for_provider 6 | from .modules.craft import CRAFT 7 | 8 | 9 | def copy_state_dict(state_dict): 10 | if list(state_dict.keys())[0].startswith("module"): 11 | start_idx = 1 12 | else: 13 | start_idx = 0 14 | new_state_dict = OrderedDict() 15 | for k, v in state_dict.items(): 16 | name = ".".join(k.split(".")[start_idx:]) 17 | new_state_dict[name] = v 18 | return new_state_dict 19 | 20 | 21 | def load_craft(config_file, model_pth): 22 | cfg = yaml_loader(config_file) 23 | net = CRAFT() 24 | 25 | print("Loading weights from checkpoint (" + model_pth + ")") 26 | if cfg["cuda"]: 27 | net.load_state_dict(copy_state_dict(torch.load(model_pth))) 28 | else: 29 | net.load_state_dict(copy_state_dict(torch.load(model_pth, map_location="cpu"))) 30 | 31 | if cfg["cuda"]: 32 | net = net.cuda() 33 | net = torch.nn.DataParallel(net) 34 | cudnn.benchmark = False 35 | 36 | net.eval() 37 | return cfg, net 38 | 39 | 40 | def load_craft_onnx(config_file, model_pth): 41 | cfg = yaml_loader(config_file) 42 | device = "CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider" 43 | print("Loading weights from checkpoint (" + model_pth + ")") 44 | net = create_model_for_provider(model_pth, device) 45 | return cfg, net 46 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | """Model Class.""" 2 | from abc import ABC, abstractmethod 3 | from .text_detector.load_model import load_craft, load_craft_onnx 4 | from .text_recognizer.load_model import load_star 5 | 6 | 7 | class BaseModel(ABC): 8 | """Abstract base class for Receipt OCR models.""" 9 | 10 | def __init__(self, detector_cfg, detector_model, recognizer_cfg, recognizer_model): 11 | """Init model config. 12 | 13 | Args: 14 | detector_cfg: config file for text detector 15 | recognizer_cfg: config file for text recognizer 16 | """ 17 | self._cfg_detector, self._detector = self._load_detector(detector_cfg, detector_model) 18 | self._cfg_recognizer, self._recognizer, self._converter = self._load_recognizer( 19 | recognizer_cfg, recognizer_model 20 | ) 21 | 22 | @property 23 | def cfg_detector(self): 24 | return self._cfg_detector 25 | 26 | @property 27 | def detector(self): 28 | return self._detector 29 | 30 | @property 31 | def cfg_recognizer(self): 32 | return self._cfg_recognizer 33 | 34 | @property 35 | def recognizer(self): 36 | return self._recognizer 37 | 38 | @property 39 | def converter(self): 40 | return self._converter 41 | 42 | @abstractmethod 43 | def _load_detector(self): 44 | """Return CRAFT model.""" 45 | 46 | @abstractmethod 47 | def _load_recognizer(self): 48 | """Return STAR model.""" 49 | 50 | 51 | class DefaultModel(BaseModel): 52 | """Default implementation of Receipt OCR models.""" 53 | 54 | def _load_detector(self, detector_cfg, detector_model): 55 | return load_craft(detector_cfg, detector_model) 56 | 57 | def _load_recognizer(self, recognizer_cfg, recognizer_model): 58 | return load_star(recognizer_cfg, recognizer_model) 59 | 60 | 61 | class ONNXModel(DefaultModel): 62 | """ONNX Model.""" 63 | 64 | def _load_detector(self, detector_cfg, detector_model): 65 | return load_craft_onnx(detector_cfg, detector_model) 66 | -------------------------------------------------------------------------------- /src/text_detector/modules/imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 12 | # should be RGB order 13 | img = in_img.copy().astype(np.float32) 14 | 15 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, 16 | mean[2] * 255.0], dtype=np.float32) 17 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, 18 | variance[2] * 255.0], dtype=np.float32) 19 | return img 20 | 21 | 22 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 23 | # should be RGB order 24 | img = in_img.copy() 25 | img *= variance 26 | img += mean 27 | img *= 255.0 28 | img = np.clip(img, 0, 255).astype(np.uint8) 29 | return img 30 | 31 | 32 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): 33 | height, width, channel = img.shape 34 | 35 | # magnify image size 36 | target_size = mag_ratio * max(height, width) 37 | 38 | # set original image size 39 | if target_size > square_size: 40 | target_size = square_size 41 | 42 | ratio = target_size / max(height, width) 43 | 44 | target_h, target_w = int(height * ratio), int(width * ratio) 45 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation) 46 | 47 | # make canvas and paste image 48 | target_h32, target_w32 = target_h, target_w 49 | if target_h % 32 != 0: 50 | target_h32 = target_h + (32 - target_h % 32) 51 | if target_w % 32 != 0: 52 | target_w32 = target_w + (32 - target_w % 32) 53 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 54 | resized[0:target_h, 0:target_w, :] = proc 55 | target_h, target_w = target_h32, target_w32 56 | 57 | size_heatmap = (int(target_w/2), int(target_h/2)) 58 | 59 | return resized, ratio, size_heatmap 60 | 61 | 62 | def cvt2HeatmapImg(img): 63 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 64 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 65 | return img 66 | -------------------------------------------------------------------------------- /src/text_recognizer/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | import torch.nn.functional as F 5 | 6 | from .modules.dataset import RawDataset, AlignCollate 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def data_preparation(opt, list_data): 12 | AlignCollate_obj = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 13 | dataset = RawDataset(list_data=list_data, opt=opt) 14 | data_loader = torch.utils.data.DataLoader( 15 | dataset, 16 | batch_size=opt.batch_size, 17 | shuffle=False, 18 | num_workers=int(opt.workers), 19 | collate_fn=AlignCollate_obj, 20 | pin_memory=True, 21 | ) 22 | return data_loader 23 | 24 | 25 | def inference(opt, model, converter, data_loader): 26 | output_pred, output_conf_score = [], [] 27 | model.eval() 28 | with torch.no_grad(): 29 | for image_tensors, _ in data_loader: 30 | batch_size = image_tensors.size(0) 31 | image = image_tensors.to(device) 32 | # For max length prediction 33 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 34 | text_for_pred = ( 35 | torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 36 | ) 37 | 38 | preds = model(image, text_for_pred, is_train=False) 39 | # select max probabilty (greedy decoding) then decode index to character 40 | _, preds_index = preds.max(2) 41 | preds_str = converter.decode(preds_index, length_for_pred) 42 | preds_prob = F.softmax(preds, dim=2) 43 | preds_max_prob, _ = preds_prob.max(dim=2) 44 | 45 | for pred, pred_max_prob in zip(preds_str, preds_max_prob): 46 | pred_eos = pred.find("[s]") 47 | # prune after "end of sentence" token ([s]) 48 | pred = pred[:pred_eos] 49 | pred_max_prob = pred_max_prob[:pred_eos] 50 | # calculate confidence score (= multiply of pred_max_prob) 51 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 52 | output_pred.append(pred) 53 | output_conf_score.append(confidence_score) 54 | return output_pred, output_conf_score 55 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Welcome to Receipt OCR docs contributing guide 2 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 3 | 4 | - Reporting a bug 5 | - Discussing the current state of the code 6 | - Submitting a fix 7 | - Proposing new features 8 | - Becoming a maintainer 9 | 10 | ## New contributor guide 11 | 12 | To get an overview of the project, read the [README](README.md). Here are some resources to help you get started with open source contributions: 13 | 14 | - [Finding ways to contribute to open source on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github) 15 | - [Set up Git](https://docs.github.com/en/get-started/quickstart/set-up-git) 16 | - [GitHub flow](https://docs.github.com/en/get-started/quickstart/github-flow) 17 | - [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests) 18 | 19 | ## Getting started 20 | 1. Fork the repo and create your branch from `main`. 21 | 2. If you've changed code, update the docstrings and documentation. 22 | 3. Make sure your code lints. 23 | 4. Issue that pull request! 24 | 25 | ### Issues 26 | 27 | **Great Bug Reports** tend to have: 28 | 29 | - A quick summary and/or background 30 | - Steps to reproduce 31 | - Be specific! 32 | - Give sample code if you can. 33 | - What you expected would happen 34 | - What actually happens 35 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) 36 | 37 | #### Create a new issue 38 | 39 | If you spot a problem with the docs, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). 40 | 41 | #### Solve an issue 42 | 43 | Scan through our existing issues to find one that interests you. As a general rule, we don’t assign issues to anyone. If you find an issue to work on, you are welcome to open a PR with a fix. 44 | 45 | ### Code Formatting and Typing 46 | 47 | This project use flake8 as code linter and black as code formatter, with config that you can check on [pyproject.toml](pyproject.toml) and [setup.cfg](setup.cfg) 48 | 49 | ## License 50 | 51 | By contributing, you agree that your contributions will be licensed under its MIT License. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.jpg 3 | *.jpeg 4 | *.onnx 5 | settings.json 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /src/text_detector/modules/refinenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from ..basenet.vgg16_bn import init_weights 12 | 13 | 14 | class RefineNet(nn.Module): 15 | def __init__(self): 16 | super(RefineNet, self).__init__() 17 | 18 | self.last_conv = nn.Sequential( 19 | nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d( 20 | 64), nn.ReLU(inplace=True), 21 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d( 22 | 64), nn.ReLU(inplace=True), 23 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d( 24 | 64), nn.ReLU(inplace=True) 25 | ) 26 | 27 | self.aspp1 = nn.Sequential( 28 | nn.Conv2d(64, 128, kernel_size=3, dilation=6, 29 | padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 30 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d( 31 | 128), nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 1, kernel_size=1) 33 | ) 34 | 35 | self.aspp2 = nn.Sequential( 36 | nn.Conv2d(64, 128, kernel_size=3, dilation=12, 37 | padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 38 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d( 39 | 128), nn.ReLU(inplace=True), 40 | nn.Conv2d(128, 1, kernel_size=1) 41 | ) 42 | 43 | self.aspp3 = nn.Sequential( 44 | nn.Conv2d(64, 128, kernel_size=3, dilation=18, 45 | padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 46 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d( 47 | 128), nn.ReLU(inplace=True), 48 | nn.Conv2d(128, 1, kernel_size=1) 49 | ) 50 | 51 | self.aspp4 = nn.Sequential( 52 | nn.Conv2d(64, 128, kernel_size=3, dilation=24, 53 | padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 54 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d( 55 | 128), nn.ReLU(inplace=True), 56 | nn.Conv2d(128, 1, kernel_size=1) 57 | ) 58 | 59 | init_weights(self.last_conv.modules()) 60 | init_weights(self.aspp1.modules()) 61 | init_weights(self.aspp2.modules()) 62 | init_weights(self.aspp3.modules()) 63 | init_weights(self.aspp4.modules()) 64 | 65 | def forward(self, y, upconv4): 66 | refine = torch.cat([y.permute(0, 3, 1, 2), upconv4], dim=1) 67 | refine = self.last_conv(refine) 68 | 69 | aspp1 = self.aspp1(refine) 70 | aspp2 = self.aspp2(refine) 71 | aspp3 = self.aspp3(refine) 72 | aspp4 = self.aspp4(refine) 73 | 74 | # out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1) 75 | out = aspp1 + aspp2 + aspp3 + aspp4 76 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1) 77 | -------------------------------------------------------------------------------- /src/text_detector/basenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | 10 | def init_weights(modules): 11 | for m in modules: 12 | if isinstance(m, nn.Conv2d): 13 | init.xavier_uniform_(m.weight.data) 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | m.weight.data.normal_(0, 0.01) 21 | m.bias.data.zero_() 22 | 23 | 24 | class vgg16_bn(torch.nn.Module): 25 | def __init__(self, pretrained=True, freeze=True): 26 | super(vgg16_bn, self).__init__() 27 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace( 28 | 'https://', 'http://') 29 | vgg_pretrained_features = models.vgg16_bn( 30 | pretrained=pretrained).features 31 | self.slice1 = torch.nn.Sequential() 32 | self.slice2 = torch.nn.Sequential() 33 | self.slice3 = torch.nn.Sequential() 34 | self.slice4 = torch.nn.Sequential() 35 | self.slice5 = torch.nn.Sequential() 36 | for x in range(12): # conv2_2 37 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(12, 19): # conv3_3 39 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(19, 29): # conv4_3 41 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 42 | for x in range(29, 39): # conv5_3 43 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 44 | 45 | # fc6, fc7 without atrous conv 46 | self.slice5 = torch.nn.Sequential( 47 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 48 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 49 | nn.Conv2d(1024, 1024, kernel_size=1) 50 | ) 51 | 52 | if not pretrained: 53 | init_weights(self.slice1.modules()) 54 | init_weights(self.slice2.modules()) 55 | init_weights(self.slice3.modules()) 56 | init_weights(self.slice4.modules()) 57 | 58 | # no pretrained model for fc6 and fc7 59 | init_weights(self.slice5.modules()) 60 | 61 | if freeze: 62 | for param in self.slice1.parameters(): # only first conv 63 | param.requires_grad = False 64 | 65 | def forward(self, X): 66 | h = self.slice1(X) 67 | h_relu2_2 = h 68 | h = self.slice2(h) 69 | h_relu3_2 = h 70 | h = self.slice3(h) 71 | h_relu4_3 = h 72 | h = self.slice4(h) 73 | h_relu5_3 = h 74 | h = self.slice5(h) 75 | h_fc7 = h 76 | vgg_outputs = namedtuple( 77 | "VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 78 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 79 | return out 80 | -------------------------------------------------------------------------------- /src/text_detector/modules/craft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from ..basenet.vgg16_bn import vgg16_bn, init_weights 12 | 13 | 14 | class double_conv(nn.Module): 15 | def __init__(self, in_ch, mid_ch, out_ch): 16 | super(double_conv, self).__init__() 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 19 | nn.BatchNorm2d(mid_ch), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(out_ch), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | return x 29 | 30 | 31 | class CRAFT(nn.Module): 32 | def __init__(self, pretrained=False, freeze=False): 33 | super(CRAFT, self).__init__() 34 | 35 | """ Base network """ 36 | self.basenet = vgg16_bn(pretrained, freeze) 37 | 38 | """ U network """ 39 | self.upconv1 = double_conv(1024, 512, 256) 40 | self.upconv2 = double_conv(512, 256, 128) 41 | self.upconv3 = double_conv(256, 128, 64) 42 | self.upconv4 = double_conv(128, 64, 32) 43 | 44 | num_class = 2 45 | self.conv_cls = nn.Sequential( 46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 47 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 48 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 49 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 50 | nn.Conv2d(16, num_class, kernel_size=1), 51 | ) 52 | 53 | init_weights(self.upconv1.modules()) 54 | init_weights(self.upconv2.modules()) 55 | init_weights(self.upconv3.modules()) 56 | init_weights(self.upconv4.modules()) 57 | init_weights(self.conv_cls.modules()) 58 | 59 | def forward(self, x): 60 | """ Base network """ 61 | sources = self.basenet(x) 62 | 63 | """ U network """ 64 | y = torch.cat([sources[0], sources[1]], dim=1) 65 | y = self.upconv1(y) 66 | 67 | y = F.interpolate(y, size=sources[2].size()[ 68 | 2:], mode='bilinear', align_corners=False) 69 | y = torch.cat([y, sources[2]], dim=1) 70 | y = self.upconv2(y) 71 | 72 | y = F.interpolate(y, size=sources[3].size()[ 73 | 2:], mode='bilinear', align_corners=False) 74 | y = torch.cat([y, sources[3]], dim=1) 75 | y = self.upconv3(y) 76 | 77 | y = F.interpolate(y, size=sources[4].size()[ 78 | 2:], mode='bilinear', align_corners=False) 79 | y = torch.cat([y, sources[4]], dim=1) 80 | feature = self.upconv4(y) 81 | 82 | y = self.conv_cls(feature) 83 | 84 | return y.permute(0, 2, 3, 1), feature 85 | 86 | 87 | if __name__ == '__main__': 88 | model = CRAFT(pretrained=True).cuda() 89 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 90 | print(output.shape) 91 | -------------------------------------------------------------------------------- /src/text_detector/infer.py: -------------------------------------------------------------------------------- 1 | """Copyright (c) 2019-present NAVER Corp. 2 | 3 | MIT License 4 | """ 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from .modules import craft_utils 9 | from .modules import imgproc 10 | from torch.autograd import Variable 11 | 12 | 13 | def test_net( 14 | net, 15 | image, 16 | text_threshold, 17 | link_threshold, 18 | low_text, 19 | cuda, 20 | poly, 21 | canvas_size, 22 | mag_ratio, 23 | refine_net=None, 24 | onnx=False, 25 | ): 26 | # resize 27 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( 28 | image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio 29 | ) 30 | ratio_h = ratio_w = 1 / target_ratio 31 | 32 | # preprocessing 33 | x = imgproc.normalizeMeanVariance(img_resized) 34 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 35 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 36 | if cuda: 37 | x = x.cuda() 38 | 39 | if onnx: 40 | # forward pass 41 | input_onnx = {"input": x.numpy()} 42 | with torch.no_grad(): 43 | y, feature = net.run(None, input_onnx) 44 | 45 | # make score and link map 46 | score_text = y[0, :, :, 0] 47 | score_link = y[0, :, :, 1] 48 | 49 | # refine link 50 | if refine_net is not None: 51 | with torch.no_grad(): 52 | y_refiner = refine_net(y, feature) 53 | score_link = y_refiner[0, :, :, 0] 54 | else: 55 | # forward pass 56 | with torch.no_grad(): 57 | y, feature = net(x) 58 | 59 | # make score and link map 60 | score_text = y[0, :, :, 0].cpu().data.numpy() 61 | score_link = y[0, :, :, 1].cpu().data.numpy() 62 | 63 | # refine link 64 | if refine_net is not None: 65 | with torch.no_grad(): 66 | y_refiner = refine_net(y, feature) 67 | score_link = y_refiner[0, :, :, 0].cpu().data.numpy() 68 | 69 | # Post-processing 70 | boxes, polys = craft_utils.getDetBoxes( 71 | score_text, score_link, text_threshold, link_threshold, low_text, poly 72 | ) 73 | 74 | # coordinate adjustment 75 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 76 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 77 | for k in range(len(polys)): 78 | if polys[k] is None: 79 | polys[k] = boxes[k] 80 | 81 | # render results (optional) 82 | render_img = score_text.copy() 83 | render_img = np.hstack((render_img, score_link)) 84 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 85 | return boxes, polys, ret_score_text 86 | 87 | 88 | def inference(cfg, net, image, onnx=False): 89 | bboxes, polys, score_text = test_net( 90 | net, 91 | image, 92 | cfg["text_threshold"], 93 | cfg["link_threshold"], 94 | cfg["low_text"], 95 | cfg["cuda"], 96 | cfg["poly"], 97 | cfg["canvas_size"], 98 | cfg["mag_ratio"], 99 | onnx=onnx, 100 | ) 101 | return polys 102 | -------------------------------------------------------------------------------- /notebooks/test_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "specific-pulse", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import requests\n", 11 | "import json\n", 12 | "import base64\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "earned-missouri", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "image_file = '../data/tes.jpg'\n", 24 | "\n", 25 | "with open(image_file, \"rb\") as f:\n", 26 | " im_bytes = f.read()\n", 27 | "\n", 28 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "id": "afraid-cheese", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "data = {\"image\": im_b64}\n", 39 | "payload = json.dumps(data)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "id": "acquired-factory", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "api_endpoint = \"http://127.0.0.1:8000/ocr/predict\"\n", 50 | "headers = {'content-type': 'application/json'}\n", 51 | "\n", 52 | "response = requests.request(\"POST\", api_endpoint, data=payload, headers=headers)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "id": "turned-queue", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "200" 65 | ] 66 | }, 67 | "execution_count": 5, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "response.status_code" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "id": "hindu-diagnosis", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "['Geprek Bensu Kopo Bandung',\n", 86 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n", 87 | " 'KOTA BANDUNG',\n", 88 | " 'order: 33',\n", 89 | " 'Kode',\n", 90 | " 'Tanggal 16-07-2021 11:53:47',\n", 91 | " 'Kasiri Kasir 1 Kopo BDG',\n", 92 | " 'Pelanggant gjk wahyu',\n", 93 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n", 94 | " 'D Level I X 27',\n", 95 | " 't Harga (27',\n", 96 | " 'Dada X',\n", 97 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n", 98 | " 'Harga (17',\n", 99 | " '+ Dada X',\n", 100 | " 'Take Away Charge X 4.000',\n", 101 | " 'Subtotal 49.000',\n", 102 | " 'PB1 (10%) 4.500',\n", 103 | " 'Total 53,500',\n", 104 | " 'Gobiz 53.500',\n", 105 | " 'Kembal Ii',\n", 106 | " 'LUNAS *x',\n", 107 | " 'Terima Kasih']" 108 | ] 109 | }, 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "response.json()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "coordinate-trouble", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.8.3" 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 5 149 | } 150 | -------------------------------------------------------------------------------- /src/engine.py: -------------------------------------------------------------------------------- 1 | """Engine Class.""" 2 | import math 3 | import time 4 | from abc import ABC, abstractmethod 5 | from .model import DefaultModel 6 | from .text_detector.infer import inference as infer_detector 7 | from .text_recognizer.infer import data_preparation, inference as infer_recognizer 8 | 9 | 10 | def timeit(method): 11 | def timed(*args, **kw): 12 | ts = time.time() 13 | result = method(*args, **kw) 14 | te = time.time() 15 | if "log_time" in kw: 16 | name = kw.get("log_name", method.__name__.upper()) 17 | kw["log_time"][name] = int((te - ts)) 18 | else: 19 | print("%r %2.2f s" % (method.__name__, (te - ts))) 20 | return result 21 | 22 | return timed 23 | 24 | 25 | class BaseEngine(ABC): 26 | def __init__(self, receipt_ocr_model: DefaultModel): 27 | if not isinstance(receipt_ocr_model, DefaultModel): 28 | raise TypeError 29 | self._model: DefaultModel = receipt_ocr_model 30 | 31 | @abstractmethod 32 | def inference_detector(self): 33 | pass 34 | 35 | @abstractmethod 36 | def inference_recognizer(self): 37 | pass 38 | 39 | @abstractmethod 40 | def predict(self): 41 | pass 42 | 43 | 44 | class DefaultEngine(BaseEngine): 45 | @timeit 46 | def inference_detector(self): 47 | output = infer_detector(self._model.cfg_detector, self._model.detector, self._input) 48 | self._out_detector = output 49 | 50 | @timeit 51 | def inference_recognizer(self): 52 | data_loader = data_preparation(opt=self._model.cfg_recognizer, list_data=self._imgs) 53 | pred, conf_score = infer_recognizer( 54 | opt=self._model.cfg_recognizer, 55 | model=self._model.recognizer, 56 | converter=self._model.converter, 57 | data_loader=data_loader, 58 | ) 59 | output = list(zip(pred, conf_score, self._coords)) 60 | output = filter(lambda x: x[1] > 0.5, output) 61 | self.raw_output = sorted(output, key=lambda x: x[2][0]) 62 | self.result = self.combine_entity() 63 | 64 | @timeit 65 | def predict(self, image): 66 | self._input = image 67 | 68 | self.inference_detector() 69 | self.get_img_from_bb() 70 | self.inference_recognizer() 71 | 72 | def get_img_from_bb(self): 73 | imgs, coords = [], [] 74 | for bb in self._out_detector: 75 | cropped_img, coord = self.crop_img(self._input, bb) 76 | imgs.append(cropped_img) 77 | coords.append(coord) 78 | self._imgs = imgs 79 | self._coords = coords 80 | 81 | def crop_img(self, img, bb): 82 | x1, y1 = bb[0] 83 | x2, y2 = bb[1] 84 | x3, y3 = bb[2] 85 | x4, y4 = bb[3] 86 | 87 | top_left_x = math.ceil(min([x1, x2, x3, x4])) 88 | top_left_y = math.ceil(min([y1, y2, y3, y4])) 89 | bot_right_x = math.ceil(max([x1, x2, x3, x4])) 90 | bot_right_y = math.ceil(max([y1, y2, y3, y4])) 91 | coord = (top_left_y, bot_right_y, top_left_x, bot_right_x) 92 | 93 | cropped_image = img[top_left_y : bot_right_y + 1, top_left_x : bot_right_x + 1] 94 | return cropped_image, coord 95 | 96 | def combine_entity(self): 97 | thres = 20 98 | output, all_entity, entity = [], [], [] 99 | entity.append(self.raw_output[0]) 100 | 101 | for idx in range(len(self.raw_output) - 1): 102 | diff = abs(self.raw_output[idx][2][0] - self.raw_output[idx + 1][2][0]) 103 | if diff < thres: 104 | entity.append(self.raw_output[idx + 1]) 105 | else: 106 | all_entity.append(entity) 107 | entity = [] 108 | entity.append(self.raw_output[idx + 1]) 109 | 110 | # Sorting entity by coordinates 111 | for idx in range(len(all_entity)): 112 | all_entity[idx] = sorted(all_entity[idx], key=lambda x: (x[2][3], x[2][1], x[2][2])) 113 | 114 | # Concatenate Entity 115 | for entity in all_entity: 116 | tmp = [x[0] for x in entity] 117 | output.append(" ".join(tmp)) 118 | return output 119 | 120 | 121 | class ONNXEngine(DefaultEngine): 122 | @timeit 123 | def inference_detector(self): 124 | output = infer_detector( 125 | self._model.cfg_detector, self._model.detector, self._input, onnx=True 126 | ) 127 | self._out_detector = output 128 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | 7 | class Attention(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell( 12 | input_size, hidden_size, num_classes) 13 | self.hidden_size = hidden_size 14 | self.num_classes = num_classes 15 | self.generator = nn.Linear(hidden_size, num_classes) 16 | 17 | def _char_to_onehot(self, input_char, onehot_dim=38): 18 | input_char = input_char.unsqueeze(1) 19 | batch_size = input_char.size(0) 20 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 21 | one_hot = one_hot.scatter_(1, input_char, 1) 22 | return one_hot 23 | 24 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 25 | """ 26 | input: 27 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 28 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 29 | output: probability distribution at each step [batch_size x num_steps x num_classes] 30 | """ 31 | batch_size = batch_H.size(0) 32 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 33 | 34 | output_hiddens = torch.FloatTensor( 35 | batch_size, num_steps, self.hidden_size).fill_(0).to(device) 36 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 37 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 38 | 39 | if is_train: 40 | for i in range(num_steps): 41 | # one-hot vectors for a i-th char. in a batch 42 | char_onehots = self._char_to_onehot( 43 | text[:, i], onehot_dim=self.num_classes) 44 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 45 | hidden, alpha = self.attention_cell( 46 | hidden, batch_H, char_onehots) 47 | # LSTM hidden index (0: hidden, 1: Cell) 48 | output_hiddens[:, i, :] = hidden[0] 49 | probs = self.generator(output_hiddens) 50 | 51 | else: 52 | targets = torch.LongTensor(batch_size).fill_( 53 | 0).to(device) # [GO] token 54 | probs = torch.FloatTensor( 55 | batch_size, num_steps, self.num_classes).fill_(0).to(device) 56 | 57 | for i in range(num_steps): 58 | char_onehots = self._char_to_onehot( 59 | targets, onehot_dim=self.num_classes) 60 | hidden, alpha = self.attention_cell( 61 | hidden, batch_H, char_onehots) 62 | probs_step = self.generator(hidden[0]) 63 | probs[:, i, :] = probs_step 64 | _, next_input = probs_step.max(1) 65 | targets = next_input 66 | 67 | return probs # batch_size x num_steps x num_classes 68 | 69 | 70 | class AttentionCell(nn.Module): 71 | 72 | def __init__(self, input_size, hidden_size, num_embeddings): 73 | super(AttentionCell, self).__init__() 74 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 75 | # either i2i or h2h should have bias 76 | self.h2h = nn.Linear(hidden_size, hidden_size) 77 | self.score = nn.Linear(hidden_size, 1, bias=False) 78 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 79 | self.hidden_size = hidden_size 80 | 81 | def forward(self, prev_hidden, batch_H, char_onehots): 82 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 83 | batch_H_proj = self.i2h(batch_H) 84 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 85 | # batch_size x num_encoder_step * 1 86 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) 87 | 88 | alpha = F.softmax(e, dim=1) 89 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze( 90 | 1) # batch_size x num_channel 91 | # batch_size x (num_channel + num_embedding) 92 | concat_context = torch.cat([context, char_onehots], 1) 93 | cur_hidden = self.rnn(concat_context, prev_hidden) 94 | return cur_hidden, alpha 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optical Character Recognition for Receipt 2 | 3 | ## Sample Results 4 | Input Image | Output 5 | :----------------------:|:----------------------: 6 | | 7 | 8 | ## References 9 | 10 | | Title | Author | Year | Github | Paper | Download Model| 11 | | ----------------------------------------------------------------------------------------| ---------------- | ---- | --------- | ----- | -------- | 12 | | Character Region Awareness for Text Detection | Clova AI Research, NAVER Corp.| 2019 | https://github.com/clovaai/CRAFT-pytorch | https://arxiv.org/abs/1904.01941 | [craft_mlt_25k.pth](https://drive.google.com/file/d/1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ/view)| 13 | | What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis | Clova AI Research, NAVER Corp.| 2019 | https://github.com/clovaai/deep-text-recognition-benchmark | https://arxiv.org/abs/1904.01906 | [TPS-ResNet-BiLSTM-Attn-case-sensitive.pth](https://www.dropbox.com/sh/j3xmli4di1zuv3s/AAArdcPgz7UFxIHUuKNOeKv_a?dl=0) | 14 | 15 | ## Folder structure 16 | ``` 17 | . 18 | ├─ configs 19 | | ├─ craft_config.yaml 20 | | └─ star_config.yaml 21 | ├─ data 22 | | ├─ sample_output.jpg 23 | | └─ tes.jpg 24 | ├─ notebooks 25 | | ├─ export_onnx_model.ipynb 26 | | ├─ inference_default_engine.ipynb 27 | | ├─ inference_onnx_engine.ipynb 28 | | └─ test_api.ipynb 29 | ├─ src 30 | | ├─ text_detector 31 | | │ ├─ basenet 32 | | │ │ ├─ __init__.py 33 | | │ │ └─ vgg16_bn.py 34 | | │ ├─ modules 35 | | │ │ ├─ __init__.py 36 | | │ │ ├─ craft.py 37 | | │ │ ├─ craft_utils.py 38 | | │ │ ├─ imgproc.py 39 | | │ │ ├─ refinenet.py 40 | | │ │ └─ utils.py 41 | | │ ├─ __init__.py 42 | | │ ├─ infer.py 43 | | │ └─ load_model.py 44 | | ├─ text_recognizer 45 | | │ ├─ modules 46 | | │ │ ├─ dataset.py 47 | | │ │ ├─ feature_extraction.py 48 | | │ │ ├─ model.py 49 | | │ │ ├─ model_utils.py 50 | | │ │ ├─ prediction.py 51 | | │ │ ├─ sequence_modeling.py 52 | | │ │ ├─ transformation.py 53 | | │ │ └─ utils.py 54 | | │ ├─ __init__.py 55 | | │ ├─ infer.py 56 | | │ └─ load_model.py 57 | | ├─ __init__.py 58 | | ├─ engine.py 59 | | └─ model.py 60 | ├─ .gitignore 61 | ├─ CONTRIBUTING.md 62 | ├─ Dockerfile 63 | ├─ environment.yaml 64 | ├─ LICENSE 65 | ├─ main.py 66 | ├─ pyproject.toml 67 | ├─ README.md 68 | ├─ requirements.txt 69 | ├─ setup.cfg 70 | ``` 71 | 72 | ## Model Preparation 73 | You need to create "models" folder to store this: 74 | - detector_model = "models/text_detector/craft_mlt_25k.pth" 75 | - recognizer_model = "models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth" 76 | 77 | Download all of pretrained models from "References" section 78 | 79 | ## Requirements 80 | You can setup the environment using conda or pip 81 | ``` 82 | pip install -r requirements.txt 83 | ``` 84 | or 85 | ``` 86 | conda env create -f environment.yaml 87 | ``` 88 | 89 | ## Container 90 | ``` 91 | docker build -t receipt-ocr . 92 | docker run -d --name receipt-ocr-service -p 80:80 receipt-ocr 93 | docker start receipt-ocr-service 94 | docker stop receipt-ocr-service 95 | ``` 96 | 97 | ## How to contribute? 98 | Check the docs [here](CONTRIBUTING.md) 99 | 100 | ## Creator 101 | [![](https://github.com/andreaschandra/git-assets/blob/master/pictures/ruben.png)](https://github.com/rubentea16) -------------------------------------------------------------------------------- /src/text_recognizer/modules/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch.nn as nn 18 | 19 | from .transformation import TPS_SpatialTransformerNetwork 20 | from .feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor 21 | from .sequence_modeling import BidirectionalLSTM 22 | from .prediction import Attention 23 | 24 | 25 | class Model(nn.Module): 26 | 27 | def __init__(self, opt): 28 | super(Model, self).__init__() 29 | self.opt = opt 30 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 31 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 32 | 33 | """ Transformation """ 34 | if opt.Transformation == 'TPS': 35 | self.Transformation = TPS_SpatialTransformerNetwork(F=opt.num_fiducial, 36 | I_size=( 37 | opt.imgH, opt.imgW), 38 | I_r_size=( 39 | opt.imgH, opt.imgW), 40 | I_channel_num=opt.input_channel) 41 | else: 42 | print('No Transformation module specified') 43 | 44 | """ FeatureExtraction """ 45 | if opt.FeatureExtraction == 'VGG': 46 | self.FeatureExtraction = VGG_FeatureExtractor( 47 | opt.input_channel, opt.output_channel) 48 | elif opt.FeatureExtraction == 'RCNN': 49 | self.FeatureExtraction = RCNN_FeatureExtractor( 50 | opt.input_channel, opt.output_channel) 51 | elif opt.FeatureExtraction == 'ResNet': 52 | self.FeatureExtraction = ResNet_FeatureExtractor( 53 | opt.input_channel, opt.output_channel) 54 | else: 55 | raise Exception('No FeatureExtraction module specified') 56 | # int(imgH/16-1) * 512 57 | self.FeatureExtraction_output = opt.output_channel 58 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d( 59 | (None, 1)) # Transform final (imgH/16-1) -> 1 60 | 61 | """ Sequence modeling""" 62 | if opt.SequenceModeling == 'BiLSTM': 63 | self.SequenceModeling = nn.Sequential( 64 | BidirectionalLSTM(self.FeatureExtraction_output, 65 | opt.hidden_size, opt.hidden_size), 66 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 67 | self.SequenceModeling_output = opt.hidden_size 68 | else: 69 | print('No SequenceModeling module specified') 70 | self.SequenceModeling_output = self.FeatureExtraction_output 71 | 72 | """ Prediction """ 73 | if opt.Prediction == 'CTC': 74 | self.Prediction = nn.Linear( 75 | self.SequenceModeling_output, opt.num_class) 76 | elif opt.Prediction == 'Attn': 77 | self.Prediction = Attention( 78 | self.SequenceModeling_output, opt.hidden_size, opt.num_class) 79 | else: 80 | raise Exception('Prediction is neither CTC or Attn') 81 | 82 | def forward(self, input, text, is_train=True): 83 | """ Transformation stage """ 84 | if not self.stages['Trans'] == "None": 85 | input = self.Transformation(input) 86 | 87 | """ Feature extraction stage """ 88 | visual_feature = self.FeatureExtraction(input) 89 | visual_feature = self.AdaptiveAvgPool( 90 | visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 91 | visual_feature = visual_feature.squeeze(3) 92 | 93 | """ Sequence modeling stage """ 94 | if self.stages['Seq'] == 'BiLSTM': 95 | contextual_feature = self.SequenceModeling(visual_feature) 96 | else: 97 | # for convenience. this is NOT contextually modeled by BiLSTM 98 | contextual_feature = visual_feature 99 | 100 | """ Prediction stage """ 101 | if self.stages['Pred'] == 'CTC': 102 | prediction = self.Prediction(contextual_feature.contiguous()) 103 | else: 104 | prediction = self.Prediction(contextual_feature.contiguous( 105 | ), text, is_train, batch_max_length=self.opt.batch_max_length) 106 | 107 | return prediction 108 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | 5 | class CTCLabelConverter(object): 6 | """ Convert between text-label and text-index """ 7 | 8 | def __init__(self, character): 9 | # character (str): set of the possible characters. 10 | dict_character = list(character) 11 | 12 | self.dict = {} 13 | for i, char in enumerate(dict_character): 14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 15 | self.dict[char] = i + 1 16 | 17 | # dummy '[CTCblank]' token for CTCLoss (index 0) 18 | self.character = ['[CTCblank]'] + dict_character 19 | 20 | def encode(self, text, batch_max_length=25): 21 | """convert text-label into text-index. 22 | input: 23 | text: text labels of each image. [batch_size] 24 | batch_max_length: max length of text label in the batch. 25 by default 25 | 26 | output: 27 | text: text index for CTCLoss. [batch_size, batch_max_length] 28 | length: length of each text. [batch_size] 29 | """ 30 | length = [len(s) for s in text] 31 | 32 | # The index used for padding (=0) would not affect the CTC loss calculation. 33 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 34 | for i, t in enumerate(text): 35 | text = list(t) 36 | text = [self.dict[char] for char in text] 37 | batch_text[i][:len(text)] = torch.LongTensor(text) 38 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 39 | 40 | def decode(self, text_index, length): 41 | """ convert text-index into text-label. """ 42 | texts = [] 43 | for index, l in enumerate(length): 44 | t = text_index[index, :] 45 | 46 | char_list = [] 47 | for i in range(l): 48 | # removing repeated characters and blank. 49 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 50 | char_list.append(self.character[t[i]]) 51 | text = ''.join(char_list) 52 | 53 | texts.append(text) 54 | return texts 55 | 56 | 57 | class CTCLabelConverterForBaiduWarpctc(object): 58 | """ Convert between text-label and text-index for baidu warpctc """ 59 | 60 | def __init__(self, character): 61 | # character (str): set of the possible characters. 62 | dict_character = list(character) 63 | 64 | self.dict = {} 65 | for i, char in enumerate(dict_character): 66 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 67 | self.dict[char] = i + 1 68 | 69 | # dummy '[CTCblank]' token for CTCLoss (index 0) 70 | self.character = ['[CTCblank]'] + dict_character 71 | 72 | def encode(self, text, batch_max_length=25): 73 | """convert text-label into text-index. 74 | input: 75 | text: text labels of each image. [batch_size] 76 | output: 77 | text: concatenated text index for CTCLoss. 78 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 79 | length: length of each text. [batch_size] 80 | """ 81 | length = [len(s) for s in text] 82 | text = ''.join(text) 83 | text = [self.dict[char] for char in text] 84 | 85 | return (torch.IntTensor(text), torch.IntTensor(length)) 86 | 87 | def decode(self, text_index, length): 88 | """ convert text-index into text-label. """ 89 | texts = [] 90 | index = 0 91 | for l in length: 92 | t = text_index[index:index + l] 93 | 94 | char_list = [] 95 | for i in range(l): 96 | # removing repeated characters and blank. 97 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 98 | char_list.append(self.character[t[i]]) 99 | text = ''.join(char_list) 100 | 101 | texts.append(text) 102 | index += l 103 | return texts 104 | 105 | 106 | class AttnLabelConverter(object): 107 | """ Convert between text-label and text-index """ 108 | 109 | def __init__(self, character): 110 | # character (str): set of the possible characters. 111 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 112 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 113 | list_character = list(character) 114 | self.character = list_token + list_character 115 | 116 | self.dict = {} 117 | for i, char in enumerate(self.character): 118 | # print(i, char) 119 | self.dict[char] = i 120 | 121 | def encode(self, text, batch_max_length=25): 122 | """ convert text-label into text-index. 123 | input: 124 | text: text labels of each image. [batch_size] 125 | batch_max_length: max length of text label in the batch. 25 by default 126 | 127 | output: 128 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 129 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 130 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 131 | """ 132 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 133 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 134 | batch_max_length += 1 135 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 136 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 137 | for i, t in enumerate(text): 138 | text = list(t) 139 | text.append('[s]') 140 | text = [self.dict[char] for char in text] 141 | # batch_text[:, 0] = [GO] token 142 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) 143 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 144 | 145 | def decode(self, text_index, length): 146 | """ convert text-index into text-label. """ 147 | texts = [] 148 | for index, l in enumerate(length): 149 | text = ''.join([self.character[i] for i in text_index[index, :]]) 150 | texts.append(text) 151 | return texts 152 | 153 | 154 | class Averager(object): 155 | """Compute average for torch.Tensor, used for loss average.""" 156 | 157 | def __init__(self): 158 | self.reset() 159 | 160 | def add(self, v): 161 | count = v.data.numel() 162 | v = v.data.sum() 163 | self.n_count += count 164 | self.sum += v 165 | 166 | def reset(self): 167 | self.n_count = 0 168 | self.sum = 0 169 | 170 | def val(self): 171 | res = 0 172 | if self.n_count != 0: 173 | res = self.sum / float(self.n_count) 174 | return res 175 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import numpy as np 7 | import torch 8 | 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from torch.utils.data import Dataset, ConcatDataset, Subset 12 | from torch._utils import _accumulate 13 | 14 | 15 | class Batch_Balanced_Dataset(object): 16 | 17 | def __init__(self, opt): 18 | """ 19 | Modulate the data ratio in the batch. 20 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 21 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 22 | """ 23 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 24 | dashed_line = '-' * 80 25 | print(dashed_line) 26 | log.write(dashed_line + '\n') 27 | print( 28 | f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 29 | log.write( 30 | f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 31 | assert len(opt.select_data) == len(opt.batch_ratio) 32 | 33 | _AlignCollate = AlignCollate( 34 | imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 35 | self.data_loader_list = [] 36 | self.dataloader_iter_list = [] 37 | batch_size_list = [] 38 | Total_batch_size = 0 39 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 40 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 41 | print(dashed_line) 42 | log.write(dashed_line + '\n') 43 | _dataset, _dataset_log = hierarchical_dataset( 44 | root=opt.train_data, opt=opt, select_data=[selected_d]) 45 | total_number_dataset = len(_dataset) 46 | log.write(_dataset_log) 47 | 48 | """ 49 | The total number of data can be modified with opt.total_data_usage_ratio. 50 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 51 | See 4.2 section in our paper. 52 | """ 53 | number_dataset = int(total_number_dataset * 54 | float(opt.total_data_usage_ratio)) 55 | dataset_split = [number_dataset, 56 | total_number_dataset - number_dataset] 57 | indices = range(total_number_dataset) 58 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 59 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 60 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 61 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 62 | print(selected_d_log) 63 | log.write(selected_d_log + '\n') 64 | batch_size_list.append(str(_batch_size)) 65 | Total_batch_size += _batch_size 66 | 67 | _data_loader = torch.utils.data.DataLoader( 68 | _dataset, batch_size=_batch_size, 69 | shuffle=True, 70 | num_workers=int(opt.workers), 71 | collate_fn=_AlignCollate, pin_memory=True) 72 | self.data_loader_list.append(_data_loader) 73 | self.dataloader_iter_list.append(iter(_data_loader)) 74 | 75 | Total_batch_size_log = f'{dashed_line}\n' 76 | batch_size_sum = '+'.join(batch_size_list) 77 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 78 | Total_batch_size_log += f'{dashed_line}' 79 | opt.batch_size = Total_batch_size 80 | 81 | print(Total_batch_size_log) 82 | log.write(Total_batch_size_log + '\n') 83 | log.close() 84 | 85 | def get_batch(self): 86 | balanced_batch_images = [] 87 | balanced_batch_texts = [] 88 | 89 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 90 | try: 91 | image, text = data_loader_iter.next() 92 | balanced_batch_images.append(image) 93 | balanced_batch_texts += text 94 | except StopIteration: 95 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 96 | image, text = self.dataloader_iter_list[i].next() 97 | balanced_batch_images.append(image) 98 | balanced_batch_texts += text 99 | except ValueError: 100 | pass 101 | 102 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 103 | 104 | return balanced_batch_images, balanced_batch_texts 105 | 106 | 107 | class RawDataset(Dataset): 108 | 109 | def __init__(self, list_data, opt): 110 | self.opt = opt 111 | self.image_list = list_data 112 | self.nSamples = len(self.image_list) 113 | 114 | def __len__(self): 115 | return self.nSamples 116 | 117 | def __getitem__(self, index): 118 | gray_img = Image.fromarray(self.image_list[index]).convert('L') 119 | return (gray_img, 'image') 120 | 121 | 122 | class ResizeNormalize(object): 123 | 124 | def __init__(self, size, interpolation=Image.BICUBIC): 125 | self.size = size 126 | self.interpolation = interpolation 127 | self.toTensor = transforms.ToTensor() 128 | 129 | def __call__(self, img): 130 | img = img.resize(self.size, self.interpolation) 131 | img = self.toTensor(img) 132 | img.sub_(0.5).div_(0.5) 133 | return img 134 | 135 | 136 | class NormalizePAD(object): 137 | 138 | def __init__(self, max_size, PAD_type='right'): 139 | self.toTensor = transforms.ToTensor() 140 | self.max_size = max_size 141 | self.max_width_half = math.floor(max_size[2] / 2) 142 | self.PAD_type = PAD_type 143 | 144 | def __call__(self, img): 145 | img = self.toTensor(img) 146 | img.sub_(0.5).div_(0.5) 147 | c, h, w = img.size() 148 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 149 | Pad_img[:, :, :w] = img # right pad 150 | if self.max_size[2] != w: # add border Pad 151 | Pad_img[:, :, w:] = img[:, :, w - 152 | 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 153 | 154 | return Pad_img 155 | 156 | 157 | class AlignCollate(object): 158 | 159 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 160 | self.imgH = imgH 161 | self.imgW = imgW 162 | self.keep_ratio_with_pad = keep_ratio_with_pad 163 | 164 | def __call__(self, batch): 165 | batch = filter(lambda x: x is not None, batch) 166 | images, labels = zip(*batch) 167 | 168 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 169 | resized_max_w = self.imgW 170 | input_channel = 3 if images[0].mode == 'RGB' else 1 171 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 172 | 173 | resized_images = [] 174 | for image in images: 175 | w, h = image.size 176 | ratio = w / float(h) 177 | if math.ceil(self.imgH * ratio) > self.imgW: 178 | resized_w = self.imgW 179 | else: 180 | resized_w = math.ceil(self.imgH * ratio) 181 | 182 | resized_image = image.resize( 183 | (resized_w, self.imgH), Image.BICUBIC) 184 | resized_images.append(transform(resized_image)) 185 | # resized_image.save('./image_test/%d_test.jpg' % w) 186 | 187 | image_tensors = torch.cat([t.unsqueeze(0) 188 | for t in resized_images], 0) 189 | 190 | else: 191 | transform = ResizeNormalize((self.imgW, self.imgH)) 192 | image_tensors = [transform(image) for image in images] 193 | image_tensors = torch.cat([t.unsqueeze(0) 194 | for t in image_tensors], 0) 195 | 196 | return image_tensors, labels 197 | 198 | 199 | def tensor2im(image_tensor, imtype=np.uint8): 200 | image_numpy = image_tensor.cpu().float().numpy() 201 | if image_numpy.shape[0] == 1: 202 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 203 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 204 | return image_numpy.astype(imtype) 205 | 206 | 207 | def save_image(image_numpy, image_path): 208 | image_pil = Image.fromarray(image_numpy) 209 | image_pil.save(image_path) 210 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork( 27 | self.F, self.I_channel_num) 28 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 29 | 30 | def forward(self, batch_I): 31 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 32 | # batch_size x n (= I_r_width x I_r_height) x 2 33 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) 34 | build_P_prime_reshape = build_P_prime.reshape( 35 | [build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 36 | 37 | if torch.__version__ > "1.2.0": 38 | batch_I_r = F.grid_sample( 39 | batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 40 | else: 41 | batch_I_r = F.grid_sample( 42 | batch_I, build_P_prime_reshape, padding_mode='border') 43 | 44 | return batch_I_r 45 | 46 | 47 | class LocalizationNetwork(nn.Module): 48 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 49 | 50 | def __init__(self, F, I_channel_num): 51 | super(LocalizationNetwork, self).__init__() 52 | self.F = F 53 | self.I_channel_num = I_channel_num 54 | self.conv = nn.Sequential( 55 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 56 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 57 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 58 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d( 59 | 128), nn.ReLU(True), 60 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 61 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d( 62 | 256), nn.ReLU(True), 63 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 64 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d( 65 | 512), nn.ReLU(True), 66 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 67 | ) 68 | 69 | self.localization_fc1 = nn.Sequential( 70 | nn.Linear(512, 256), nn.ReLU(True)) 71 | self.localization_fc2 = nn.Linear(256, self.F * 2) 72 | 73 | # Init fc2 in LocalizationNetwork 74 | self.localization_fc2.weight.data.fill_(0) 75 | """ see RARE paper Fig. 6 (a) """ 76 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 77 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 78 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 79 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 80 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 81 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 82 | self.localization_fc2.bias.data = torch.from_numpy( 83 | initial_bias).float().view(-1) 84 | 85 | def forward(self, batch_I): 86 | """ 87 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 88 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 89 | """ 90 | batch_size = batch_I.size(0) 91 | features = self.conv(batch_I).view(batch_size, -1) 92 | batch_C_prime = self.localization_fc2( 93 | self.localization_fc1(features)).view(batch_size, self.F, 2) 94 | return batch_C_prime 95 | 96 | 97 | class GridGenerator(nn.Module): 98 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 99 | 100 | def __init__(self, F, I_r_size): 101 | """ Generate P_hat and inv_delta_C for later """ 102 | super(GridGenerator, self).__init__() 103 | self.eps = 1e-6 104 | self.I_r_height, self.I_r_width = I_r_size 105 | self.F = F 106 | self.C = self._build_C(self.F) # F x 2 107 | self.P = self._build_P(self.I_r_width, self.I_r_height) 108 | # for multi-gpu, you need register buffer 109 | self.register_buffer("inv_delta_C", torch.tensor( 110 | self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 111 | self.register_buffer("P_hat", torch.tensor( 112 | self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 113 | # for fine-tuning with different image width, you may use below instead of self.register_buffer 114 | # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 115 | # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 116 | 117 | def _build_C(self, F): 118 | """ Return coordinates of fiducial points in I_r; C """ 119 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 120 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 121 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 122 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 123 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 124 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 125 | return C # F x 2 126 | 127 | def _build_inv_delta_C(self, F, C): 128 | """ Return inv_delta_C which is needed to calculate T """ 129 | hat_C = np.zeros((F, F), dtype=float) # F x F 130 | for i in range(0, F): 131 | for j in range(i, F): 132 | r = np.linalg.norm(C[i] - C[j]) 133 | hat_C[i, j] = r 134 | hat_C[j, i] = r 135 | np.fill_diagonal(hat_C, 1) 136 | hat_C = (hat_C ** 2) * np.log(hat_C) 137 | # print(C.shape, hat_C.shape) 138 | delta_C = np.concatenate( # F+3 x F+3 139 | [ 140 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 141 | np.concatenate( 142 | [np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 143 | np.concatenate( 144 | [np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 145 | ], 146 | axis=0 147 | ) 148 | inv_delta_C = np.linalg.inv(delta_C) 149 | return inv_delta_C # F+3 x F+3 150 | 151 | def _build_P(self, I_r_width, I_r_height): 152 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 153 | 1.0) / I_r_width # self.I_r_width 154 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 155 | 1.0) / I_r_height # self.I_r_height 156 | P = np.stack( # self.I_r_width x self.I_r_height x 2 157 | np.meshgrid(I_r_grid_x, I_r_grid_y), 158 | axis=2 159 | ) 160 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 161 | 162 | def _build_P_hat(self, F, C, P): 163 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 164 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1) 165 | ) # n x 2 -> n x 1 x 2 -> n x F x 2 166 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 167 | P_diff = P_tile - C_tile # n x F x 2 168 | rbf_norm = np.linalg.norm( 169 | P_diff, ord=2, axis=2, keepdims=False) # n x F 170 | rbf = np.multiply(np.square(rbf_norm), np.log( 171 | rbf_norm + self.eps)) # n x F 172 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 173 | return P_hat # n x F+3 174 | 175 | def build_P_prime(self, batch_C_prime): 176 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 177 | batch_size = batch_C_prime.size(0) 178 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 179 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 180 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 181 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 182 | # batch_size x F+3 x 2 183 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) 184 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 185 | return batch_P_prime # batch_size x n x 2 186 | -------------------------------------------------------------------------------- /src/text_detector/modules/craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | 14 | 15 | def warpCoord(Minv, pt): 16 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 17 | return np.array([out[0]/out[2], out[1]/out[2]]) 18 | 19 | 20 | """ end of auxilary functions """ 21 | 22 | 23 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 24 | # prepare data 25 | linkmap = linkmap.copy() 26 | textmap = textmap.copy() 27 | img_h, img_w = textmap.shape 28 | 29 | """ labeling method """ 30 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 31 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 32 | 33 | text_score_comb = np.clip(text_score + link_score, 0, 1) 34 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 35 | text_score_comb.astype(np.uint8), connectivity=4) 36 | 37 | det = [] 38 | mapper = [] 39 | for k in range(1, nLabels): 40 | # size filtering 41 | size = stats[k, cv2.CC_STAT_AREA] 42 | if size < 10: 43 | continue 44 | 45 | # thresholding 46 | if np.max(textmap[labels == k]) < text_threshold: 47 | continue 48 | 49 | # make segmentation map 50 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 51 | segmap[labels == k] = 255 52 | # remove link area 53 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 54 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 55 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 56 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 57 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 58 | # boundary check 59 | if sx < 0: 60 | sx = 0 61 | if sy < 0: 62 | sy = 0 63 | if ex >= img_w: 64 | ex = img_w 65 | if ey >= img_h: 66 | ey = img_h 67 | kernel = cv2.getStructuringElement( 68 | cv2.MORPH_RECT, (1 + niter, 1 + niter)) 69 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 70 | 71 | # make box 72 | np_contours = np.roll(np.array(np.where(segmap != 0)), 73 | 1, axis=0).transpose().reshape(-1, 2) 74 | rectangle = cv2.minAreaRect(np_contours) 75 | box = cv2.boxPoints(rectangle) 76 | 77 | # align diamond-shape 78 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 79 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 80 | if abs(1 - box_ratio) <= 0.1: 81 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0]) 82 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1]) 83 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 84 | 85 | # make clock-wise order 86 | startidx = box.sum(axis=1).argmin() 87 | box = np.roll(box, 4-startidx, 0) 88 | box = np.array(box) 89 | 90 | det.append(box) 91 | mapper.append(k) 92 | 93 | return det, labels, mapper 94 | 95 | 96 | def getPoly_core(boxes, labels, mapper, linkmap): 97 | # configs 98 | num_cp = 5 99 | max_len_ratio = 0.7 100 | expand_ratio = 1.45 101 | max_r = 2.0 102 | step_r = 0.2 103 | 104 | polys = [] 105 | for k, box in enumerate(boxes): 106 | # size filter for small instance 107 | w, h = int(np.linalg.norm(box[0] - box[1]) + 108 | 1), int(np.linalg.norm(box[1] - box[2]) + 1) 109 | if w < 10 or h < 10: 110 | polys.append(None) 111 | continue 112 | 113 | # warp image 114 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 115 | M = cv2.getPerspectiveTransform(box, tar) 116 | word_label = cv2.warpPerspective( 117 | labels, M, (w, h), flags=cv2.INTER_NEAREST) 118 | try: 119 | Minv = np.linalg.inv(M) 120 | except: 121 | polys.append(None) 122 | continue 123 | 124 | # binarization for selected label 125 | cur_label = mapper[k] 126 | word_label[word_label != cur_label] = 0 127 | word_label[word_label > 0] = 1 128 | 129 | """ Polygon generation """ 130 | # find top/bottom contours 131 | cp = [] 132 | max_len = -1 133 | for i in range(w): 134 | region = np.where(word_label[:, i] != 0)[0] 135 | if len(region) < 2: 136 | continue 137 | cp.append((i, region[0], region[-1])) 138 | length = region[-1] - region[0] + 1 139 | if length > max_len: 140 | max_len = length 141 | 142 | # pass if max_len is similar to h 143 | if h * max_len_ratio < max_len: 144 | polys.append(None) 145 | continue 146 | 147 | # get pivot points with fixed length 148 | tot_seg = num_cp * 2 + 1 149 | seg_w = w / tot_seg # segment width 150 | pp = [None] * num_cp # init pivot points 151 | cp_section = [[0, 0]] * tot_seg 152 | seg_height = [0] * num_cp 153 | seg_num = 0 154 | num_sec = 0 155 | prev_h = -1 156 | for i in range(0, len(cp)): 157 | (x, sy, ey) = cp[i] 158 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 159 | # average previous segment 160 | if num_sec == 0: 161 | break 162 | cp_section[seg_num] = [cp_section[seg_num][0] / 163 | num_sec, cp_section[seg_num][1] / num_sec] 164 | num_sec = 0 165 | 166 | # reset variables 167 | seg_num += 1 168 | prev_h = -1 169 | 170 | # accumulate center points 171 | cy = (sy + ey) * 0.5 172 | cur_h = ey - sy + 1 173 | cp_section[seg_num] = [cp_section[seg_num] 174 | [0] + x, cp_section[seg_num][1] + cy] 175 | num_sec += 1 176 | 177 | if seg_num % 2 == 0: 178 | continue # No polygon area 179 | 180 | if prev_h < cur_h: 181 | pp[int((seg_num - 1)/2)] = (x, cy) 182 | seg_height[int((seg_num - 1)/2)] = cur_h 183 | prev_h = cur_h 184 | 185 | # processing last segment 186 | if num_sec != 0: 187 | cp_section[-1] = [cp_section[-1][0] / 188 | num_sec, cp_section[-1][1] / num_sec] 189 | 190 | # pass if num of pivots is not sufficient or segment widh is smaller than character height 191 | if None in pp or seg_w < np.max(seg_height) * 0.25: 192 | polys.append(None) 193 | continue 194 | 195 | # calc median maximum of pivot points 196 | half_char_h = np.median(seg_height) * expand_ratio / 2 197 | 198 | # calc gradiant and apply to make horizontal pivots 199 | new_pp = [] 200 | for i, (x, cy) in enumerate(pp): 201 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 202 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 203 | if dx == 0: # gradient if zero 204 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 205 | continue 206 | rad = - math.atan2(dy, dx) 207 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 208 | new_pp.append([x - s, cy - c, x + s, cy + c]) 209 | 210 | # get edge points to cover character heatmaps 211 | isSppFound, isEppFound = False, False 212 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + \ 213 | (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 214 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + \ 215 | (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 216 | for r in np.arange(0.5, max_r, step_r): 217 | dx = 2 * half_char_h * r 218 | if not isSppFound: 219 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 220 | dy = grad_s * dx 221 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 222 | cv2.line(line_img, (int(p[0]), int(p[1])), 223 | (int(p[2]), int(p[3])), 1, thickness=1) 224 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 225 | spp = p 226 | isSppFound = True 227 | if not isEppFound: 228 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 229 | dy = grad_e * dx 230 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 231 | cv2.line(line_img, (int(p[0]), int(p[1])), 232 | (int(p[2]), int(p[3])), 1, thickness=1) 233 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 234 | epp = p 235 | isEppFound = True 236 | if isSppFound and isEppFound: 237 | break 238 | 239 | # pass if boundary of polygon is not found 240 | if not (isSppFound and isEppFound): 241 | polys.append(None) 242 | continue 243 | 244 | # make final polygon 245 | poly = [] 246 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 247 | for p in new_pp: 248 | poly.append(warpCoord(Minv, (p[0], p[1]))) 249 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 250 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 251 | for p in reversed(new_pp): 252 | poly.append(warpCoord(Minv, (p[2], p[3]))) 253 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 254 | 255 | # add to final result 256 | polys.append(np.array(poly)) 257 | 258 | return polys 259 | 260 | 261 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 262 | boxes, labels, mapper = getDetBoxes_core( 263 | textmap, linkmap, text_threshold, link_threshold, low_text) 264 | 265 | if poly: 266 | polys = getPoly_core(boxes, labels, mapper, linkmap) 267 | else: 268 | polys = [None] * len(boxes) 269 | 270 | return boxes, polys 271 | 272 | 273 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2): 274 | if len(polys) > 0: 275 | polys = np.array(polys) 276 | for k in range(len(polys)): 277 | if polys[k] is not None: 278 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 279 | return polys 280 | -------------------------------------------------------------------------------- /notebooks/inference_onnx_engine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "false-shipping", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "import io\n", 12 | "import base64\n", 13 | "import numpy as np\n", 14 | "import warnings\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append('..')\n", 19 | "\n", 20 | "from src.model import ONNXModel\n", 21 | "from src.engine import ONNXEngine" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "automated-graduation", 27 | "metadata": {}, 28 | "source": [ 29 | "## Load model network and weight" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "periodic-timothy", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "detector_cfg = '../configs/craft_config.yaml'\n", 40 | "detector_model = '../models/text_detector/craft.onnx'\n", 41 | "recognizer_cfg = '../configs/star_config.yaml'\n", 42 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "id": "recovered-consortium", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Loading weights from checkpoint (../models/text_detector/craft.onnx)\n", 56 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "model = ONNXModel(detector_cfg, detector_model, \n", 62 | " recognizer_cfg, recognizer_model)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "inside-salon", 68 | "metadata": {}, 69 | "source": [ 70 | "## Load Engine" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "id": "editorial-bargain", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "engine = ONNXEngine(model)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "known-seattle", 86 | "metadata": {}, 87 | "source": [ 88 | "## Input Data Preparation" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "id": "sonic-jacksonville", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "image_file = '../data/tes.jpg'\n", 99 | "\n", 100 | "with open(image_file, \"rb\") as f:\n", 101 | " im_bytes = f.read()\n", 102 | "\n", 103 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "id": "featured-boost", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "img_bytes = base64.b64decode(im_b64.encode('utf-8'))\n", 114 | "image = Image.open(io.BytesIO(img_bytes))\n", 115 | "image = np.array(image)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "plain-prediction", 121 | "metadata": {}, 122 | "source": [ 123 | "## Prediction" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 7, 129 | "id": "biblical-davis", 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "'inference_detector' 2.73 s\n", 137 | "'inference_recognizer' 3.43 s\n", 138 | "'predict' 6.17 s\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "engine.predict(image)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "id": "burning-jacket", 150 | "metadata": { 151 | "scrolled": true 152 | }, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "[('Kopo', tensor(0.9991), (620, 680, 481, 575)),\n", 158 | " ('Bandung', tensor(1.0000), (622, 673, 575, 713)),\n", 159 | " ('Geprek', tensor(1.0000), (626, 676, 253, 374)),\n", 160 | " ('Bensu', tensor(0.9999), (626, 676, 380, 481)),\n", 161 | " ('Babakan', tensor(0.9999), (666, 716, 602, 740)),\n", 162 | " ('Ciparay', tensor(0.9997), (666, 716, 743, 884)),\n", 163 | " ('536,', tensor(0.9938), (666, 720, 313, 397)),\n", 164 | " ('Margasuka,', tensor(0.9953), (669, 723, 404, 592)),\n", 165 | " ('J1,', tensor(0.5975), (673, 723, 85, 145)),\n", 166 | " ('Kopo', tensor(0.9966), (673, 723, 155, 239)),\n", 167 | " ('No.', tensor(0.7720), (673, 720, 246, 306)),\n", 168 | " ('KOTA', tensor(0.9945), (710, 763, 367, 454)),\n", 169 | " ('BANDUNG', tensor(0.9989), (710, 763, 458, 599)),\n", 170 | " ('33', tensor(0.7718), (804, 854, 246, 296)),\n", 171 | " ('order:', tensor(0.9085), (804, 858, 120, 240)),\n", 172 | " ('Kode', tensor(0.9972), (851, 901, 48, 135)),\n", 173 | " ('16-07-2021', tensor(0.9492), (888, 942, 212, 400)),\n", 174 | " ('11:53:47', tensor(0.9910), (891, 942, 411, 562)),\n", 175 | " ('Tanggal', tensor(0.9934), (895, 948, 51, 185)),\n", 176 | " ('BDG', tensor(0.9957), (931, 985, 407, 474)),\n", 177 | " ('Kasiri', tensor(0.5920), (938, 989, 51, 169)),\n", 178 | " ('Kasir', tensor(0.9999), (938, 989, 175, 279)),\n", 179 | " ('Kopo', tensor(0.9991), (938, 989, 316, 404)),\n", 180 | " ('1', tensor(0.5032), (945, 982, 286, 306)),\n", 181 | " ('Pelanggant', tensor(0.6637), (980, 1039, 49, 241)),\n", 182 | " ('gjk', tensor(0.9945), (982, 1036, 246, 313)),\n", 183 | " ('wahyu', tensor(0.9993), (982, 1032, 320, 424)),\n", 184 | " ('GOFOO', tensor(0.9861), (1059, 1116, 673, 780)),\n", 185 | " ('Paket', tensor(0.9999), (1066, 1120, 48, 152)),\n", 186 | " ('Geprek', tensor(0.9998), (1068, 1126, 157, 281)),\n", 187 | " ('Bensu', tensor(0.9999), (1069, 1120, 283, 387)),\n", 188 | " ('Nasi', tensor(0.9995), (1069, 1120, 394, 471)),\n", 189 | " ('Daun', tensor(0.9998), (1069, 1116, 481, 565)),\n", 190 | " ('Jeruk', tensor(0.9994), (1069, 1116, 572, 673)),\n", 191 | " ('27', tensor(0.9975), (1110, 1160, 787, 834)),\n", 192 | " ('Level', tensor(0.9845), (1116, 1167, 125, 222)),\n", 193 | " ('D', tensor(0.9934), (1120, 1160, 51, 81)),\n", 194 | " ('I', tensor(0.6730), (1123, 1160, 236, 256)),\n", 195 | " ('X', tensor(0.8183), (1123, 1160, 266, 296)),\n", 196 | " ('(27', tensor(0.9810), (1160, 1210, 195, 263)),\n", 197 | " ('Harga', tensor(1.0000), (1163, 1214, 88, 189)),\n", 198 | " ('t', tensor(0.5064), (1173, 1204, 58, 78)),\n", 199 | " ('Dada', tensor(0.9999), (1207, 1254, 88, 172)),\n", 200 | " ('X', tensor(0.9218), (1214, 1251, 179, 209)),\n", 201 | " ('17.500', tensor(0.8386), (1241, 1291, 787, 908)),\n", 202 | " ('Original', tensor(0.9989), (1241, 1301, 532, 689)),\n", 203 | " ('GOFOOD', tensor(0.9986), (1244, 1298, 370, 498)),\n", 204 | " ('Bensu', tensor(0.9998), (1247, 1298, 269, 370)),\n", 205 | " ('Ayam', tensor(0.9882), (1251, 1301, 54, 138)),\n", 206 | " ('Geprek', tensor(0.9999), (1251, 1301, 142, 263)),\n", 207 | " ('X', tensor(0.9251), (1251, 1291, 696, 726)),\n", 208 | " ('(17', tensor(0.9848), (1291, 1345, 195, 263)),\n", 209 | " ('Harga', tensor(1.0000), (1293, 1351, 85, 194)),\n", 210 | " ('Dada', tensor(0.9999), (1338, 1389, 88, 175)),\n", 211 | " ('X', tensor(0.9166), (1345, 1385, 179, 209)),\n", 212 | " ('+', tensor(0.9094), (1352, 1382, 58, 81)),\n", 213 | " ('4.000', tensor(0.9562), (1375, 1426, 807, 908)),\n", 214 | " ('Charge', tensor(0.9999), (1380, 1438, 230, 355)),\n", 215 | " ('Take', tensor(0.9999), (1382, 1432, 58, 138)),\n", 216 | " ('Away', tensor(0.9999), (1384, 1437, 143, 231)),\n", 217 | " ('X', tensor(0.8809), (1389, 1426, 357, 390)),\n", 218 | " ('49.000', tensor(0.6948), (1463, 1513, 790, 911)),\n", 219 | " ('Subtotal', tensor(0.9988), (1469, 1523, 51, 209)),\n", 220 | " ('4.500', tensor(0.9441), (1510, 1560, 807, 911)),\n", 221 | " ('PB1', tensor(0.9708), (1516, 1567, 54, 118)),\n", 222 | " ('(10%)', tensor(0.9886), (1516, 1567, 128, 229)),\n", 223 | " ('53,500', tensor(0.7126), (1597, 1647, 790, 911)),\n", 224 | " ('Total', tensor(0.9979), (1607, 1657, 58, 158)),\n", 225 | " ('53.500', tensor(0.8583), (1684, 1738, 790, 911)),\n", 226 | " ('Gobiz', tensor(0.9990), (1692, 1749, 73, 181)),\n", 227 | " ('Kembal', tensor(0.9943), (1738, 1792, 58, 162)),\n", 228 | " ('Ii', tensor(0.7184), (1738, 1789, 152, 195)),\n", 229 | " ('LUNAS', tensor(0.9925), (1825, 1876, 444, 548)),\n", 230 | " ('*x', tensor(0.5476), (1829, 1869, 552, 602)),\n", 231 | " ('Kasih', tensor(0.9972), (1910, 1963, 505, 612)),\n", 232 | " ('Terima', tensor(0.9996), (1913, 1963, 384, 505)),\n", 233 | " ('POS', tensor(0.9836), (2041, 2094, 622, 693)),\n", 234 | " ('by', tensor(0.9994), (2047, 2101, 448, 498)),\n", 235 | " ('Pawoon', tensor(0.9994), (2047, 2098, 498, 622)),\n", 236 | " ('Powered', tensor(0.9999), (2051, 2101, 303, 444))]" 237 | ] 238 | }, 239 | "execution_count": 8, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "engine.raw_output" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "id": "outstanding-munich", 252 | "metadata": { 253 | "scrolled": true 254 | }, 255 | "outputs": [ 256 | { 257 | "data": { 258 | "text/plain": [ 259 | "['Geprek Bensu Kopo Bandung',\n", 260 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n", 261 | " 'KOTA BANDUNG',\n", 262 | " 'order: 33',\n", 263 | " 'Kode',\n", 264 | " 'Tanggal 16-07-2021 11:53:47',\n", 265 | " 'Kasiri Kasir 1 Kopo BDG',\n", 266 | " 'Pelanggant gjk wahyu',\n", 267 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n", 268 | " 'D Level I X 27',\n", 269 | " 't Harga (27',\n", 270 | " 'Dada X',\n", 271 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n", 272 | " 'Harga (17',\n", 273 | " '+ Dada X',\n", 274 | " 'Take Away Charge X 4.000',\n", 275 | " 'Subtotal 49.000',\n", 276 | " 'PB1 (10%) 4.500',\n", 277 | " 'Total 53,500',\n", 278 | " 'Gobiz 53.500',\n", 279 | " 'Kembal Ii',\n", 280 | " 'LUNAS *x',\n", 281 | " 'Terima Kasih']" 282 | ] 283 | }, 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "engine.result" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "id": "narrow-satellite", 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python [conda env:receipt-ocr]", 305 | "language": "python", 306 | "name": "conda-env-receipt-ocr-py" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 3 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython3", 318 | "version": "3.9.7" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 5 323 | } 324 | -------------------------------------------------------------------------------- /notebooks/inference_default_engine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "romantic-checklist", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "import io\n", 12 | "import base64\n", 13 | "import numpy as np\n", 14 | "import warnings\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "warnings.filterwarnings(\"ignore\")\n", 18 | "sys.path.append('..')\n", 19 | "\n", 20 | "from src.model import DefaultModel\n", 21 | "from src.engine import DefaultEngine" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "hourly-russell", 27 | "metadata": {}, 28 | "source": [ 29 | "## Load model network and weight" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "covered-closure", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "detector_cfg = '../configs/craft_config.yaml'\n", 40 | "detector_model = '../models/text_detector/craft_mlt_25k.pth'\n", 41 | "recognizer_cfg = '../configs/star_config.yaml'\n", 42 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "id": "intense-ordinance", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Loading weights from checkpoint (../models/text_detector/craft_mlt_25k.pth)\n", 56 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "model = DefaultModel(detector_cfg, detector_model, \n", 62 | " recognizer_cfg, recognizer_model)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "continued-shift", 68 | "metadata": {}, 69 | "source": [ 70 | "## Load Engine" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "id": "marine-legend", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "engine = DefaultEngine(model)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "insured-cooperation", 86 | "metadata": {}, 87 | "source": [ 88 | "## Input Data Preparation" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "id": "cardiac-mother", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "image_file = '../data/tes.jpg'\n", 99 | "\n", 100 | "with open(image_file, \"rb\") as f:\n", 101 | " im_bytes = f.read()\n", 102 | "\n", 103 | "im_b64 = base64.b64encode(im_bytes).decode('utf-8')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "id": "weird-psychology", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "img_bytes = base64.b64decode(im_b64.encode('utf-8'))\n", 114 | "image = Image.open(io.BytesIO(img_bytes))\n", 115 | "image = np.array(image)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "adjustable-prevention", 121 | "metadata": {}, 122 | "source": [ 123 | "## Prediction" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 7, 129 | "id": "southern-connecticut", 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "'inference_detector' 2.88 s\n", 137 | "'inference_recognizer' 3.67 s\n", 138 | "'predict' 6.56 s\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "engine.predict(image)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "id": "composite-royalty", 150 | "metadata": { 151 | "scrolled": true 152 | }, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "[('Kopo', tensor(0.9991), (620, 680, 481, 575)),\n", 158 | " ('Bandung', tensor(1.0000), (622, 673, 575, 713)),\n", 159 | " ('Geprek', tensor(1.0000), (626, 676, 253, 374)),\n", 160 | " ('Bensu', tensor(0.9999), (626, 676, 380, 481)),\n", 161 | " ('Babakan', tensor(0.9999), (666, 716, 602, 740)),\n", 162 | " ('Ciparay', tensor(0.9997), (666, 716, 743, 884)),\n", 163 | " ('536,', tensor(0.9938), (666, 720, 313, 397)),\n", 164 | " ('Margasuka,', tensor(0.9953), (669, 723, 404, 592)),\n", 165 | " ('J1,', tensor(0.5975), (673, 723, 85, 145)),\n", 166 | " ('Kopo', tensor(0.9966), (673, 723, 155, 239)),\n", 167 | " ('No.', tensor(0.7720), (673, 720, 246, 306)),\n", 168 | " ('KOTA', tensor(0.9945), (710, 763, 367, 454)),\n", 169 | " ('BANDUNG', tensor(0.9989), (710, 763, 458, 599)),\n", 170 | " ('33', tensor(0.7718), (804, 854, 246, 296)),\n", 171 | " ('order:', tensor(0.9085), (804, 858, 120, 240)),\n", 172 | " ('Kode', tensor(0.9972), (851, 901, 48, 135)),\n", 173 | " ('16-07-2021', tensor(0.9492), (888, 942, 212, 400)),\n", 174 | " ('11:53:47', tensor(0.9910), (891, 942, 411, 562)),\n", 175 | " ('Tanggal', tensor(0.9934), (895, 948, 51, 185)),\n", 176 | " ('BDG', tensor(0.9957), (931, 985, 407, 474)),\n", 177 | " ('Kasiri', tensor(0.5920), (938, 989, 51, 169)),\n", 178 | " ('Kasir', tensor(0.9999), (938, 989, 175, 279)),\n", 179 | " ('Kopo', tensor(0.9991), (938, 989, 316, 404)),\n", 180 | " ('1', tensor(0.5032), (945, 982, 286, 306)),\n", 181 | " ('Pelanggant', tensor(0.6637), (980, 1039, 49, 241)),\n", 182 | " ('gjk', tensor(0.9945), (982, 1036, 246, 313)),\n", 183 | " ('wahyu', tensor(0.9993), (982, 1032, 320, 424)),\n", 184 | " ('GOFOO', tensor(0.9861), (1059, 1116, 673, 780)),\n", 185 | " ('Paket', tensor(0.9999), (1066, 1120, 48, 152)),\n", 186 | " ('Geprek', tensor(0.9998), (1068, 1126, 157, 281)),\n", 187 | " ('Bensu', tensor(0.9999), (1069, 1120, 283, 387)),\n", 188 | " ('Nasi', tensor(0.9995), (1069, 1120, 394, 471)),\n", 189 | " ('Daun', tensor(0.9998), (1069, 1116, 481, 565)),\n", 190 | " ('Jeruk', tensor(0.9994), (1069, 1116, 572, 673)),\n", 191 | " ('27', tensor(0.9975), (1110, 1160, 787, 834)),\n", 192 | " ('Level', tensor(0.9845), (1116, 1167, 125, 222)),\n", 193 | " ('D', tensor(0.9934), (1120, 1160, 51, 81)),\n", 194 | " ('I', tensor(0.6730), (1123, 1160, 236, 256)),\n", 195 | " ('X', tensor(0.8183), (1123, 1160, 266, 296)),\n", 196 | " ('(27', tensor(0.9810), (1160, 1210, 195, 263)),\n", 197 | " ('Harga', tensor(1.0000), (1163, 1214, 88, 189)),\n", 198 | " ('t', tensor(0.5064), (1173, 1204, 58, 78)),\n", 199 | " ('Dada', tensor(0.9999), (1207, 1254, 88, 172)),\n", 200 | " ('X', tensor(0.9218), (1214, 1251, 179, 209)),\n", 201 | " ('17.500', tensor(0.8386), (1241, 1291, 787, 908)),\n", 202 | " ('Original', tensor(0.9989), (1241, 1301, 532, 689)),\n", 203 | " ('GOFOOD', tensor(0.9986), (1244, 1298, 370, 498)),\n", 204 | " ('Bensu', tensor(0.9998), (1247, 1298, 269, 370)),\n", 205 | " ('Ayam', tensor(0.9882), (1251, 1301, 54, 138)),\n", 206 | " ('Geprek', tensor(0.9999), (1251, 1301, 142, 263)),\n", 207 | " ('X', tensor(0.9251), (1251, 1291, 696, 726)),\n", 208 | " ('(17', tensor(0.9848), (1291, 1345, 195, 263)),\n", 209 | " ('Harga', tensor(1.0000), (1293, 1351, 85, 194)),\n", 210 | " ('Dada', tensor(0.9999), (1338, 1389, 88, 175)),\n", 211 | " ('X', tensor(0.9166), (1345, 1385, 179, 209)),\n", 212 | " ('+', tensor(0.9094), (1352, 1382, 58, 81)),\n", 213 | " ('4.000', tensor(0.9562), (1375, 1426, 807, 908)),\n", 214 | " ('Charge', tensor(0.9999), (1380, 1438, 230, 355)),\n", 215 | " ('Take', tensor(0.9999), (1382, 1432, 58, 138)),\n", 216 | " ('Away', tensor(0.9999), (1384, 1437, 143, 231)),\n", 217 | " ('X', tensor(0.8809), (1389, 1426, 357, 390)),\n", 218 | " ('49.000', tensor(0.6948), (1463, 1513, 790, 911)),\n", 219 | " ('Subtotal', tensor(0.9988), (1469, 1523, 51, 209)),\n", 220 | " ('4.500', tensor(0.9441), (1510, 1560, 807, 911)),\n", 221 | " ('PB1', tensor(0.9708), (1516, 1567, 54, 118)),\n", 222 | " ('(10%)', tensor(0.9886), (1516, 1567, 128, 229)),\n", 223 | " ('53,500', tensor(0.7126), (1597, 1647, 790, 911)),\n", 224 | " ('Total', tensor(0.9979), (1607, 1657, 58, 158)),\n", 225 | " ('53.500', tensor(0.8583), (1684, 1738, 790, 911)),\n", 226 | " ('Gobiz', tensor(0.9990), (1692, 1749, 73, 181)),\n", 227 | " ('Kembal', tensor(0.9943), (1738, 1792, 58, 162)),\n", 228 | " ('Ii', tensor(0.7184), (1738, 1789, 152, 195)),\n", 229 | " ('LUNAS', tensor(0.9925), (1825, 1876, 444, 548)),\n", 230 | " ('*x', tensor(0.5476), (1829, 1869, 552, 602)),\n", 231 | " ('Kasih', tensor(0.9972), (1910, 1963, 505, 612)),\n", 232 | " ('Terima', tensor(0.9996), (1913, 1963, 384, 505)),\n", 233 | " ('POS', tensor(0.9836), (2041, 2094, 622, 693)),\n", 234 | " ('by', tensor(0.9994), (2047, 2101, 448, 498)),\n", 235 | " ('Pawoon', tensor(0.9994), (2047, 2098, 498, 622)),\n", 236 | " ('Powered', tensor(0.9999), (2051, 2101, 303, 444))]" 237 | ] 238 | }, 239 | "execution_count": 8, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "engine.raw_output" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "id": "helpful-madrid", 252 | "metadata": { 253 | "scrolled": true 254 | }, 255 | "outputs": [ 256 | { 257 | "data": { 258 | "text/plain": [ 259 | "['Geprek Bensu Kopo Bandung',\n", 260 | " 'J1, Kopo No. 536, Margasuka, Babakan Ciparay',\n", 261 | " 'KOTA BANDUNG',\n", 262 | " 'order: 33',\n", 263 | " 'Kode',\n", 264 | " 'Tanggal 16-07-2021 11:53:47',\n", 265 | " 'Kasiri Kasir 1 Kopo BDG',\n", 266 | " 'Pelanggant gjk wahyu',\n", 267 | " 'Paket Geprek Bensu Nasi Daun Jeruk GOFOO',\n", 268 | " 'D Level I X 27',\n", 269 | " 't Harga (27',\n", 270 | " 'Dada X',\n", 271 | " 'Ayam Geprek Bensu GOFOOD Original X 17.500',\n", 272 | " 'Harga (17',\n", 273 | " '+ Dada X',\n", 274 | " 'Take Away Charge X 4.000',\n", 275 | " 'Subtotal 49.000',\n", 276 | " 'PB1 (10%) 4.500',\n", 277 | " 'Total 53,500',\n", 278 | " 'Gobiz 53.500',\n", 279 | " 'Kembal Ii',\n", 280 | " 'LUNAS *x',\n", 281 | " 'Terima Kasih']" 282 | ] 283 | }, 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "engine.result" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "id": "valid-dominican", 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python [conda env:receipt-ocr]", 305 | "language": "python", 306 | "name": "conda-env-receipt-ocr-py" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 3 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython3", 318 | "version": "3.9.7" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 5 323 | } 324 | -------------------------------------------------------------------------------- /src/text_recognizer/modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VGG_FeatureExtractor(nn.Module): 6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 7 | 8 | def __init__(self, input_channel, output_channel=512): 9 | super(VGG_FeatureExtractor, self).__init__() 10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 12 | self.ConvNet = nn.Sequential( 13 | nn.Conv2d(input_channel, 14 | self.output_channel[0], 3, 1, 1), nn.ReLU(True), 15 | nn.MaxPool2d(2, 2), # 64x16x50 16 | nn.Conv2d( 17 | self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 18 | nn.MaxPool2d(2, 2), # 128x8x25 19 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU( 20 | True), # 256x8x25 21 | nn.Conv2d( 22 | self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 23 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 24 | nn.Conv2d( 25 | self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 26 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 27 | nn.Conv2d( 28 | self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 29 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 30 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 31 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 32 | 33 | def forward(self, input): 34 | return self.ConvNet(input) 35 | 36 | 37 | class RCNN_FeatureExtractor(nn.Module): 38 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 39 | 40 | def __init__(self, input_channel, output_channel=512): 41 | super(RCNN_FeatureExtractor, self).__init__() 42 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 43 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 44 | self.ConvNet = nn.Sequential( 45 | nn.Conv2d(input_channel, 46 | self.output_channel[0], 3, 1, 1), nn.ReLU(True), 47 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 48 | GRCL(self.output_channel[0], self.output_channel[0], 49 | num_iteration=5, kernel_size=3, pad=1), 50 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 51 | GRCL(self.output_channel[0], self.output_channel[1], 52 | num_iteration=5, kernel_size=3, pad=1), 53 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 54 | GRCL(self.output_channel[1], self.output_channel[2], 55 | num_iteration=5, kernel_size=3, pad=1), 56 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 57 | nn.Conv2d( 58 | self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 59 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 60 | 61 | def forward(self, input): 62 | return self.ConvNet(input) 63 | 64 | 65 | class ResNet_FeatureExtractor(nn.Module): 66 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 67 | 68 | def __init__(self, input_channel, output_channel=512): 69 | super(ResNet_FeatureExtractor, self).__init__() 70 | self.ConvNet = ResNet(input_channel, output_channel, 71 | BasicBlock, [1, 2, 5, 3]) 72 | 73 | def forward(self, input): 74 | return self.ConvNet(input) 75 | 76 | 77 | # For Gated RCNN 78 | class GRCL(nn.Module): 79 | 80 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 81 | super(GRCL, self).__init__() 82 | self.wgf_u = nn.Conv2d( 83 | input_channel, output_channel, 1, 1, 0, bias=False) 84 | self.wgr_x = nn.Conv2d( 85 | output_channel, output_channel, 1, 1, 0, bias=False) 86 | self.wf_u = nn.Conv2d(input_channel, output_channel, 87 | kernel_size, 1, pad, bias=False) 88 | self.wr_x = nn.Conv2d(output_channel, output_channel, 89 | kernel_size, 1, pad, bias=False) 90 | 91 | self.BN_x_init = nn.BatchNorm2d(output_channel) 92 | 93 | self.num_iteration = num_iteration 94 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 95 | self.GRCL = nn.Sequential(*self.GRCL) 96 | 97 | def forward(self, input): 98 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 99 | thus wgf_u / wf_u is also consistant over time t. 100 | """ 101 | wgf_u = self.wgf_u(input) 102 | wf_u = self.wf_u(input) 103 | x = F.relu(self.BN_x_init(wf_u)) 104 | 105 | for i in range(self.num_iteration): 106 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 107 | 108 | return x 109 | 110 | 111 | class GRCL_unit(nn.Module): 112 | 113 | def __init__(self, output_channel): 114 | super(GRCL_unit, self).__init__() 115 | self.BN_gfu = nn.BatchNorm2d(output_channel) 116 | self.BN_grx = nn.BatchNorm2d(output_channel) 117 | self.BN_fu = nn.BatchNorm2d(output_channel) 118 | self.BN_rx = nn.BatchNorm2d(output_channel) 119 | self.BN_Gx = nn.BatchNorm2d(output_channel) 120 | 121 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 122 | G_first_term = self.BN_gfu(wgf_u) 123 | G_second_term = self.BN_grx(wgr_x) 124 | G = F.sigmoid(G_first_term + G_second_term) 125 | 126 | x_first_term = self.BN_fu(wf_u) 127 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 128 | x = F.relu(x_first_term + x_second_term) 129 | 130 | return x 131 | 132 | 133 | class BasicBlock(nn.Module): 134 | expansion = 1 135 | 136 | def __init__(self, inplanes, planes, stride=1, downsample=None): 137 | super(BasicBlock, self).__init__() 138 | self.conv1 = self._conv3x3(inplanes, planes) 139 | self.bn1 = nn.BatchNorm2d(planes) 140 | self.conv2 = self._conv3x3(planes, planes) 141 | self.bn2 = nn.BatchNorm2d(planes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.downsample = downsample 144 | self.stride = stride 145 | 146 | def _conv3x3(self, in_planes, out_planes, stride=1): 147 | "3x3 convolution with padding" 148 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 149 | padding=1, bias=False) 150 | 151 | def forward(self, x): 152 | residual = x 153 | 154 | out = self.conv1(x) 155 | out = self.bn1(out) 156 | out = self.relu(out) 157 | 158 | out = self.conv2(out) 159 | out = self.bn2(out) 160 | 161 | if self.downsample is not None: 162 | residual = self.downsample(x) 163 | out += residual 164 | out = self.relu(out) 165 | 166 | return out 167 | 168 | 169 | class ResNet(nn.Module): 170 | 171 | def __init__(self, input_channel, output_channel, block, layers): 172 | super(ResNet, self).__init__() 173 | 174 | self.output_channel_block = [ 175 | int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 176 | 177 | self.inplanes = int(output_channel / 8) 178 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 179 | kernel_size=3, stride=1, padding=1, bias=False) 180 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 181 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 182 | kernel_size=3, stride=1, padding=1, bias=False) 183 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 184 | self.relu = nn.ReLU(inplace=True) 185 | 186 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 187 | self.layer1 = self._make_layer( 188 | block, self.output_channel_block[0], layers[0]) 189 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 190 | 0], kernel_size=3, stride=1, padding=1, bias=False) 191 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 192 | 193 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 194 | self.layer2 = self._make_layer( 195 | block, self.output_channel_block[1], layers[1], stride=1) 196 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 197 | 1], kernel_size=3, stride=1, padding=1, bias=False) 198 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 199 | 200 | self.maxpool3 = nn.MaxPool2d( 201 | kernel_size=2, stride=(2, 1), padding=(0, 1)) 202 | self.layer3 = self._make_layer( 203 | block, self.output_channel_block[2], layers[2], stride=1) 204 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 205 | 2], kernel_size=3, stride=1, padding=1, bias=False) 206 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 207 | 208 | self.layer4 = self._make_layer( 209 | block, self.output_channel_block[3], layers[3], stride=1) 210 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 211 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 212 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 213 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 214 | 3], kernel_size=2, stride=1, padding=0, bias=False) 215 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 216 | 217 | def _make_layer(self, block, planes, blocks, stride=1): 218 | downsample = None 219 | if stride != 1 or self.inplanes != planes * block.expansion: 220 | downsample = nn.Sequential( 221 | nn.Conv2d(self.inplanes, planes * block.expansion, 222 | kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample)) 228 | self.inplanes = planes * block.expansion 229 | for i in range(1, blocks): 230 | layers.append(block(self.inplanes, planes)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x): 235 | x = self.conv0_1(x) 236 | x = self.bn0_1(x) 237 | x = self.relu(x) 238 | x = self.conv0_2(x) 239 | x = self.bn0_2(x) 240 | x = self.relu(x) 241 | 242 | x = self.maxpool1(x) 243 | x = self.layer1(x) 244 | x = self.conv1(x) 245 | x = self.bn1(x) 246 | x = self.relu(x) 247 | 248 | x = self.maxpool2(x) 249 | x = self.layer2(x) 250 | x = self.conv2(x) 251 | x = self.bn2(x) 252 | x = self.relu(x) 253 | 254 | x = self.maxpool3(x) 255 | x = self.layer3(x) 256 | x = self.conv3(x) 257 | x = self.bn3(x) 258 | x = self.relu(x) 259 | 260 | x = self.layer4(x) 261 | x = self.conv4_1(x) 262 | x = self.bn4_1(x) 263 | x = self.relu(x) 264 | x = self.conv4_2(x) 265 | x = self.bn4_2(x) 266 | x = self.relu(x) 267 | 268 | return x 269 | -------------------------------------------------------------------------------- /notebooks/export_onnx_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "enormous-reward", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "import warnings\n", 12 | "import onnx\n", 13 | "import torch\n", 14 | "import torch.onnx\n", 15 | "\n", 16 | "warnings.filterwarnings(\"ignore\")\n", 17 | "sys.path.append('..')\n", 18 | "\n", 19 | "from src.model import DefaultModel" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "stock-christian", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "cpu\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", 38 | "print(device)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "partial-slope", 44 | "metadata": {}, 45 | "source": [ 46 | "## Load model network and weight" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "id": "coastal-brunei", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "detector_cfg = '../configs/craft_config.yaml'\n", 57 | "detector_model = '../models/text_detector/craft_mlt_25k.pth'\n", 58 | "recognizer_cfg = '../configs/star_config.yaml'\n", 59 | "recognizer_model = '../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth'" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "brief-victor", 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Loading weights from checkpoint (../models/text_detector/craft_mlt_25k.pth)\n", 73 | "Loading weights from checkpoint (../models/text_recognizer/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth)\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "model = DefaultModel(detector_cfg, detector_model, \n", 79 | " recognizer_cfg, recognizer_model)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "unusual-fence", 85 | "metadata": {}, 86 | "source": [ 87 | "# Detector" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "congressional-enforcement", 93 | "metadata": {}, 94 | "source": [ 95 | "## Exporter Model\n", 96 | "Batch Size X Channel X Height X Width" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "id": "lasting-identification", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "detector_dummy_input = torch.randn(1, 3, 1280, 720)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 6, 112 | "id": "expanded-enzyme", 113 | "metadata": { 114 | "scrolled": true 115 | }, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "(tensor([[[[0.0010, 0.0002],\n", 121 | " [0.0085, 0.0020],\n", 122 | " [0.0010, 0.0002],\n", 123 | " ...,\n", 124 | " [0.0010, 0.0002],\n", 125 | " [0.0019, 0.0016],\n", 126 | " [0.0010, 0.0002]],\n", 127 | " \n", 128 | " [[0.0022, 0.0016],\n", 129 | " [0.0010, 0.0002],\n", 130 | " [0.0010, 0.0002],\n", 131 | " ...,\n", 132 | " [0.0010, 0.0002],\n", 133 | " [0.0010, 0.0002],\n", 134 | " [0.0010, 0.0002]],\n", 135 | " \n", 136 | " [[0.0010, 0.0002],\n", 137 | " [0.0010, 0.0002],\n", 138 | " [0.0010, 0.0002],\n", 139 | " ...,\n", 140 | " [0.0010, 0.0002],\n", 141 | " [0.0010, 0.0002],\n", 142 | " [0.0010, 0.0002]],\n", 143 | " \n", 144 | " ...,\n", 145 | " \n", 146 | " [[0.0010, 0.0002],\n", 147 | " [0.0010, 0.0002],\n", 148 | " [0.0010, 0.0002],\n", 149 | " ...,\n", 150 | " [0.0010, 0.0002],\n", 151 | " [0.0010, 0.0002],\n", 152 | " [0.0010, 0.0002]],\n", 153 | " \n", 154 | " [[0.0128, 0.0013],\n", 155 | " [0.0056, 0.0013],\n", 156 | " [0.0065, 0.0015],\n", 157 | " ...,\n", 158 | " [0.0010, 0.0002],\n", 159 | " [0.0010, 0.0002],\n", 160 | " [0.0020, 0.0006]],\n", 161 | " \n", 162 | " [[0.0065, 0.0086],\n", 163 | " [0.0084, 0.0021],\n", 164 | " [0.0093, 0.0018],\n", 165 | " ...,\n", 166 | " [0.0037, 0.0011],\n", 167 | " [0.0024, 0.0007],\n", 168 | " [0.0100, 0.0079]]]], grad_fn=),\n", 169 | " tensor([[[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0552],\n", 170 | " [0.0870, 0.0165, 0.2143, ..., 0.0000, 0.0000, 0.0886],\n", 171 | " [0.0047, 0.0641, 0.2276, ..., 0.0000, 0.0000, 0.0913],\n", 172 | " ...,\n", 173 | " [0.0000, 0.0000, 0.2815, ..., 1.2193, 1.1758, 1.3355],\n", 174 | " [0.0000, 0.0000, 0.4303, ..., 1.3450, 1.2745, 1.4509],\n", 175 | " [0.0000, 0.0000, 0.2416, ..., 0.8716, 0.9751, 1.0762]],\n", 176 | " \n", 177 | " [[0.1752, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.3734],\n", 178 | " [0.1021, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0797],\n", 179 | " [0.1667, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0534],\n", 180 | " ...,\n", 181 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 182 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 183 | " [0.0576, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.1832]],\n", 184 | " \n", 185 | " [[1.0205, 0.6065, 0.8373, ..., 0.2779, 0.2694, 0.6539],\n", 186 | " [0.9442, 0.3013, 0.7273, ..., 0.6779, 0.8180, 1.2405],\n", 187 | " [1.3239, 0.6361, 0.8955, ..., 0.7086, 0.7718, 1.0392],\n", 188 | " ...,\n", 189 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.2606, 2.2630],\n", 190 | " [0.0000, 0.0000, 0.0000, ..., 0.4381, 0.4628, 1.9584],\n", 191 | " [0.0000, 0.0000, 0.0000, ..., 0.2608, 0.3710, 1.2387]],\n", 192 | " \n", 193 | " ...,\n", 194 | " \n", 195 | " [[0.0000, 0.0000, 0.0221, ..., 0.0000, 0.0000, 0.0000],\n", 196 | " [0.0000, 0.2290, 0.2829, ..., 0.0000, 0.0000, 0.0000],\n", 197 | " [0.0000, 0.0408, 0.0449, ..., 0.0000, 0.0000, 0.0000],\n", 198 | " ...,\n", 199 | " [0.7339, 1.3421, 1.8549, ..., 2.7119, 2.3538, 1.3433],\n", 200 | " [0.2865, 0.6955, 0.9729, ..., 1.9042, 1.6367, 1.0503],\n", 201 | " [0.3265, 0.5312, 0.6013, ..., 1.1321, 0.9693, 0.7332]],\n", 202 | " \n", 203 | " [[0.3608, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 204 | " [0.0601, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 205 | " [0.0870, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 206 | " ...,\n", 207 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 208 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 209 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", 210 | " \n", 211 | " [[0.0937, 0.1586, 0.0410, ..., 0.1801, 0.1988, 0.4503],\n", 212 | " [0.0000, 0.0638, 0.1298, ..., 0.0000, 0.0304, 0.4700],\n", 213 | " [0.0000, 0.1438, 0.1875, ..., 0.0000, 0.1855, 0.5608],\n", 214 | " ...,\n", 215 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 216 | " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", 217 | " [0.0997, 0.0000, 0.0000, ..., 0.0000, 0.2321, 0.6662]]]],\n", 218 | " grad_fn=))" 219 | ] 220 | }, 221 | "execution_count": 6, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "model.detector(detector_dummy_input)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 7, 233 | "id": "devoted-information", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "out_detector_model = '../models/text_detector/craft.onnx'" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 8, 243 | "id": "arbitrary-valuation", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "# Export the model\n", 248 | "torch.onnx.export(model.detector, \n", 249 | " detector_dummy_input,\n", 250 | " out_detector_model,\n", 251 | " export_params=True,\n", 252 | " opset_version=13,\n", 253 | " do_constant_folding=True,\n", 254 | " input_names = ['input'],\n", 255 | " output_names = ['output'],\n", 256 | " dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},\n", 257 | " 'output' : {0:'batch_size'}})" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "id": "reflected-librarian", 263 | "metadata": {}, 264 | "source": [ 265 | "## Inspecting Model" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 9, 271 | "id": "cognitive-albania", 272 | "metadata": { 273 | "scrolled": true 274 | }, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "graph torch-jit-export (\n", 281 | " %input[FLOAT, batch_sizex3xheightxwidth]\n", 282 | ") initializers (\n", 283 | " %basenet.slice5.1.weight[FLOAT, 1024x512x3x3]\n", 284 | " %basenet.slice5.1.bias[FLOAT, 1024]\n", 285 | " %basenet.slice5.2.weight[FLOAT, 1024x1024x1x1]\n", 286 | " %basenet.slice5.2.bias[FLOAT, 1024]\n", 287 | " %conv_cls.0.weight[FLOAT, 32x32x3x3]\n", 288 | " %conv_cls.0.bias[FLOAT, 32]\n", 289 | " %conv_cls.2.weight[FLOAT, 32x32x3x3]\n", 290 | " %conv_cls.2.bias[FLOAT, 32]\n", 291 | " %conv_cls.4.weight[FLOAT, 16x32x3x3]\n", 292 | " %conv_cls.4.bias[FLOAT, 16]\n", 293 | " %conv_cls.6.weight[FLOAT, 16x16x1x1]\n", 294 | " %conv_cls.6.bias[FLOAT, 16]\n", 295 | " %conv_cls.8.weight[FLOAT, 2x16x1x1]\n", 296 | " %conv_cls.8.bias[FLOAT, 2]\n", 297 | " %299[FLOAT, 64x3x3x3]\n", 298 | " %300[FLOAT, 64]\n", 299 | " %302[FLOAT, 64x64x3x3]\n", 300 | " %303[FLOAT, 64]\n", 301 | " %305[FLOAT, 128x64x3x3]\n", 302 | " %306[FLOAT, 128]\n", 303 | " %308[FLOAT, 128x128x3x3]\n", 304 | " %309[FLOAT, 128]\n", 305 | " %311[FLOAT, 256x128x3x3]\n", 306 | " %312[FLOAT, 256]\n", 307 | " %314[FLOAT, 256x256x3x3]\n", 308 | " %315[FLOAT, 256]\n", 309 | " %317[FLOAT, 256x256x3x3]\n", 310 | " %318[FLOAT, 256]\n", 311 | " %320[FLOAT, 512x256x3x3]\n", 312 | " %321[FLOAT, 512]\n", 313 | " %323[FLOAT, 512x512x3x3]\n", 314 | " %324[FLOAT, 512]\n", 315 | " %326[FLOAT, 512x512x3x3]\n", 316 | " %327[FLOAT, 512]\n", 317 | " %329[FLOAT, 512x512x3x3]\n", 318 | " %330[FLOAT, 512]\n", 319 | " %332[FLOAT, 512x512x3x3]\n", 320 | " %333[FLOAT, 512]\n", 321 | " %335[FLOAT, 512x1536x1x1]\n", 322 | " %336[FLOAT, 512]\n", 323 | " %338[FLOAT, 256x512x3x3]\n", 324 | " %339[FLOAT, 256]\n", 325 | " %341[FLOAT, 256x768x1x1]\n", 326 | " %342[FLOAT, 256]\n", 327 | " %344[FLOAT, 128x256x3x3]\n", 328 | " %345[FLOAT, 128]\n", 329 | " %347[FLOAT, 128x384x1x1]\n", 330 | " %348[FLOAT, 128]\n", 331 | " %350[FLOAT, 64x128x3x3]\n", 332 | " %351[FLOAT, 64]\n", 333 | " %353[FLOAT, 64x192x1x1]\n", 334 | " %354[FLOAT, 64]\n", 335 | " %356[FLOAT, 32x64x3x3]\n", 336 | " %357[FLOAT, 32]\n", 337 | ") {\n", 338 | " %298 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%input, %299, %300)\n", 339 | " %157 = Relu(%298)\n", 340 | " %301 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%157, %302, %303)\n", 341 | " %160 = Relu(%301)\n", 342 | " %161 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%160)\n", 343 | " %304 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%161, %305, %306)\n", 344 | " %164 = Relu(%304)\n", 345 | " %307 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%164, %308, %309)\n", 346 | " %167 = Relu(%307)\n", 347 | " %168 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%167)\n", 348 | " %310 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%168, %311, %312)\n", 349 | " %171 = Relu(%310)\n", 350 | " %313 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%171, %314, %315)\n", 351 | " %174 = Relu(%313)\n", 352 | " %316 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%174, %317, %318)\n", 353 | " %177 = Relu(%316)\n", 354 | " %178 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%177)\n", 355 | " %319 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%178, %320, %321)\n", 356 | " %181 = Relu(%319)\n", 357 | " %322 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%181, %323, %324)\n", 358 | " %184 = Relu(%322)\n", 359 | " %325 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%184, %326, %327)\n", 360 | " %187 = Relu(%325)\n", 361 | " %188 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%187)\n", 362 | " %328 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%188, %329, %330)\n", 363 | " %191 = Relu(%328)\n", 364 | " %331 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%191, %332, %333)\n", 365 | " %194 = MaxPool[ceil_mode = 0, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%331)\n", 366 | " %195 = Conv[dilations = [6, 6], group = 1, kernel_shape = [3, 3], pads = [6, 6, 6, 6], strides = [1, 1]](%194, %basenet.slice5.1.weight, %basenet.slice5.1.bias)\n", 367 | " %196 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%195, %basenet.slice5.2.weight, %basenet.slice5.2.bias)\n", 368 | " %197 = Concat[axis = 1](%196, %331)\n", 369 | " %334 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%197, %335, %336)\n", 370 | " %200 = Relu(%334)\n", 371 | " %337 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%200, %338, %339)\n", 372 | " %203 = Relu(%337)\n", 373 | " %204 = Shape(%184)\n", 374 | " %205 = Constant[value = ]()\n", 375 | " %206 = Gather[axis = 0](%204, %205)\n", 376 | " %207 = Shape(%184)\n", 377 | " %208 = Constant[value = ]()\n", 378 | " %209 = Gather[axis = 0](%207, %208)\n", 379 | " %210 = Constant[value = ]()\n", 380 | " %211 = Unsqueeze(%206, %210)\n", 381 | " %212 = Constant[value = ]()\n", 382 | " %213 = Unsqueeze(%209, %212)\n", 383 | " %214 = Concat[axis = 0](%211, %213)\n", 384 | " %215 = Shape(%203)\n", 385 | " %216 = Constant[value = ]()\n", 386 | " %217 = Constant[value = ]()\n", 387 | " %218 = Constant[value = ]()\n", 388 | " %219 = Slice(%215, %217, %218, %216)\n", 389 | " %220 = Cast[to = 7](%214)\n", 390 | " %221 = Concat[axis = 0](%219, %220)\n", 391 | " %224 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%203, %, %, %221)\n", 392 | " %225 = Concat[axis = 1](%224, %184)\n", 393 | " %340 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%225, %341, %342)\n", 394 | " %228 = Relu(%340)\n", 395 | " %343 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%228, %344, %345)\n", 396 | " %231 = Relu(%343)\n", 397 | " %232 = Shape(%174)\n", 398 | " %233 = Constant[value = ]()\n", 399 | " %234 = Gather[axis = 0](%232, %233)\n", 400 | " %235 = Shape(%174)\n", 401 | " %236 = Constant[value = ]()\n", 402 | " %237 = Gather[axis = 0](%235, %236)\n", 403 | " %238 = Constant[value = ]()\n", 404 | " %239 = Unsqueeze(%234, %238)\n", 405 | " %240 = Constant[value = ]()\n", 406 | " %241 = Unsqueeze(%237, %240)\n", 407 | " %242 = Concat[axis = 0](%239, %241)\n", 408 | " %243 = Shape(%231)\n", 409 | " %244 = Constant[value = ]()\n", 410 | " %245 = Constant[value = ]()\n", 411 | " %246 = Constant[value = ]()\n", 412 | " %247 = Slice(%243, %245, %246, %244)\n", 413 | " %248 = Cast[to = 7](%242)\n", 414 | " %249 = Concat[axis = 0](%247, %248)\n", 415 | " %252 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%231, %, %, %249)\n", 416 | " %253 = Concat[axis = 1](%252, %174)\n", 417 | " %346 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%253, %347, %348)\n", 418 | " %256 = Relu(%346)\n", 419 | " %349 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%256, %350, %351)\n", 420 | " %259 = Relu(%349)\n", 421 | " %260 = Shape(%167)\n", 422 | " %261 = Constant[value = ]()\n", 423 | " %262 = Gather[axis = 0](%260, %261)\n", 424 | " %263 = Shape(%167)\n", 425 | " %264 = Constant[value = ]()\n", 426 | " %265 = Gather[axis = 0](%263, %264)\n", 427 | " %266 = Constant[value = ]()\n", 428 | " %267 = Unsqueeze(%262, %266)\n", 429 | " %268 = Constant[value = ]()\n", 430 | " %269 = Unsqueeze(%265, %268)\n", 431 | " %270 = Concat[axis = 0](%267, %269)\n", 432 | " %271 = Shape(%259)\n", 433 | " %272 = Constant[value = ]()\n", 434 | " %273 = Constant[value = ]()\n", 435 | " %274 = Constant[value = ]()\n", 436 | " %275 = Slice(%271, %273, %274, %272)\n", 437 | " %276 = Cast[to = 7](%270)\n", 438 | " %277 = Concat[axis = 0](%275, %276)\n", 439 | " %280 = Resize[coordinate_transformation_mode = 'pytorch_half_pixel', cubic_coeff_a = -0.75, mode = 'linear', nearest_mode = 'floor'](%259, %, %, %277)\n", 440 | " %281 = Concat[axis = 1](%280, %167)\n", 441 | " %352 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%281, %353, %354)\n", 442 | " %284 = Relu(%352)\n", 443 | " %355 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%284, %356, %357)\n", 444 | " %287 = Relu(%355)\n", 445 | " %288 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%287, %conv_cls.0.weight, %conv_cls.0.bias)\n", 446 | " %289 = Relu(%288)\n", 447 | " %290 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%289, %conv_cls.2.weight, %conv_cls.2.bias)\n", 448 | " %291 = Relu(%290)\n", 449 | " %292 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%291, %conv_cls.4.weight, %conv_cls.4.bias)\n", 450 | " %293 = Relu(%292)\n", 451 | " %294 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%293, %conv_cls.6.weight, %conv_cls.6.bias)\n", 452 | " %295 = Relu(%294)\n", 453 | " %296 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%295, %conv_cls.8.weight, %conv_cls.8.bias)\n", 454 | " %output = Transpose[perm = [0, 2, 3, 1]](%296)\n", 455 | " return %output, %287\n", 456 | "}\n" 457 | ] 458 | } 459 | ], 460 | "source": [ 461 | "# Load the ONNX model\n", 462 | "onnx_model = onnx.load(out_detector_model)\n", 463 | "\n", 464 | "# Check that the IR is well formed\n", 465 | "onnx.checker.check_model(onnx_model)\n", 466 | "\n", 467 | "# Print a human readable representation of the graph\n", 468 | "print(onnx.helper.printable_graph(onnx_model.graph))" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "id": "israeli-northwest", 474 | "metadata": {}, 475 | "source": [ 476 | "# Recognizer\n", 477 | "\n", 478 | "ERROR UNSOLVED BY CREATOR\n", 479 | "https://github.com/pytorch/pytorch/issues/27212" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "id": "eleven-institution", 485 | "metadata": {}, 486 | "source": [ 487 | "## Exporter Model\n", 488 | "Batch Size X Channel X Height X Width" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 10, 494 | "id": "italian-trailer", 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "recognizer_dummy_input = torch.randn(100, 1, 32, 100)\n", 499 | "recognizer_dummy_text = torch.LongTensor(100, 26).fill_(0)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 11, 505 | "id": "sized-truck", 506 | "metadata": { 507 | "scrolled": true 508 | }, 509 | "outputs": [ 510 | { 511 | "data": { 512 | "text/plain": [ 513 | "tensor([[[ -6.9061, -7.5613, -5.4793, ..., -2.4262, -4.9112, -4.1666],\n", 514 | " [-11.9275, -13.6367, -11.6911, ..., -10.5464, -11.6802, -11.4881],\n", 515 | " [-15.8177, -16.6920, -15.7986, ..., -15.0340, -15.5713, -15.5707],\n", 516 | " ...,\n", 517 | " [-13.8261, -2.1892, -13.1530, ..., -13.9286, -13.0425, -13.7967],\n", 518 | " [-13.8304, -2.1870, -13.1481, ..., -13.9520, -13.0352, -13.8002],\n", 519 | " [-13.8197, -2.1547, -13.1142, ..., -13.9470, -13.0154, -13.7876]],\n", 520 | "\n", 521 | " [[ -8.2746, -8.6355, -6.8791, ..., -6.2013, -8.2145, -6.9954],\n", 522 | " [-13.3459, -18.1712, -12.9131, ..., -12.4395, -12.8243, -11.7817],\n", 523 | " [-15.9471, -18.3970, -13.9544, ..., -14.9989, -15.8966, -14.6564],\n", 524 | " ...,\n", 525 | " [-13.8625, -3.7098, -11.3744, ..., -13.9167, -12.4986, -14.3229],\n", 526 | " [-13.8717, -3.6356, -11.3872, ..., -13.9434, -12.5487, -14.3171],\n", 527 | " [-13.8459, -3.6648, -11.3504, ..., -13.9306, -12.5294, -14.2871]],\n", 528 | "\n", 529 | " [[ -9.2234, -8.9115, -7.3134, ..., -8.0188, -9.4038, -7.2299],\n", 530 | " [-14.8482, -16.4109, -13.4717, ..., -14.9134, -15.2613, -14.8936],\n", 531 | " [-18.2801, -17.6252, -17.8121, ..., -17.7865, -18.4960, -17.5789],\n", 532 | " ...,\n", 533 | " [-14.0301, -4.4127, -15.5972, ..., -13.7900, -13.4122, -14.4714],\n", 534 | " [-14.0370, -4.2523, -15.5777, ..., -13.7668, -13.4058, -14.4595],\n", 535 | " [-14.0503, -4.2084, -15.6115, ..., -13.8018, -13.4401, -14.5031]],\n", 536 | "\n", 537 | " ...,\n", 538 | "\n", 539 | " [[ -7.1346, -8.7684, -3.9511, ..., -6.0389, -7.8837, -6.5712],\n", 540 | " [-13.3955, -14.4589, -13.1261, ..., -12.2595, -12.9576, -12.9921],\n", 541 | " [-16.3747, -14.4627, -16.1623, ..., -15.8627, -17.5044, -15.9218],\n", 542 | " ...,\n", 543 | " [-12.4704, -3.2038, -12.0584, ..., -13.4266, -12.3316, -12.8561],\n", 544 | " [-12.4187, -3.2129, -11.9786, ..., -13.3757, -12.2918, -12.8280],\n", 545 | " [-12.3670, -3.2296, -11.9009, ..., -13.3300, -12.2517, -12.7935]],\n", 546 | "\n", 547 | " [[ -8.4369, -8.5261, -6.7348, ..., -7.9101, -9.0706, -8.3969],\n", 548 | " [-13.5088, -12.4055, -13.0897, ..., -12.9229, -13.2860, -13.7020],\n", 549 | " [-15.6351, -13.3638, -15.9259, ..., -15.3991, -15.9885, -16.2208],\n", 550 | " ...,\n", 551 | " [-11.8698, -4.2842, -11.2330, ..., -13.6111, -12.6553, -13.3193],\n", 552 | " [-11.8772, -4.2873, -11.2386, ..., -13.6150, -12.6616, -13.3284],\n", 553 | " [-11.8831, -4.3029, -11.2439, ..., -13.6189, -12.6683, -13.3351]],\n", 554 | "\n", 555 | " [[ -5.7706, -4.6843, -6.4071, ..., -3.7921, -4.4570, -4.4655],\n", 556 | " [-12.4181, -13.1821, -10.1006, ..., -9.4514, -11.3294, -11.0276],\n", 557 | " [-13.5660, -13.7678, -9.6513, ..., -13.9703, -14.8214, -13.9399],\n", 558 | " ...,\n", 559 | " [-13.8621, -3.4651, -13.2679, ..., -14.5769, -13.8149, -14.4902],\n", 560 | " [-13.8660, -3.4625, -13.2367, ..., -14.5793, -13.8194, -14.4901],\n", 561 | " [-13.8639, -3.4490, -13.2186, ..., -14.5810, -13.8137, -14.4848]]],\n", 562 | " grad_fn=)" 563 | ] 564 | }, 565 | "execution_count": 11, 566 | "metadata": {}, 567 | "output_type": "execute_result" 568 | } 569 | ], 570 | "source": [ 571 | "model.recognizer.module(recognizer_dummy_input, recognizer_dummy_text)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 12, 577 | "id": "enabling-philippines", 578 | "metadata": {}, 579 | "outputs": [], 580 | "source": [ 581 | "out_recognizer_model = '../models/text_recognizer/star.onnx'" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 13, 587 | "id": "roman-medicare", 588 | "metadata": {}, 589 | "outputs": [ 590 | { 591 | "ename": "RuntimeError", 592 | "evalue": "Exporting the operator grid_sampler to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.", 593 | "output_type": "error", 594 | "traceback": [ 595 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 596 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 597 | "\u001b[0;32m/tmp/ipykernel_17483/4145972323.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Export the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m torch.onnx.export(model.recognizer.module, \n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrecognizer_dummy_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecognizer_dummy_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mout_recognizer_model\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mexport_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 598 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36mexport\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 275\u001b[0;31m return utils.export(model, args, f, export_params, verbose, training,\n\u001b[0m\u001b[1;32m 276\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maten\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexport_raw_ir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 599 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36mexport\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0moperator_export_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOperatorExportTypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mONNX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m _export(model, args, f, export_params, verbose, training, input_names, output_names,\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdo_constant_folding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdo_constant_folding\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 600 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_export\u001b[0;34m(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference)\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 688\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_out\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 689\u001b[0;31m _model_to_graph(model, args, verbose, input_names,\n\u001b[0m\u001b[1;32m 690\u001b[0m \u001b[0moutput_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 691\u001b[0m \u001b[0mexample_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_retain_param_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 601 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_model_to_graph\u001b[0;34m(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0mparams_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_named_param_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m graph = _optimize_graph(graph, operator_export_type,\n\u001b[0m\u001b[1;32m 464\u001b[0m \u001b[0m_disable_torch_constant_prop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_disable_torch_constant_prop\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[0mfixed_batch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfixed_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 602 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_optimize_graph\u001b[0;34m(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mdynamic_axes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdynamic_axes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mdynamic_axes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx_set_dynamic_input_shape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdynamic_axes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_names\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jit_pass_lint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 603 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/__init__.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_symbolic_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 314\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 604 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_run_symbolic_function\u001b[0;34m(g, block, n, inputs, env, operator_export_type)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;31m# Export it regularly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 989\u001b[0m \u001b[0mdomain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m \u001b[0msymbolic_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_find_symbolic_in_registry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator_export_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 991\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msymbolic_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 992\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 605 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/utils.py\u001b[0m in \u001b[0;36m_find_symbolic_in_registry\u001b[0;34m(domain, op_name, opset_version, operator_export_type)\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0;31m# Use the original node directly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 943\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 944\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msym_registry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_registered_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopset_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 945\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 946\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 606 | "\u001b[0;32m~/miniconda3/envs/receipt-ocr/lib/python3.9/site-packages/torch/onnx/symbolic_registry.py\u001b[0m in \u001b[0;36mget_registered_op\u001b[0;34m(opname, domain, version)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"Please feel free to request support or submit a pull request on PyTorch GitHub.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_registry\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdomain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mopname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 607 | "\u001b[0;31mRuntimeError\u001b[0m: Exporting the operator grid_sampler to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub." 608 | ] 609 | } 610 | ], 611 | "source": [ 612 | "# Export the model\n", 613 | "torch.onnx.export(model.recognizer.module, \n", 614 | " (recognizer_dummy_input, recognizer_dummy_text),\n", 615 | " out_recognizer_model,\n", 616 | " export_params=True,\n", 617 | " opset_version=13,\n", 618 | " do_constant_folding=True,\n", 619 | " input_names = ['input'],\n", 620 | " output_names = ['output'],\n", 621 | " dynamic_axes={'input' : {0:'batch_size', 2:'height', 3:'width'},\n", 622 | " 'output' : {0:'batch_size'}})" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "id": "outside-grill", 628 | "metadata": {}, 629 | "source": [ 630 | "## Inspecting Model" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "id": "understood-lobby", 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "# Load the ONNX model\n", 641 | "onnx_model = onnx.load(out_recognizer_model)\n", 642 | "\n", 643 | "# Check that the IR is well formed\n", 644 | "onnx.checker.check_model(onnx_model)\n", 645 | "\n", 646 | "# Print a human readable representation of the graph\n", 647 | "print(onnx.helper.printable_graph(onnx_model.graph))" 648 | ] 649 | } 650 | ], 651 | "metadata": { 652 | "kernelspec": { 653 | "display_name": "Python [conda env:receipt-ocr]", 654 | "language": "python", 655 | "name": "conda-env-receipt-ocr-py" 656 | }, 657 | "language_info": { 658 | "codemirror_mode": { 659 | "name": "ipython", 660 | "version": 3 661 | }, 662 | "file_extension": ".py", 663 | "mimetype": "text/x-python", 664 | "name": "python", 665 | "nbconvert_exporter": "python", 666 | "pygments_lexer": "ipython3", 667 | "version": "3.9.7" 668 | } 669 | }, 670 | "nbformat": 4, 671 | "nbformat_minor": 5 672 | } 673 | --------------------------------------------------------------------------------