├── ocr ├── __init__.py ├── config.py ├── metrics.py ├── models.py ├── utils.py ├── tokenizer.py ├── predictor.py ├── dataset.py └── transforms.py ├── data └── README.md ├── .dockerignore ├── .gitignore ├── setup.py ├── requirements.txt ├── Makefile ├── Dockerfile ├── LICENSE ├── scripts ├── ocr_config.json ├── torch2onnx.py ├── evaluate.py ├── inference.ipynb ├── OCR-GoogleColab.ipynb ├── train.py └── prepare_dataset.py └── README.md /ocr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Data folder 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .pytest_cache/ 3 | **/__pycache__/ 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__/ 3 | data/ 4 | .DS_Store 5 | ._.DS_Store 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | with open('requirements.txt') as f: 5 | packages = f.read().splitlines() 6 | 7 | setup( 8 | name='ocrmodel', 9 | packages=['ocr'], 10 | install_requires=packages 11 | ) 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.1 2 | torch>=1.6.0 3 | torchvision>=0.7.0 4 | opencv-python==4.6.0.66 5 | pandas==1.3.4 6 | tqdm==4.62.3 7 | scikit-learn==1.0.1 8 | scipy==1.4.1 9 | matplotlib==3.5.0 10 | Pillow==8.4.0 11 | onnxruntime==1.13.1 12 | openvino==2022.2.0 13 | albumentations==1.1.0 14 | ctcdecode @ git+https://github.com/parlance/ctcdecode 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | NAME?=ocr-model 2 | 3 | GPUS?=all 4 | GPUS_OPTION=--gpus=$(GPUS) 5 | 6 | CPUS?=none 7 | ifeq ($(CPUS), none) 8 | CPUS_OPTION= 9 | else 10 | CPUS_OPTION=--cpus=$(CPUS) 11 | endif 12 | 13 | .PHONY: all stop build run 14 | 15 | all: stop build run 16 | 17 | build: 18 | docker build \ 19 | -t $(NAME) . 20 | 21 | stop: 22 | -docker stop $(NAME) 23 | -docker rm $(NAME) 24 | 25 | run: 26 | docker run --rm -it \ 27 | $(GPUS_OPTION) \ 28 | $(CPUS_OPTION) \ 29 | --net=host \ 30 | --ipc=host \ 31 | -v $(shell pwd):/workdir \ 32 | --name=$(NAME) \ 33 | $(NAME) \ 34 | bash 35 | -------------------------------------------------------------------------------- /ocr/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Config: 5 | """Class to handle config.json.""" 6 | 7 | def __init__(self, config_path): 8 | with open(config_path, 'r') as f: 9 | self.config = json.load(f) 10 | 11 | def get(self, key): 12 | return self.config[key] 13 | 14 | def get_train(self, key): 15 | return self.config['train'][key] 16 | 17 | def get_val(self, key): 18 | return self.config['val'][key] 19 | 20 | def get_test(self, key): 21 | return self.config['test'][key] 22 | 23 | def get_image(self, key): 24 | return self.config['image'][key] 25 | 26 | def get_train_datasets(self, key): 27 | return [data[key] for data in self.config['train']['datasets']] 28 | 29 | def get_val_datasets(self, key): 30 | return [data[key] for data in self.config['val']['datasets']] 31 | 32 | def get_test_datasets(self, key): 33 | return [data[key] for data in self.config['test']['datasets']] 34 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.4.0-devel-ubuntu20.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN apt-get update &&\ 6 | apt-get -y install \ 7 | build-essential yasm nasm cmake \ 8 | git htop tmux \ 9 | python3 python3-pip python3-dev python3-setuptools python3-opencv &&\ 10 | ln -s /usr/bin/python3 /usr/bin/python &&\ 11 | ln -sf /usr/bin/pip3 /usr/bin/pip &&\ 12 | apt-get clean &&\ 13 | apt-get autoremove &&\ 14 | rm -rf /var/lib/apt/lists/* &&\ 15 | rm -rf /var/cache/apt/archives/* 16 | 17 | # Upgrade pip for cv package instalation 18 | RUN pip3 install --upgrade pip==21.0.1 19 | 20 | # Install PyTorch 21 | RUN pip3 install --no-cache-dir \ 22 | torch==1.9.0+cu111 \ 23 | torchvision==0.10.0+cu111 \ 24 | -f https://download.pytorch.org/whl/torch_stable.html 25 | 26 | ENV PYTHONPATH $PYTHONPATH:/workdir 27 | ENV TORCH_HOME=/workdir/data/.torch 28 | ENV LANG C.UTF-8 29 | 30 | WORKDIR /workdir 31 | 32 | # Install python ML packages 33 | COPY requirements.txt /workdir 34 | RUN pip3 install --no-cache-dir -r requirements.txt 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sber AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/ocr_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "alphabet": " !\"'()*+,-./0123456789:;<=>?ABCDEFGHIJKLMNOPRSTVWY[\\]_abcdefghiklmnoprstuvwxyz|}ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё’№", 3 | "save_dir": "/workdir/data/experiments/test", 4 | "num_epochs": 100, 5 | "pretrain_path": "", 6 | "image": { 7 | "width": 1024, 8 | "height": 128 9 | }, 10 | "train": { 11 | "datasets": [ 12 | { 13 | "csv_path": "train.csv", 14 | "prob": 1 15 | } 16 | 17 | ], 18 | "epoch_size": 100000, 19 | "batch_size": 64 20 | }, 21 | "val": { 22 | "datasets": [ 23 | { 24 | "csv_path": "val.csv", 25 | "prob": 1 26 | } 27 | 28 | ], 29 | "epoch_size": null, 30 | "batch_size": 64 31 | }, 32 | "test": { 33 | "datasets": [ 34 | { 35 | "csv_path": "test.csv", 36 | "prob": 1 37 | } 38 | 39 | ], 40 | "epoch_size": null, 41 | "batch_size": 64 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /scripts/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pathlib import Path 4 | import argparse 5 | 6 | from ocr.predictor import OCRTorchModel 7 | from ocr.config import Config 8 | from ocr.utils import configure_logging 9 | 10 | 11 | def main(args): 12 | logger = configure_logging() 13 | 14 | config = Config(args.config_path) 15 | 16 | ocr_torch_model = OCRTorchModel( 17 | model_path=args.model_path, 18 | config=config, 19 | decoder=None, 20 | device='cpu' 21 | ) 22 | 23 | onnx_path = Path(args.model_path) 24 | onnx_path = onnx_path.parents[0] / onnx_path.stem 25 | onnx_path = str(onnx_path) + '.onnx' 26 | 27 | example_forward_input = torch.rand( 28 | 1, 3, config.get_image('height'), config.get_image('width')) 29 | 30 | torch.onnx.export(ocr_torch_model.model, 31 | example_forward_input, 32 | onnx_path, 33 | opset_version=12, 34 | input_names=['input'], 35 | output_names=['output'], 36 | dynamic_axes={'input': {0: 'batch_size'}, 37 | 'output': {0: 'batch_size'}} 38 | ) 39 | logger.info(f"ONNX model was saved to '{onnx_path}'") 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--config_path', type=str, 45 | default='/workdir/scripts/ocr_config.json', 46 | help='Path to config.json.') 47 | parser.add_argument('--model_path', type=str, 48 | help='Path to torch model weights.') 49 | args = parser.parse_args() 50 | main(args) 51 | -------------------------------------------------------------------------------- /ocr/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_accuracy(y_true, y_pred): 5 | """Calc accuracy between two list of strings.""" 6 | scores = [] 7 | for true, pred in zip(y_true, y_pred): 8 | scores.append(true == pred) 9 | avg_score = np.mean(scores) 10 | return avg_score 11 | 12 | 13 | def levenshtein_distance(first, second): 14 | distance = [[0 for _ in range(len(second) + 1)] 15 | for _ in range(len(first) + 1)] 16 | for i in range(len(first) + 1): 17 | for j in range(len(second) + 1): 18 | if i == 0: 19 | distance[i][j] = j 20 | elif j == 0: 21 | distance[i][j] = i 22 | else: 23 | diag = distance[i - 1][j - 1] + (first[i - 1] != second[j - 1]) 24 | upper = distance[i - 1][j] + 1 25 | left = distance[i][j - 1] + 1 26 | distance[i][j] = min(diag, upper, left) 27 | return distance[len(first)][len(second)] 28 | 29 | 30 | def cer(gt_texts, pred_texts): 31 | assert len(pred_texts) == len(gt_texts) 32 | lev_distances, num_gt_chars = 0, 0 33 | for pred_text, gt_text in zip(pred_texts, gt_texts): 34 | lev_distances += levenshtein_distance(pred_text, gt_text) 35 | num_gt_chars += len(gt_text) 36 | return lev_distances / num_gt_chars 37 | 38 | 39 | def wer(gt_texts, pred_texts): 40 | assert len(pred_texts) == len(gt_texts) 41 | lev_distances, num_gt_words = 0, 0 42 | for pred_text, gt_text in zip(pred_texts, gt_texts): 43 | gt_words, pred_words = gt_text.split(), pred_text.split() 44 | lev_distances += levenshtein_distance(pred_words, gt_words) 45 | num_gt_words += len(gt_words) 46 | return lev_distances / num_gt_words 47 | -------------------------------------------------------------------------------- /ocr/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | 4 | 5 | class GlobalMaxPool2d(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x): 10 | return x.max(dim=-2, keepdim=True)[0] 11 | 12 | 13 | def get_resnet34_backbone(pretrained=True): 14 | m = torchvision.models.resnet34(pretrained=pretrained) 15 | input_conv = nn.Conv2d(3, 64, 7, 1, 3) 16 | blocks = [input_conv, m.bn1, m.relu, 17 | m.maxpool, m.layer1, m.layer2, m.layer3] 18 | return nn.Sequential(*blocks) 19 | 20 | 21 | class BiLSTM(nn.Module): 22 | def __init__(self, input_size, hidden_size, num_layers, dropout=0.1): 23 | super().__init__() 24 | self.lstm = nn.LSTM( 25 | input_size, hidden_size, num_layers, 26 | dropout=dropout, batch_first=True, bidirectional=True) 27 | 28 | def forward(self, x): 29 | out, _ = self.lstm(x) 30 | return out 31 | 32 | 33 | class CRNN(nn.Module): 34 | def __init__( 35 | self, number_class_symbols, time_feature_count=256, lstm_hidden=256, 36 | lstm_len=3, pretrained=True 37 | ): 38 | super().__init__() 39 | self.feature_extractor = get_resnet34_backbone(pretrained=pretrained) 40 | self.global_maxpool = GlobalMaxPool2d() 41 | self.bilstm = BiLSTM(time_feature_count, lstm_hidden, lstm_len) 42 | self.classifier = nn.Sequential( 43 | nn.Linear(lstm_hidden * 2, time_feature_count), 44 | nn.GELU(), 45 | nn.Dropout(0.1), 46 | nn.Linear(time_feature_count, number_class_symbols) 47 | ) 48 | 49 | def forward(self, x): 50 | x = self.feature_extractor(x) 51 | x = self.global_maxpool(x) 52 | x = x.squeeze(2) 53 | x = x.permute(0, 2, 1) 54 | x = self.bilstm(x) 55 | x = self.classifier(x) 56 | x = nn.functional.log_softmax(x, dim=2).permute(1, 0, 2) 57 | return x 58 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import argparse 4 | 5 | from ocr.dataset import get_data_loader 6 | from ocr.utils import val_loop, configure_logging 7 | from ocr.transforms import get_val_transforms 8 | from ocr.tokenizer import Tokenizer, BeamSearcDecoder, BestPathDecoder 9 | from ocr.config import Config 10 | from ocr.models import CRNN 11 | 12 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | def main(args): 16 | config = Config(args.config_path) 17 | tokenizer = Tokenizer(config.get('alphabet')) 18 | logger = configure_logging() 19 | 20 | val_transforms = get_val_transforms( 21 | height=config.get_image('height'), 22 | width=config.get_image('width') 23 | ) 24 | 25 | model = CRNN(number_class_symbols=tokenizer.get_num_chars()) 26 | model.load_state_dict(torch.load(args.model_path)) 27 | model.to(DEVICE) 28 | 29 | csv_paths = config.get_test_datasets('csv_path') 30 | dataset_probs = config.get_test_datasets('prob') 31 | 32 | if args.lm_path: 33 | decoder = BeamSearcDecoder(config.get('alphabet'), args.lm_path) 34 | else: 35 | decoder = BestPathDecoder(config.get('alphabet')) 36 | 37 | acc_avg_all = [] 38 | 39 | for csv_path, dataset_prob in zip(csv_paths, dataset_probs): 40 | 41 | test_loader = get_data_loader( 42 | transforms=val_transforms, 43 | csv_paths=[csv_path], 44 | tokenizer=tokenizer, 45 | dataset_probs=[dataset_prob], 46 | epoch_size=config.get_test('epoch_size'), 47 | batch_size=config.get_test('batch_size'), 48 | drop_last=False 49 | ) 50 | 51 | logger.info(csv_path) 52 | acc_avg = val_loop(test_loader, model, decoder, logger, DEVICE) 53 | acc_avg_all.append(acc_avg) 54 | 55 | logger.info(f'Average accuracy by dataset: {sum(acc_avg_all) / len(acc_avg_all):.4f}') 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--config_path', type=str, 61 | default='/workdir/scripts/ocr_config.json', 62 | help='Path to config.json.') 63 | parser.add_argument('--model_path', type=str, 64 | help='Path to model weights.') 65 | parser.add_argument('--lm_path', type=str, default='', 66 | help='Path to KenLM language model .arpa.') 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /scripts/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "c0a671fe", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import cv2\n", 11 | "from matplotlib import pyplot as plt\n", 12 | "\n", 13 | "from ocr.predictor import OcrPredictor" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "469d9030", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "IMG_PATHS = ['']\n", 24 | "\n", 25 | "MODEL_PATH = ''\n", 26 | "CONFIG_PATH = ''\n", 27 | "LM_PATH = ''\n", 28 | "\n", 29 | "NUM_THREADS = 8\n", 30 | "\n", 31 | "DEVICE = 'cuda'\n", 32 | "\n", 33 | "RUNTIME = 'Pytorch'" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "2d4ea798", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "predictor = OcrPredictor(\n", 44 | " model_path=MODEL_PATH,\n", 45 | " config_path=CONFIG_PATH,\n", 46 | " lm_path=LM_PATH,\n", 47 | " num_threads=NUM_THREADS,\n", 48 | " device=DEVICE,\n", 49 | " runtime=RUNTIME\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "id": "8e2894c4", 56 | "metadata": {}, 57 | "source": [ 58 | "# Predict" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "8bb7f01b", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "images = [cv2.imread(i) for i in IMG_PATHS]\n", 69 | "\n", 70 | "pred_texts = predictor(images)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "a585473e", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "for idx, pred_text in enumerate(pred_texts):\n", 81 | " print(pred_text)\n", 82 | "\n", 83 | " plt.figure(figsize=(5, 5))\n", 84 | " plt.imshow(images[idx])\n", 85 | " plt.show()" 86 | ] 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "Python 3 (ipykernel)", 92 | "language": "python", 93 | "name": "python3" 94 | }, 95 | "language_info": { 96 | "codemirror_mode": { 97 | "name": "ipython", 98 | "version": 3 99 | }, 100 | "file_extension": ".py", 101 | "mimetype": "text/x-python", 102 | "name": "python", 103 | "nbconvert_exporter": "python", 104 | "pygments_lexer": "ipython3", 105 | "version": "3.8.10" 106 | } 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 5 110 | } 111 | -------------------------------------------------------------------------------- /scripts/OCR-GoogleColab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "16261ff8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Google Colab demo\n", 9 | "\n", 10 | "To enable GPU:\n", 11 | "Runtime -> Change runtime type -> GPU" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "25761b2c", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "! pip install onnxruntime\n", 22 | "! pip install openvino\n", 23 | "! pip install huggingface_hub\n", 24 | "! pip install git+https://github.com/parlance/ctcdecode\n", 25 | "\n", 26 | "! git clone https://github.com/ai-forever/OCR-model.git" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "d13e3a4d", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import sys\n", 37 | "sys.path.append('OCR-model/')\n", 38 | "\n", 39 | "import cv2\n", 40 | "from matplotlib import pyplot as plt\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "from huggingface_hub import hf_hub_download\n", 44 | "\n", 45 | "from ocr.predictor import OcrPredictor" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "3a01f543", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "repo_id = \"sberbank-ai/ReadingPipeline-Peter\"\n", 56 | "\n", 57 | "IMG_PATH = hf_hub_download(repo_id, \"crop.jpg\")\n", 58 | "\n", 59 | "MODEL_PATH = hf_hub_download(repo_id, \"ocr/ocr_model.ckpt\")\n", 60 | "CONFIG_PATH = hf_hub_download(repo_id, \"ocr/ocr_config.json\")\n", 61 | "LM_PATH = ''\n", 62 | "\n", 63 | "NUM_THREADS = 8\n", 64 | "\n", 65 | "DEVICE = 'cuda'\n", 66 | "\n", 67 | "RUNTIME = 'Pytorch'" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "85e23719", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "predictor = OcrPredictor(\n", 78 | " model_path=MODEL_PATH,\n", 79 | " config_path=CONFIG_PATH,\n", 80 | " lm_path=LM_PATH,\n", 81 | " num_threads=NUM_THREADS,\n", 82 | " device=DEVICE,\n", 83 | " runtime=RUNTIME\n", 84 | ")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "e9289300", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "image = cv2.imread(IMG_PATH)\n", 95 | "\n", 96 | "pred_texts = predictor([image])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "a7ace4b1", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "print(pred_texts[0])\n", 107 | "\n", 108 | "image = cv2.imread(IMG_PATH)\n", 109 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 110 | "plt.figure(figsize=(10, 10))\n", 111 | "plt.imshow(image)\n", 112 | "plt.show()" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3 (ipykernel)", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.9.5" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 5 137 | } 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OCR model 2 | 3 | This is a model for Optical Character Recognition based on [CRNN-arhitecture](https://arxiv.org/abs/1507.05717) and [CTC loss](https://www.cs.toronto.edu/~graves/icml_2006.pdf). 4 | 5 | OCR-model is a part of [ReadingPipeline](https://github.com/ai-forever/ReadingPipeline) repo. 6 | 7 | ## Demo 8 | 9 | In the [demo](scripts/OCR-GoogleColab.ipynb) you can find an example of using of OCR-model (you can run it in your Google Colab). 10 | 11 | ## Quick setup and start 12 | 13 | - Nvidia drivers >= 470, CUDA >= 11.4 14 | - [Docker](https://docs.docker.com/engine/install/ubuntu/), [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) 15 | 16 | The provided [Dockerfile](Dockerfile) is supplied to build an image with CUDA support and cuDNN. 17 | 18 | ### Preparations 19 | 20 | - Clone the repo. 21 | - Download and extract dataset to the `data/` folder. 22 | - `sudo make all` to build a docker image and create a container. 23 | Or `sudo make all GPUS=device=0 CPUS=10` if you want to specify gpu devices and limit CPU-resources. 24 | 25 | If you don't want to use Docker, you can install dependencies via requirements.txt 26 | 27 | ## Configuring the model 28 | 29 | You can change the [ocr_config.json](scripts/ocr_config.json) and set the necessary training and evaluating parameters: alphabet, image size, saving path, etc. 30 | 31 | ``` 32 | "train": { 33 | "datasets": [ 34 | { 35 | "csv_path": "/workdir/data/dataset_1/train.csv", 36 | "prob": 0.5 37 | }, 38 | { 39 | "csv_path": "/workdir/data/dataset_2/train.csv", 40 | "prob": 0.7 41 | }, 42 | ... 43 | ], 44 | "epoch_size": 10000, 45 | "batch_size": 512 46 | } 47 | ``` 48 | - `epoch_size` - the size of an epoch. If you set it to `null`, then the epoch size will be equal to the amount of samples in the all datasets. 49 | - It is also possible to specify several datasets for the train/validation/test, setting the probabilities for each dataset separately (the sum of `prob` can be greater than 1, since normalization occurs inside the processing). 50 | 51 | ## Prepare data 52 | 53 | Datasets must be pre-processed and have a single format: each dataset must contain a folder with images (crop images with text) and csv file with annotations. The csv file should contain two columns: "filename" with the relative path to the images (folder-name/image-name.png), and "text"-column with the image transcription. 54 | 55 | | filename | text | 56 | | ----------------- | ---- | 57 | | images/4099-0.png | is | 58 | 59 | If you use polygon annotations in COCO format, you can prepare a training dataset using this script: 60 | 61 | ```bash 62 | python scripts/prepare_dataset.py \ 63 | --annotation_json_path path/to/the/annotaions.json \ 64 | --annotation_image_root dir/to/images/from/annotation/file \ 65 | --class_names pupil_text pupil_comment teacher_comment \ 66 | --bbox_scale_x 1 \ 67 | --bbox_scale_y 1 \ 68 | --save_dir dir/to/save/dataset \ 69 | --output_csv_name data.csv 70 | ``` 71 | 72 | ## Training 73 | 74 | To train the model run: 75 | 76 | ```bash 77 | python scripts/train.py --config_path path/to/the/ocr_config.json 78 | ``` 79 | 80 | ## Evaluating 81 | 82 | To test the model run: 83 | 84 | ```bash 85 | python scripts/evaluate.py \ 86 | --config_path path/to/the/ocr_config.json \ 87 | --model_path path/to/the/model-weights.ckpt 88 | ``` 89 | 90 | If you want to use a beam search decoder with LM, you can pass lm_path arg with path to .arpa kenLM file. 91 | --lm_path path/to/the/language-model.arpa 92 | 93 | ## ONNX 94 | 95 | You can convert Torch model to ONNX to speed up inference on cpu. 96 | 97 | ```bash 98 | python scripts/torch2onnx.py \ 99 | --config_path path/to/the/ocr_config.json \ 100 | --model_path path/to/the/model-weights.ckpt 101 | ``` 102 | 103 | -------------------------------------------------------------------------------- /ocr/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import time 5 | import logging 6 | from tqdm import tqdm 7 | 8 | from ocr.metrics import get_accuracy, wer, cer 9 | from ocr.predictor import predict 10 | 11 | 12 | def configure_logging(log_path=None): 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter( 16 | fmt='%(asctime)s - %(levelname)s - %(message)s', 17 | datefmt='%d-%b-%y %H:%M:%S' 18 | ) 19 | # Setup console logging 20 | sh = logging.StreamHandler() 21 | sh.setLevel(logging.DEBUG) 22 | sh.setFormatter(formatter) 23 | logger.addHandler(sh) 24 | # Setup file logging as well 25 | if log_path is not None: 26 | fh = logging.FileHandler(log_path) 27 | fh.setLevel(logging.DEBUG) 28 | fh.setFormatter(formatter) 29 | logger.addHandler(fh) 30 | return logger 31 | 32 | 33 | def val_loop(data_loader, model, decoder, logger, device): 34 | acc_avg = AverageMeter() 35 | wer_avg = AverageMeter() 36 | cer_avg = AverageMeter() 37 | strat_time = time.time() 38 | tqdm_data_loader = tqdm(data_loader, total=len(data_loader), leave=False) 39 | for images, texts, _, _ in tqdm_data_loader: 40 | batch_size = len(texts) 41 | text_preds = predict(images, model, decoder, device) 42 | acc_avg.update(get_accuracy(texts, text_preds), batch_size) 43 | wer_avg.update(wer(texts, text_preds), batch_size) 44 | cer_avg.update(cer(texts, text_preds), batch_size) 45 | 46 | loop_time = sec2min(time.time() - strat_time) 47 | logger.info(f'Validation, ' 48 | f'acc: {acc_avg.avg:.4f}, ' 49 | f'wer: {wer_avg.avg:.4f}, ' 50 | f'cer: {cer_avg.avg:.4f}, ' 51 | f'loop_time: {loop_time}') 52 | return acc_avg.avg 53 | 54 | 55 | def sec2min(s): 56 | m = math.floor(s / 60) 57 | s -= m * 60 58 | return '%dm %ds' % (m, s) 59 | 60 | 61 | class AverageMeter: 62 | """Computes and stores the average and current value""" 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.avg = 0 68 | self.sum = 0 69 | self.count = 0 70 | 71 | def update(self, val, n=1): 72 | self.sum += val * n 73 | self.count += n 74 | self.avg = self.sum / self.count 75 | 76 | 77 | class FilesLimitControl: 78 | """Delete files from the disk if there are more files than the set limit. 79 | Args: 80 | max_weights_to_save (int, optional): The number of files that will be 81 | stored on the disk at the same time. Default is 3. 82 | """ 83 | def __init__(self, logger=None, max_weights_to_save=2): 84 | self.saved_weights_paths = [] 85 | self.max_weights_to_save = max_weights_to_save 86 | self.logger = logger 87 | if logger is None: 88 | self.logger = configure_logging() 89 | 90 | def __call__(self, save_path): 91 | self.saved_weights_paths.append(save_path) 92 | if len(self.saved_weights_paths) > self.max_weights_to_save: 93 | old_weights_path = self.saved_weights_paths.pop(0) 94 | if os.path.exists(old_weights_path): 95 | os.remove(old_weights_path) 96 | self.logger.info(f"Weigths removed '{old_weights_path}'") 97 | 98 | 99 | def load_pretrain_model(weights_path, model, logger=None): 100 | """Load the entire pretrain model or as many layers as possible. 101 | """ 102 | if logger is None: 103 | logger = configure_logging() 104 | old_dict = torch.load(weights_path) 105 | new_dict = model.state_dict() 106 | for key, weights in new_dict.items(): 107 | if key in old_dict: 108 | if new_dict[key].shape == old_dict[key].shape: 109 | new_dict[key] = old_dict[key] 110 | else: 111 | logger.info('Weights {} were not loaded'.format(key)) 112 | else: 113 | logger.info('Weights {} were not loaded'.format(key)) 114 | return new_dict 115 | -------------------------------------------------------------------------------- /ocr/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ctcdecode import CTCBeamDecoder 5 | 6 | 7 | OOV_TOKEN = '' 8 | CTC_BLANK = '' 9 | 10 | 11 | def get_char_map(alphabet): 12 | """Make from string alphabet character2int dict. 13 | Add BLANK char fro CTC loss and OOV char for out of vocabulary symbols.""" 14 | char_map = {value: idx + 2 for (idx, value) in enumerate(alphabet)} 15 | char_map[CTC_BLANK] = 0 16 | char_map[OOV_TOKEN] = 1 17 | return char_map 18 | 19 | 20 | class Tokenizer: 21 | """Class for encoding and decoding string word to sequence of int 22 | (and vice versa) using alphabet.""" 23 | 24 | def __init__(self, alphabet): 25 | self.char_map = get_char_map(alphabet) 26 | self.rev_char_map = {val: key for key, val in self.char_map.items()} 27 | 28 | def encode(self, word_list): 29 | """Returns a list of encoded words (int).""" 30 | enc_words = [] 31 | for word in word_list: 32 | enc_words.append( 33 | [self.char_map[char] if char in self.char_map 34 | else self.char_map[OOV_TOKEN] 35 | for char in word] 36 | ) 37 | return enc_words 38 | 39 | def get_num_chars(self): 40 | return len(self.char_map) 41 | 42 | def decode(self, enc_word_list, merge_repeated=True): 43 | """Returns a list of words (str) after removing blanks and collapsing 44 | repeating characters. Also skip out of vocabulary tokens.""" 45 | dec_words = [] 46 | for word in enc_word_list: 47 | word_chars = '' 48 | for idx, char_enc in enumerate(word): 49 | # skip blank symbols, oov tokens and repeated characters 50 | if ( 51 | char_enc != self.char_map[OOV_TOKEN] 52 | and char_enc != self.char_map[CTC_BLANK] 53 | # (idx > 0) condition to avoid selecting [-1] item 54 | and not (merge_repeated and idx > 0 55 | and char_enc == word[idx - 1]) 56 | ): 57 | word_chars += self.rev_char_map[char_enc] 58 | dec_words.append(word_chars) 59 | return dec_words 60 | 61 | 62 | class OCRDecoder: 63 | def decode(self): 64 | raise NotImplementedError 65 | 66 | def onnx_cpu_decode(self): 67 | raise NotImplementedError 68 | 69 | 70 | class BeamSearcDecoder(OCRDecoder): 71 | def __init__(self, alphabet, lm_path): 72 | self.tokenizer = Tokenizer(alphabet) 73 | char_map = self.tokenizer.char_map 74 | labels = [ 75 | k for k, v in sorted(char_map.items(), key=lambda item: item[1]) 76 | ] 77 | self.decoder = CTCBeamDecoder( 78 | labels=labels, 79 | model_path=lm_path, 80 | alpha=0.6, 81 | beta=1.1, 82 | cutoff_top_n=10, 83 | cutoff_prob=1, 84 | beam_width=10, 85 | num_processes=6, 86 | blank_id=0, 87 | log_probs_input=True) 88 | 89 | def decode_numpy(self, output): 90 | # parlance/ctcdecode works only with torch arrays 91 | output = torch.from_numpy(output) 92 | return self.decode(output) 93 | 94 | def decode(self, output): 95 | beam_results, _, _, out_lens = \ 96 | self.decoder.decode(output.permute(1, 0, 2)) 97 | encoded_texts = [] 98 | for beam_result, out_len in zip(beam_results, out_lens): 99 | encoded_texts.append( 100 | beam_result[0][:out_len[0]].numpy() 101 | ) 102 | text_preds = self.tokenizer.decode(encoded_texts, merge_repeated=False) 103 | return text_preds 104 | 105 | 106 | class BestPathDecoder(OCRDecoder): 107 | def __init__(self, alphabet): 108 | self.tokenizer = Tokenizer(alphabet) 109 | 110 | def decode_numpy(self, output): 111 | pred = np.argmax(output, -1) 112 | pred = np.transpose(pred, (1, 0)) 113 | text_preds = self.tokenizer.decode(pred) 114 | return text_preds 115 | 116 | def decode(self, output): 117 | pred = torch.argmax(output.detach().cpu(), -1).permute(1, 0).numpy() 118 | text_preds = self.tokenizer.decode(pred) 119 | return text_preds 120 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | import time 4 | import torch 5 | import argparse 6 | import numpy as np 7 | 8 | from ocr.utils import ( 9 | val_loop, load_pretrain_model, FilesLimitControl, AverageMeter, sec2min, 10 | configure_logging 11 | ) 12 | 13 | from ocr.dataset import get_data_loader 14 | from ocr.transforms import get_train_transforms, get_val_transforms 15 | from ocr.tokenizer import Tokenizer, BestPathDecoder 16 | from ocr.config import Config 17 | from ocr.models import CRNN 18 | from ocr.metrics import get_accuracy, wer, cer 19 | 20 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | def train_loop( 24 | data_loader, model, decoder, criterion, optimizer, epoch, scheduler, logger 25 | ): 26 | loss_avg = AverageMeter() 27 | acc_avg = AverageMeter() 28 | wer_avg = AverageMeter() 29 | cer_avg = AverageMeter() 30 | strat_time = time.time() 31 | model.train() 32 | tqdm_data_loader = tqdm(data_loader, total=len(data_loader), leave=False) 33 | for images, texts, enc_pad_texts, text_lens in tqdm_data_loader: 34 | model.zero_grad() 35 | images = images.to(DEVICE) 36 | batch_size = len(texts) 37 | output = model(images) 38 | text_preds = decoder.decode(output) 39 | acc_avg.update(get_accuracy(texts, text_preds), batch_size) 40 | wer_avg.update(wer(texts, text_preds), batch_size) 41 | cer_avg.update(cer(texts, text_preds), batch_size) 42 | output_lenghts = torch.full( 43 | size=(output.size(1),), 44 | fill_value=output.size(0), 45 | dtype=torch.long 46 | ) 47 | loss = criterion(output, enc_pad_texts, output_lenghts, text_lens) 48 | loss_avg.update(loss.item(), batch_size) 49 | loss.backward() 50 | torch.nn.utils.clip_grad_norm_(model.parameters(), 2) 51 | optimizer.step() 52 | scheduler.step() 53 | loop_time = sec2min(time.time() - strat_time) 54 | for param_group in optimizer.param_groups: 55 | lr = param_group['lr'] 56 | logger.info(f'Epoch {epoch}, Loss: {loss_avg.avg:.4f}, ' 57 | f'acc: {acc_avg.avg:.4f}, ' 58 | f'wer: {wer_avg.avg:.4f}, ' 59 | f'cer: {cer_avg.avg:.4f}, ' 60 | f'LR: {lr:.7f}, loop_time: {loop_time}') 61 | return loss_avg.avg 62 | 63 | 64 | def get_loaders(tokenizer, config): 65 | train_transforms = get_train_transforms( 66 | height=config.get_image('height'), 67 | width=config.get_image('width'), 68 | prob=0.4 69 | ) 70 | train_loader = get_data_loader( 71 | transforms=train_transforms, 72 | csv_paths=config.get_train_datasets('csv_path'), 73 | tokenizer=tokenizer, 74 | dataset_probs=config.get_train_datasets('prob'), 75 | epoch_size=config.get_train('epoch_size'), 76 | batch_size=config.get_train('batch_size'), 77 | drop_last=True 78 | ) 79 | val_transforms = get_val_transforms( 80 | height=config.get_image('height'), 81 | width=config.get_image('width') 82 | ) 83 | val_loader = get_data_loader( 84 | transforms=val_transforms, 85 | csv_paths=config.get_val_datasets('csv_path'), 86 | tokenizer=tokenizer, 87 | dataset_probs=config.get_val_datasets('prob'), 88 | epoch_size=config.get_val('epoch_size'), 89 | batch_size=config.get_val('batch_size'), 90 | drop_last=False 91 | ) 92 | return train_loader, val_loader 93 | 94 | 95 | def main(args): 96 | config = Config(args.config_path) 97 | tokenizer = Tokenizer(config.get('alphabet')) 98 | os.makedirs(config.get('save_dir'), exist_ok=True) 99 | log_path = os.path.join(config.get('save_dir'), "output.log") 100 | logger = configure_logging(log_path) 101 | train_loader, val_loader = get_loaders(tokenizer, config) 102 | 103 | model = CRNN(number_class_symbols=tokenizer.get_num_chars()) 104 | if config.get('pretrain_path'): 105 | states = load_pretrain_model( 106 | config.get('pretrain_path'), model, logger) 107 | model.load_state_dict(states) 108 | logger.info(f"Load pretrained model {config.get('pretrain_path')}") 109 | model.to(DEVICE) 110 | 111 | decoder = BestPathDecoder(config.get('alphabet')) 112 | 113 | criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) 114 | optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, 115 | weight_decay=0.01) 116 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 117 | optimizer=optimizer, 118 | epochs=config.get('num_epochs'), 119 | steps_per_epoch=len(train_loader), 120 | max_lr=0.001, 121 | pct_start=0.1, 122 | anneal_strategy='cos', 123 | final_div_factor=10 ** 5 124 | ) 125 | weight_limit_control = FilesLimitControl(logger=logger) 126 | best_acc = -np.inf 127 | 128 | acc_avg = val_loop(val_loader, model, decoder, logger, DEVICE) 129 | for epoch in range(config.get('num_epochs')): 130 | loss_avg = train_loop(train_loader, model, decoder, criterion, optimizer, 131 | epoch, scheduler, logger) 132 | acc_avg = val_loop(val_loader, model, decoder, logger, DEVICE) 133 | if acc_avg > best_acc: 134 | best_acc = acc_avg 135 | model_save_path = os.path.join( 136 | config.get('save_dir'), f'model-{epoch}-{acc_avg:.4f}.ckpt') 137 | torch.save(model.state_dict(), model_save_path) 138 | logger.info(f'Model weights saved {model_save_path}') 139 | weight_limit_control(model_save_path) 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument('--config_path', type=str, 145 | default='/workdir/scripts/ocr_config.json', 146 | help='Path to config.json.') 147 | args = parser.parse_args() 148 | 149 | main(args) 150 | -------------------------------------------------------------------------------- /ocr/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import onnxruntime as ort 3 | import openvino.runtime as ov 4 | from enum import Enum 5 | 6 | from ocr.transforms import InferenceTransform 7 | from ocr.tokenizer import Tokenizer, BeamSearcDecoder, BestPathDecoder 8 | from ocr.config import Config 9 | from ocr.models import CRNN 10 | 11 | 12 | def predict(images, model, decoder, device): 13 | """Make model prediction. 14 | 15 | Args: 16 | images (torch.Tensor): Batch with tensor images. 17 | model (ocr.src.models.CRNN): OCR model. 18 | decoder: (ocr.tokenizer.OCRDecoder) 19 | device (torch.device): Torch device. 20 | """ 21 | model.eval() 22 | images = images.to(device) 23 | with torch.no_grad(): 24 | output = model(images) 25 | text_preds = decoder.decode(output) 26 | return text_preds 27 | 28 | 29 | def split_list2batches(lst, batch_size): 30 | """Split list of images to list of bacthes.""" 31 | return [lst[i:i+batch_size] for i in range(0, len(lst), batch_size)] 32 | 33 | 34 | class OCRModel: 35 | def predict(self): 36 | raise NotImplementedError 37 | 38 | 39 | class OCRONNXCPUModel(OCRModel): 40 | def __init__(self, model_path, config, num_threads, decoder): 41 | self.tokenizer = Tokenizer(config.get('alphabet')) 42 | self.decoder = decoder 43 | sess = ort.SessionOptions() 44 | sess.intra_op_num_threads = num_threads 45 | sess.inter_op_num_threads = num_threads 46 | self.model = ort.InferenceSession(model_path, sess) 47 | 48 | self.transforms = InferenceTransform( 49 | height=config.get_image('height'), 50 | width=config.get_image('width'), 51 | return_numpy=True 52 | ) 53 | 54 | def predict(self, images): 55 | transformed_images = self.transforms(images) 56 | output = self.model.run( 57 | None, 58 | {"input": transformed_images}, 59 | )[0] 60 | pred = self.decoder.decode_numpy(output) 61 | return pred 62 | 63 | 64 | class OCROpenVinoCPUModel(OCRModel): 65 | def __init__(self, model_path, config, num_threads, decoder): 66 | self.tokenizer = Tokenizer(config.get('alphabet')) 67 | self.decoder = decoder 68 | ie = ov.Core() 69 | model_onnx = ie.read_model(model_path, "AUTO") 70 | self.model = ie.compile_model( 71 | model=model_onnx, 72 | device_name="CPU", 73 | config={"INFERENCE_NUM_THREADS": str(num_threads)} 74 | ) 75 | self.transforms = InferenceTransform( 76 | height=config.get_image('height'), 77 | width=config.get_image('width'), 78 | return_numpy=True 79 | ) 80 | 81 | def predict(self, images): 82 | transformed_images = self.transforms(images) 83 | infer_request = self.model.create_infer_request() 84 | infer_request.infer([transformed_images]) 85 | output = infer_request.get_output_tensor().data 86 | pred = self.decoder.decode_numpy(output) 87 | return pred 88 | 89 | 90 | class OCRTorchModel(OCRModel): 91 | def __init__(self, model_path, config, decoder, device='cuda'): 92 | self.tokenizer = Tokenizer(config.get('alphabet')) 93 | self.device = torch.device(device) 94 | self.decoder = decoder 95 | # load model 96 | self.model = CRNN( 97 | number_class_symbols=self.tokenizer.get_num_chars(), 98 | pretrained=False 99 | ) 100 | self.model.load_state_dict( 101 | torch.load(model_path, map_location=self.device)) 102 | self.model.to(self.device) 103 | 104 | self.transforms = InferenceTransform( 105 | height=config.get_image('height'), 106 | width=config.get_image('width'), 107 | ) 108 | 109 | def predict(self, images): 110 | transformed_images = self.transforms(images) 111 | pred = predict( 112 | transformed_images, self.model, self.decoder, self.device) 113 | return pred 114 | 115 | 116 | class RuntimeType(Enum): 117 | ONNX = "ONNX" 118 | OVINO = "OpenVino" 119 | TORCH = "Pytorch" 120 | 121 | 122 | def validate_value_in_enum(value, enum_cls: Enum): 123 | enum_values = [e.value for e in enum_cls] 124 | if value not in enum_values: 125 | raise Exception(f"{value} is not supported. " 126 | f"Allowed types are: {', '.join(enum_values)}") 127 | 128 | 129 | class OcrPredictor: 130 | """Make OCR prediction. 131 | 132 | Args: 133 | model_path (str): The path to the model weights. 134 | config_path (str): The path to the model config. 135 | num_threads (int): The number of cpu threads to use 136 | (in ONNX and OpenVino runtimes). 137 | runtime (str): The runtime method of the model (Pytorch, ONNX or 138 | OpenVino from the RuntimeType). Default is Pytorch. 139 | device (str): The device for computation. Default is cuda. 140 | """ 141 | 142 | def __init__( 143 | self, model_path, config_path, num_threads, lm_path='', 144 | device='cuda', batch_size=1, runtime='Pytorch' 145 | ): 146 | self.batch_size = batch_size 147 | config = Config(config_path) 148 | if lm_path: 149 | decoder = BeamSearcDecoder(config.get('alphabet'), lm_path) 150 | else: 151 | decoder = BestPathDecoder(config.get('alphabet')) 152 | 153 | validate_value_in_enum(runtime, RuntimeType) 154 | if RuntimeType(runtime) is RuntimeType.TORCH: 155 | self.model = OCRTorchModel(model_path, config, decoder, device) 156 | elif ( 157 | RuntimeType(runtime) is RuntimeType.ONNX 158 | and device == 'cpu' 159 | ): 160 | self.model = OCRONNXCPUModel( 161 | model_path, config, num_threads, decoder) 162 | elif ( 163 | RuntimeType(runtime) is RuntimeType.OVINO 164 | and device == 'cpu' 165 | ): 166 | self.model = OCROpenVinoCPUModel( 167 | model_path, config, num_threads, decoder) 168 | else: 169 | raise Exception(f"Runtime type {runtime} with device {device} " 170 | "are not supported options.") 171 | 172 | def __call__(self, images): 173 | """ 174 | Args: 175 | images (list of np.ndarray): A list of images in BGR format. 176 | 177 | Returns: 178 | pred (str or list of strs): The predicted text for one input 179 | image, and a list with texts if there was a list of images. 180 | """ 181 | images_batches = split_list2batches(images, self.batch_size) 182 | pred = [] 183 | for images_batch in images_batches: 184 | preds_batch = self.model.predict(images_batch) 185 | pred.extend(preds_batch) 186 | 187 | return pred 188 | -------------------------------------------------------------------------------- /scripts/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | import cv2 6 | import os 7 | import argparse 8 | 9 | 10 | def numbers2coords(list_of_numbers): 11 | """Convert list of numbers to list of tuple coords x, y.""" 12 | bbox = [[int(list_of_numbers[i]), int(list_of_numbers[i+1])] 13 | for i in range(0, len(list_of_numbers), 2)] 14 | return np.array(bbox) 15 | 16 | 17 | def upscale_bbox(bbox, upscale_x=1, upscale_y=1): 18 | """Increase size of the bbox.""" 19 | height = bbox[3] - bbox[1] 20 | width = bbox[2] - bbox[0] 21 | 22 | y_change = (height * upscale_y) - height 23 | x_change = (width * upscale_x) - width 24 | 25 | x_min = max(0, bbox[0] - int(x_change/2)) 26 | y_min = max(0, bbox[1] - int(y_change/2)) 27 | x_max = bbox[2] + int(x_change/2) 28 | y_max = bbox[3] + int(y_change/2) 29 | return x_min, y_min, x_max, y_max 30 | 31 | 32 | def polygon2bbox(polygon): 33 | x_min = np.inf 34 | y_min = np.inf 35 | x_max = -np.inf 36 | y_max = -np.inf 37 | for x, y in polygon: 38 | if x > x_max: 39 | x_max = x 40 | if y > y_max: 41 | y_max = y 42 | if x < x_min: 43 | x_min = x 44 | if y < y_min: 45 | y_min = y 46 | return int(x_min), int(y_min), int(x_max), int(y_max) 47 | 48 | 49 | def img_crop(img, bbox): 50 | return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] 51 | 52 | 53 | def class_names2id(class_names, data): 54 | """Match class names to categoty ids using annotation in COCO format.""" 55 | category_ids = [] 56 | for class_name in class_names: 57 | for category_info in data['categories']: 58 | if category_info['name'] == class_name: 59 | category_ids.append(category_info['id']) 60 | return category_ids 61 | 62 | 63 | def get_data_from_image(data, image_id, class_names): 64 | texts = [] 65 | bboxes = [] 66 | polygons = [] 67 | category_ids = class_names2id(class_names, data) 68 | for idx, data_ann in enumerate(data['annotations']): 69 | if ( 70 | data_ann['image_id'] == image_id 71 | and data_ann['category_id'] in category_ids 72 | and data_ann.get('attributes') is not None 73 | and data_ann['attributes']['translation'] 74 | and data_ann['segmentation'] 75 | ): 76 | polygon = numbers2coords(data_ann['segmentation'][0]) 77 | bbox = polygon2bbox(polygon) 78 | bboxes.append(bbox) 79 | polygons.append(polygon) 80 | texts.append(data_ann['attributes']['translation']) 81 | return texts, bboxes, polygons 82 | 83 | 84 | def is_save_crop(remove_turned_crops, crop): 85 | crop_h, crop_w = crop.shape[:2] 86 | if ( 87 | remove_turned_crops 88 | and crop_h > crop_w 89 | ): 90 | return False 91 | 92 | if ( 93 | crop_h < 5 94 | or crop_w < 5 95 | ): 96 | return False 97 | 98 | return True 99 | 100 | 101 | def make_large_bbox_dataset( 102 | input_coco_json, image_root, class_names, bbox_scale_x, bbox_scale_y, 103 | save_dir, save_csv_name, remove_turned_crops, crop_by_mask, 104 | image_folder_name='images' 105 | ): 106 | os.makedirs(save_dir, exist_ok=True) 107 | save_image_dir = os.path.join(save_dir, image_folder_name) 108 | os.makedirs(save_image_dir, exist_ok=True) 109 | 110 | with open(input_coco_json, 'r') as f: 111 | data = json.load(f) 112 | 113 | crop_texts = [] 114 | crop_names = [] 115 | for data_img in tqdm(data['images']): 116 | img_name = data_img['file_name'] 117 | image_id = data_img['id'] 118 | image = cv2.imread(os.path.join(image_root, img_name)) 119 | 120 | texts, bboxes, polygons = \ 121 | get_data_from_image(data, image_id, class_names) 122 | 123 | crop_data = zip(texts, bboxes, polygons) 124 | for idx, (text, bbox, polygon) in enumerate(crop_data): 125 | upscaled_bbox = upscale_bbox(bbox, bbox_scale_x, bbox_scale_y) 126 | crop = img_crop(image, upscaled_bbox) 127 | crop_h, crop_w = crop.shape[:2] 128 | 129 | if crop_by_mask: 130 | pts = polygon - polygon.min(axis=0) 131 | mask = np.zeros((crop_h, crop_w), np.uint8) 132 | cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA) 133 | crop = cv2.bitwise_and(crop, crop, mask=mask) 134 | 135 | if is_save_crop(remove_turned_crops, crop): 136 | crop_name = f'{image_folder_name}/{image_id}-{idx}.png' 137 | crop_path = os.path.join(save_dir, crop_name) 138 | cv2.imwrite(crop_path, crop) 139 | crop_texts.append(text) 140 | crop_names.append(crop_name) 141 | data = pd.DataFrame(zip(crop_names, crop_texts), columns=["filename", "text"]) 142 | csv_path = os.path.join(save_dir, save_csv_name) 143 | data.to_csv(csv_path, index=False) 144 | 145 | 146 | if __name__ == '__main__': 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument('--annotation_json_path', type=str, required=True, 149 | help='Path to json with segmentation dataset' 150 | 'annotation in COCO format.') 151 | parser.add_argument('--annotation_image_root', type=str, required=True, 152 | help='Directory to folder with images from' 153 | 'annotatin.json.') 154 | parser.add_argument("--remove_turned_crops", action='store_true', 155 | help="To remove images with height greater than width.") 156 | parser.add_argument("--crop_by_mask", action='store_true', 157 | help="To crop iamges by mask instead of bbox.") 158 | parser.add_argument('--class_names', nargs='+', type=str, required=True, 159 | help='Class namess (separated by spaces) from ' 160 | 'annotation_json_path to make OCR dataset from them.') 161 | parser.add_argument('--bbox_scale_x', type=float, required=True, 162 | help='Scale parameter for bbox.') 163 | parser.add_argument('--bbox_scale_y', type=float, required=True, 164 | help='Scale parameter for bbox.') 165 | parser.add_argument('--save_dir', type=str, required=True, 166 | help='Directory to save OCR dataset.') 167 | parser.add_argument('--output_csv_name', type=str, required=True, 168 | help='The name of the output csv with OCR annotation' 169 | 'informarion.') 170 | 171 | args = parser.parse_args() 172 | 173 | make_large_bbox_dataset( 174 | input_coco_json=args.annotation_json_path, 175 | image_root=args.annotation_image_root, 176 | class_names=args.class_names, 177 | bbox_scale_x=args.bbox_scale_x, 178 | bbox_scale_y=args.bbox_scale_y, 179 | save_dir=args.save_dir, 180 | remove_turned_crops=args.remove_turned_crops, 181 | crop_by_mask=args.crop_by_mask, 182 | save_csv_name=args.output_csv_name 183 | ) 184 | -------------------------------------------------------------------------------- /ocr/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, Sampler 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pathlib 8 | import cv2 9 | 10 | 11 | class SequentialSampler(Sampler): 12 | """Make sequence of dataset indexes for batch sampler. 13 | Args: 14 | dataset_len (int): Length of train dataset. 15 | epoch_size (int, optional): Size of train epoch (by default it 16 | is equal to the dataset_len). Can be specified if you need to 17 | reduce the time of the epoch. 18 | init_sample_probs (list, optional): List of samples' probabilities to 19 | be added in batch. If None probs for all samples would be the same. 20 | The length of the list must be equal to the length of the dataset. 21 | """ 22 | def __init__(self, dataset_len, epoch_size=None, init_sample_probs=None): 23 | self.dataset_len = dataset_len 24 | if epoch_size is not None: 25 | self.epoch_size = epoch_size 26 | else: 27 | self.epoch_size = dataset_len 28 | 29 | if init_sample_probs is None: 30 | self.init_sample_probs = \ 31 | np.array([1. for i in range(dataset_len)], dtype=np.float64) 32 | else: 33 | self.init_sample_probs = \ 34 | np.array(init_sample_probs, dtype=np.float64) 35 | assert len(self.init_sample_probs) == dataset_len, "The len " \ 36 | "of the sample_probs must be equal to the dataset_len." 37 | self.init_sample_probs = \ 38 | self._sample_probs_normalization(self.init_sample_probs) 39 | 40 | def _sample_probs_normalization(self, sample_probs): 41 | """Probabilities normalization to make them sum to 1. 42 | Sum might not be equal to 1 if probs are too small. 43 | """ 44 | return sample_probs / sample_probs.sum() 45 | 46 | def __iter__(self): 47 | dataset_indexes = np.random.choice( 48 | a=self.dataset_len, 49 | size=self.epoch_size, 50 | p=self.init_sample_probs, 51 | replace=False, # only unique samples inside an epoch 52 | ) 53 | return iter(dataset_indexes) 54 | 55 | def __len__(self): 56 | return self.epoch_size 57 | 58 | 59 | def collate_fn(batch): 60 | images, texts, enc_texts = zip(*batch) 61 | images = torch.stack(images, 0) 62 | text_lens = torch.LongTensor([len(text) for text in texts]) 63 | enc_pad_texts = pad_sequence(enc_texts, batch_first=True, padding_value=0) 64 | return images, texts, enc_pad_texts, text_lens 65 | 66 | 67 | def get_full_img_path(img_root_path, csv_path): 68 | """Merge csv root path and image name.""" 69 | root_dir = pathlib.Path(csv_path).parent 70 | img_path = root_dir / pathlib.Path(img_root_path) 71 | return str(img_path) 72 | 73 | 74 | def read_and_concat_datasets(csv_paths): 75 | """Read csv files and concatenate them into one pandas DataFrame. 76 | 77 | Args: 78 | csv_paths (list): List of the dataset csv paths. 79 | 80 | Return: 81 | data (pandas.DataFrame): Concatenated datasets. 82 | """ 83 | data = [] 84 | for csv_path in csv_paths: 85 | csv_data = pd.read_csv( 86 | csv_path, dtype={'text': 'str'}, keep_default_na=False) 87 | csv_data['dataset_name'] = csv_path 88 | csv_data['filename'] = csv_data['filename'].apply( 89 | get_full_img_path, csv_path=csv_path) 90 | data.append(csv_data[['filename', 'dataset_name', 'text']]) 91 | data = pd.concat(data, ignore_index=True) 92 | return data 93 | 94 | 95 | def get_data_loader( 96 | transforms, csv_paths, tokenizer, dataset_probs, epoch_size, 97 | batch_size, drop_last 98 | ): 99 | data = read_and_concat_datasets(csv_paths) 100 | data['enc_text'] = tokenizer.encode(data['text'].values) 101 | 102 | dataset_prob2sample_prob = DatasetProb2SampleProb(csv_paths, dataset_probs) 103 | data = dataset_prob2sample_prob(data) 104 | 105 | dataset = OCRDataset(data, transforms) 106 | sampler = SequentialSampler( 107 | dataset_len=len(data), 108 | epoch_size=epoch_size, 109 | init_sample_probs=data['sample_prob'].values 110 | ) 111 | batcher = torch.utils.data.BatchSampler(sampler, batch_size=batch_size, 112 | drop_last=drop_last) 113 | data_loader = torch.utils.data.DataLoader( 114 | dataset=dataset, 115 | collate_fn=collate_fn, 116 | batch_sampler=batcher, 117 | num_workers=8, 118 | ) 119 | return data_loader 120 | 121 | 122 | class DatasetProb2SampleProb: 123 | """Convert dataset sampling probability to probability for each sample 124 | in the datset. 125 | 126 | Args: 127 | dataset_names (list): A list of the dataset names. 128 | dataset_probs (list of float): A list of dataset sample probs 129 | corresponding to the datasets from dataset_names list. 130 | """ 131 | 132 | def __init__(self, dataset_names, dataset_probs): 133 | assert len(dataset_names) == len(dataset_probs), "Length of " \ 134 | "csv_paths should be equal to the length of the dataset_probs." 135 | self.dataset2dataset_prob = dict(zip(dataset_names, dataset_probs)) 136 | 137 | def _dataset2sample_count(self, data): 138 | """Calculate samples in each dataset from data using.""" 139 | dataset2sample_count = {} 140 | for dataset_name in self.dataset2dataset_prob: 141 | dataset2sample_count[dataset_name] = \ 142 | (data['dataset_name'] == dataset_name).sum() 143 | return dataset2sample_count 144 | 145 | def _dataset2sample_prob(self, dataset2sample_count): 146 | """Convert dataaset prob to sample prob.""" 147 | dataset2sample_prob = {} 148 | for dataset_name, dataset_prob in self.dataset2dataset_prob.items(): 149 | sample_count = dataset2sample_count[dataset_name] 150 | dataset2sample_prob[dataset_name] = dataset_prob / sample_count 151 | return dataset2sample_prob 152 | 153 | def __call__(self, data): 154 | """Add sampling prob column to data. 155 | 156 | Args: 157 | data (pandas.DataFrame): Dataset with 'dataset_name' column. 158 | """ 159 | dataset2sample_count = self._dataset2sample_count(data) 160 | dataset2sample_prob = \ 161 | self._dataset2sample_prob(dataset2sample_count) 162 | data['sample_prob'] = data['dataset_name'].apply( 163 | lambda x: dataset2sample_prob[x]) 164 | return data 165 | 166 | 167 | class OCRDataset(Dataset): 168 | """OCR torch.Dataset. 169 | 170 | Args: 171 | data (pandas.DataFrame): Dataset with 'filename', 'text' and 172 | 'enc_text' columns. 173 | transform (torchvision.Compose): Image transforms, default is None. 174 | """ 175 | 176 | def __init__(self, data, transform=None): 177 | super().__init__() 178 | self.transform = transform 179 | self.data_len = len(data) 180 | self.img_paths = data['filename'].values 181 | self.texts = data['text'].values 182 | self.enc_texts = data['enc_text'].values 183 | 184 | def __len__(self): 185 | return self.data_len 186 | 187 | def __getitem__(self, idx): 188 | img_path = self.img_paths[idx] 189 | text = self.texts[idx] 190 | enc_text = torch.LongTensor(self.enc_texts[idx]) 191 | image = cv2.imread(img_path) 192 | if self.transform is not None: 193 | image = self.transform(image) 194 | return image, text, enc_text 195 | -------------------------------------------------------------------------------- /ocr/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torchvision 4 | import cv2 5 | import random 6 | import numpy as np 7 | from albumentations import augmentations 8 | 9 | 10 | class RescalePaddingImage: 11 | def __init__(self, output_height, output_width): 12 | self.output_height = output_height 13 | self.output_width = output_width 14 | 15 | def __call__(self, image): 16 | h, w = image.shape[:2] 17 | # width proportional to change in height 18 | new_width = int(w*(self.output_height/h)) 19 | # new_width cannot be bigger than output_width 20 | new_width = min(new_width, self.output_width) 21 | image = cv2.resize(image, (new_width, self.output_height), 22 | interpolation=cv2.INTER_LINEAR) 23 | if new_width < self.output_width: 24 | image = np.pad( 25 | image, ((0, 0), (0, self.output_width - new_width), (0, 0)), 26 | 'constant', constant_values=0) 27 | return image 28 | 29 | 30 | class Normalize: 31 | def __call__(self, img): 32 | img = img.astype(np.float32) / 255 33 | return img 34 | 35 | 36 | class ToTensor: 37 | def __call__(self, arr): 38 | arr = torch.from_numpy(arr) 39 | return arr 40 | 41 | 42 | class Rotate: 43 | def __init__(self, max_ang, prob): 44 | self.aug = augmentations.geometric.rotate.Rotate(limit=max_ang, p=prob) 45 | 46 | def __call__(self, img): 47 | augmented = self.aug(image=img) 48 | return augmented['image'] 49 | 50 | 51 | class SafeRotate: 52 | def __init__(self, max_ang, prob): 53 | self.aug = augmentations.geometric.rotate.SafeRotate( 54 | limit=max_ang, p=prob) 55 | 56 | def __call__(self, img): 57 | augmented = self.aug(image=img) 58 | return augmented['image'] 59 | 60 | 61 | class MoveChannels: 62 | """Move the channel axis to the zero position as required in pytorch.""" 63 | 64 | def __init__(self, to_channels_first=True): 65 | self.to_channels_first = to_channels_first 66 | 67 | def __call__(self, image): 68 | if self.to_channels_first: 69 | return np.moveaxis(image, -1, 0) 70 | else: 71 | return np.moveaxis(image, 0, -1) 72 | 73 | 74 | class UseWithProb: 75 | def __init__(self, transform, prob=0.5): 76 | self.transform = transform 77 | self.prob = prob 78 | 79 | def __call__(self, image): 80 | if random.random() < self.prob: 81 | image = self.transform(image) 82 | return image 83 | 84 | 85 | class OneOf: 86 | def __init__(self, transforms): 87 | self.transforms = transforms 88 | 89 | def __call__(self, image): 90 | return random.choice(self.transforms)(image) 91 | 92 | 93 | def img_crop(img, bbox): 94 | return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] 95 | 96 | 97 | def random_crop(img, size): 98 | tw = size[0] 99 | th = size[1] 100 | h, w = img.shape[:2] 101 | if ((w - tw) > 0) and ((h - th) > 0): 102 | x1 = random.randint(0, w - tw) 103 | y1 = random.randint(0, h - th) 104 | else: 105 | x1 = 0 106 | y1 = 0 107 | img_return = img_crop(img, (x1, y1, x1 + tw, y1 + th)) 108 | return img_return, x1, y1 109 | 110 | 111 | class RandomCrop: 112 | def __init__(self, rnd_crop_min, rnd_crop_max=1): 113 | self.factor_max = rnd_crop_max 114 | self.factor_min = rnd_crop_min 115 | 116 | def __call__(self, img): 117 | factor = random.uniform(self.factor_min, self.factor_max) 118 | size = ( 119 | int(img.shape[1]*factor), 120 | int(img.shape[0]*factor) 121 | ) 122 | img, x1, y1 = random_crop(img, size) 123 | return img 124 | 125 | 126 | def largest_rotated_rect(w, h, angle): 127 | """ 128 | https://stackoverflow.com/a/16770343 129 | Given a rectangle of size wxh that has been rotated by 'angle' (in 130 | radians), computes the width and height of the largest possible 131 | axis-aligned rectangle within the rotated rectangle. 132 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow 133 | Converted to Python by Aaron Snoswell 134 | """ 135 | 136 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3 137 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle 138 | alpha = (sign_alpha % math.pi + math.pi) % math.pi 139 | 140 | bb_w = w * math.cos(alpha) + h * math.sin(alpha) 141 | bb_h = w * math.sin(alpha) + h * math.cos(alpha) 142 | 143 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w) 144 | 145 | delta = math.pi - alpha - gamma 146 | 147 | length = h if (w < h) else w 148 | 149 | d = length * math.cos(alpha) 150 | a = d * math.sin(alpha) / math.sin(delta) 151 | 152 | y = a * math.cos(gamma) 153 | x = y * math.tan(gamma) 154 | 155 | return ( 156 | bb_w - 2 * x, 157 | bb_h - 2 * y 158 | ) 159 | 160 | 161 | def crop_around_center(image, width, height): 162 | """ 163 | https://stackoverflow.com/a/16770343 164 | Given a NumPy / OpenCV 2 image, crops it to the given width and height, 165 | around it's centre point 166 | """ 167 | 168 | image_size = (image.shape[1], image.shape[0]) 169 | image_center = (int(image_size[0] * 0.5), int(image_size[1] * 0.5)) 170 | 171 | if(width > image_size[0]): 172 | width = image_size[0] 173 | 174 | if(height > image_size[1]): 175 | height = image_size[1] 176 | 177 | x1 = int(image_center[0] - width * 0.5) 178 | x2 = int(image_center[0] + width * 0.5) 179 | y1 = int(image_center[1] - height * 0.5) 180 | y2 = int(image_center[1] + height * 0.5) 181 | 182 | return image[y1:y2, x1:x2] 183 | 184 | 185 | class RotateAndCrop: 186 | """Random image rotate around the image center 187 | 188 | Args: 189 | max_ang (float): Max angle of rotation in deg 190 | """ 191 | 192 | def __init__(self, max_ang=0): 193 | self.max_ang = max_ang 194 | 195 | def __call__(self, img): 196 | h, w, _ = img.shape 197 | 198 | ang = np.random.uniform(-self.max_ang, self.max_ang) 199 | M = cv2.getRotationMatrix2D((w/2, h/2), ang, 1) 200 | img = cv2.warpAffine(img, M, (w, h)) 201 | 202 | w_cropped, h_cropped = largest_rotated_rect(w, h, math.radians(ang)) 203 | #to fix cases of too small or negative image height when cropping 204 | h_cropped = max(h_cropped, 10) 205 | img = crop_around_center(img, w_cropped, h_cropped) 206 | return img 207 | 208 | 209 | class InferenceTransform: 210 | def __init__(self, height, width, return_numpy=False): 211 | self.transforms = torchvision.transforms.Compose([ 212 | RescalePaddingImage(height, width), 213 | MoveChannels(to_channels_first=True), 214 | Normalize(), 215 | ]) 216 | self.return_numpy = return_numpy 217 | self.to_tensor = ToTensor() 218 | 219 | def __call__(self, images): 220 | transformed_images = [self.transforms(image) for image in images] 221 | transformed_array = np.stack(transformed_images, 0) 222 | if not self.return_numpy: 223 | transformed_array = self.to_tensor(transformed_array) 224 | return transformed_array 225 | 226 | 227 | class CLAHE: 228 | def __init__(self, prob): 229 | self.aug = augmentations.transforms.CLAHE(p=prob) 230 | 231 | def __call__(self, img): 232 | img = self.aug(image=img)['image'] 233 | return img 234 | 235 | 236 | class GaussNoise: 237 | def __init__(self, prob): 238 | self.aug = augmentations.transforms.GaussNoise( 239 | var_limit=100, p=prob) 240 | 241 | def __call__(self, img): 242 | img = self.aug(image=img)['image'] 243 | return img 244 | 245 | 246 | class ISONoise: 247 | def __init__(self, prob): 248 | self.aug = augmentations.transforms.ISONoise( 249 | p=prob) 250 | 251 | def __call__(self, img): 252 | img = self.aug(image=img)['image'] 253 | return img 254 | 255 | 256 | class MultiplicativeNoise: 257 | def __init__(self, prob): 258 | self.aug = augmentations.transforms.MultiplicativeNoise( 259 | multiplier=(0.85, 1.15), p=prob) 260 | 261 | def __call__(self, img): 262 | img = self.aug(image=img)['image'] 263 | return img 264 | 265 | 266 | class ImageCompression: 267 | def __init__(self, prob): 268 | self.aug = augmentations.transforms.ImageCompression( 269 | quality_lower=60, quality_upper=90, p=prob) 270 | 271 | def __call__(self, img): 272 | img = self.aug(image=img)['image'] 273 | return img 274 | 275 | 276 | class Sharpen: 277 | def __init__(self, prob): 278 | self.aug = augmentations.Sharpen( 279 | p=prob) 280 | 281 | def __call__(self, img): 282 | img = self.aug(image=img)['image'] 283 | return img 284 | 285 | 286 | class ElasticTransform: 287 | def __init__(self, prob): 288 | self.aug = augmentations.geometric.transforms.ElasticTransform( 289 | alpha_affine=2.5, p=prob) 290 | 291 | def __call__(self, img): 292 | augmented = self.aug(image=img) 293 | return augmented['image'] 294 | 295 | 296 | class GridDistortion: 297 | def __init__(self, prob): 298 | self.aug = augmentations.transforms.GridDistortion(p=prob) 299 | 300 | def __call__(self, img): 301 | augmented = self.aug(image=img) 302 | return augmented['image'] 303 | 304 | 305 | class OpticalDistortion: 306 | def __init__(self, prob): 307 | self.aug = augmentations.transforms.OpticalDistortion( 308 | distort_limit=0.2, p=prob) 309 | 310 | def __call__(self, img): 311 | augmented = self.aug(image=img) 312 | return augmented['image'] 313 | 314 | 315 | class Perspective: 316 | def __init__(self, prob): 317 | self.aug = augmentations.geometric.transforms.Perspective( 318 | pad_mode=2, fit_output=True, p=prob) 319 | 320 | def __call__(self, img): 321 | augmented = self.aug(image=img) 322 | return augmented['image'] 323 | 324 | 325 | class ChannelDropout: 326 | def __init__(self, prob): 327 | self.aug = augmentations.ChannelDropout(p=prob) 328 | 329 | def __call__(self, img): 330 | img = self.aug(image=img)['image'] 331 | return img 332 | 333 | 334 | class ChannelShuffle: 335 | def __init__(self, prob): 336 | self.aug = augmentations.transforms.ChannelShuffle(p=prob) 337 | 338 | def __call__(self, img): 339 | img = self.aug(image=img)['image'] 340 | return img 341 | 342 | 343 | class RGBShift: 344 | def __init__(self, prob): 345 | self.aug = augmentations.transforms.RGBShift(p=prob) 346 | 347 | def __call__(self, img): 348 | img = self.aug(image=img)['image'] 349 | return img 350 | 351 | 352 | class ToGray: 353 | def __init__(self, prob): 354 | self.aug = augmentations.transforms.ToGray(p=prob) 355 | 356 | def __call__(self, img): 357 | img = self.aug(image=img)['image'] 358 | return img 359 | 360 | 361 | class ToSepia: 362 | def __init__(self, prob): 363 | self.aug = augmentations.transforms.ToSepia(p=prob) 364 | 365 | def __call__(self, img): 366 | img = self.aug(image=img)['image'] 367 | return img 368 | 369 | 370 | class RandomBrightnessContrast: 371 | def __init__(self, prob): 372 | self.aug = augmentations.transforms.RandomBrightnessContrast(p=prob) 373 | 374 | def __call__(self, img): 375 | img = self.aug(image=img)['image'] 376 | return img 377 | 378 | 379 | class RandomSnow: 380 | def __init__(self, prob): 381 | self.aug = augmentations.transforms.RandomSnow( 382 | brightness_coeff=1.5, p=prob) 383 | 384 | def __call__(self, img): 385 | img = self.aug(image=img)['image'] 386 | return img 387 | 388 | 389 | class HueSaturationValue: 390 | def __init__(self, prob): 391 | self.aug = augmentations.transforms.HueSaturationValue(p=prob) 392 | 393 | def __call__(self, img): 394 | img = self.aug(image=img)['image'] 395 | return img 396 | 397 | 398 | class RandomShadow: 399 | def __init__(self): 400 | pass 401 | 402 | def __call__(self, image, mask=None): 403 | row, col, ch = image.shape 404 | # We take a random point at the top for the x coordinate and then 405 | # another random x-coordinate at the bottom and join them to create 406 | # a shadow zone on the image. 407 | top_y = col * np.random.uniform() 408 | top_x = 0 409 | bot_x = row 410 | bot_y = col * np.random.uniform() 411 | img_hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS) 412 | shadow_mask = 0 * img_hls[:, :, 1] 413 | X_m = np.mgrid[0:image.shape[0], 0:image.shape[1]][0] 414 | Y_m = np.mgrid[0:image.shape[0], 0:image.shape[1]][1] 415 | 416 | shadow_mask[((X_m - top_x) * (bot_y - top_y) - (bot_x - top_x) * (Y_m - top_y) >= 0)] = 1 417 | 418 | random_bright = .25 + .7 * np.random.uniform() 419 | cond0 = shadow_mask == 0 420 | cond1 = shadow_mask == 1 421 | 422 | if np.random.randint(2) == 1: 423 | img_hls[:, :, 1][cond1] = img_hls[:, :, 1][cond1] * random_bright 424 | else: 425 | img_hls[:, :, 1][cond0] = img_hls[:, :, 1][cond0] * random_bright 426 | image = cv2.cvtColor(img_hls, cv2.COLOR_HLS2RGB) 427 | 428 | image = np.clip(image, 0, 255) 429 | image = image.astype(np.uint8) 430 | 431 | if mask is not None: 432 | return image, mask 433 | else: 434 | return image 435 | 436 | 437 | class RandomGamma: 438 | def __init__(self, prob): 439 | self.aug = augmentations.transforms.RandomGamma( 440 | gamma_limit=(50, 150), p=prob) 441 | 442 | def __call__(self, img): 443 | img = self.aug(image=img)['image'] 444 | return img 445 | 446 | 447 | class MotionBlur: 448 | def __init__(self, prob): 449 | self.aug = augmentations.transforms.MotionBlur( 450 | blur_limit=7, p=prob) 451 | 452 | def __call__(self, img): 453 | img = self.aug(image=img)['image'] 454 | return img 455 | 456 | 457 | class MedianBlur: 458 | def __init__(self, prob): 459 | self.aug = augmentations.transforms.MedianBlur( 460 | blur_limit=5, p=prob) 461 | 462 | def __call__(self, img): 463 | img = self.aug(image=img)['image'] 464 | return img 465 | 466 | 467 | class GlassBlur: 468 | def __init__(self, prob): 469 | self.aug = augmentations.transforms.GlassBlur( 470 | sigma=0.7, max_delta=2, p=prob) 471 | 472 | def __call__(self, img): 473 | img = self.aug(image=img)['image'] 474 | return img 475 | 476 | 477 | def get_train_transforms(height, width, prob): 478 | transforms = torchvision.transforms.Compose([ 479 | OneOf([ 480 | CLAHE(prob), 481 | GaussNoise(prob), 482 | ISONoise(prob), 483 | MultiplicativeNoise(prob), 484 | ImageCompression(prob), 485 | Sharpen(prob), 486 | MotionBlur(prob), 487 | MedianBlur(prob) 488 | ]), 489 | UseWithProb(RandomCrop(rnd_crop_min=0.80), 1), 490 | OneOf([ 491 | UseWithProb(RotateAndCrop(2), prob), 492 | Rotate(2, prob), 493 | SafeRotate(5, prob), 494 | ElasticTransform(prob), 495 | GridDistortion(prob), 496 | OpticalDistortion(prob), 497 | Perspective(prob) 498 | ]), 499 | OneOf([ 500 | RandomBrightnessContrast(prob), 501 | RandomGamma(prob), 502 | HueSaturationValue(prob), 503 | RandomSnow(prob), 504 | UseWithProb(RandomShadow(), prob) 505 | ]), 506 | RescalePaddingImage(height, width), 507 | MoveChannels(to_channels_first=True), 508 | Normalize(), 509 | ToTensor() 510 | ]) 511 | return transforms 512 | 513 | 514 | def get_val_transforms(height, width): 515 | transforms = torchvision.transforms.Compose([ 516 | RescalePaddingImage(height, width), 517 | MoveChannels(to_channels_first=True), 518 | Normalize(), 519 | ToTensor() 520 | ]) 521 | return transforms 522 | --------------------------------------------------------------------------------