├── .gitignore ├── 0_handwriting_ocr.ipynb ├── 1_a_paragraph_segmentation_msers.ipynb ├── 1_b_paragraph_segmentation_dcnn.ipynb ├── 2_line_word_segmentation.ipynb ├── 3_handwriting_recognition.ipynb ├── 4_text_denoising.ipynb ├── 5_a_character_error_distance.ipynb ├── 5_b_visual_distance.ipynb ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── credentials.json.example ├── get_models.py ├── images ├── love.png ├── output_13_0.png ├── output_55_0.png ├── output_59_1.png └── output_6_1.png ├── localizing.md ├── model_checkpoint └── README.md ├── ocr ├── __init__.py ├── evaluate_cer.py ├── handwriting_line_recognition.py ├── paragraph_segmentation_dcnn.py ├── utils │ ├── __init__.py │ ├── beam_search.py │ ├── denoiser_utils.py │ ├── draw_box_on_image.py │ ├── draw_text_on_image.py │ ├── encoder_decoder.py │ ├── expand_bounding_box.py │ ├── iam_dataset.py │ ├── iou_loss.py │ ├── lexicon_search.py │ ├── max_flow.py │ ├── ngram_dataset.py │ ├── noisy_forms_dataset.py │ ├── sclite_helper.py │ ├── test_iam_dataset.ipynb │ └── word_to_line.py └── word_and_line_segmentation.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | *~ 3 | handwritting/* 4 | dataset/* 5 | credentials.json 6 | logs/* 7 | nohup.out 8 | model_checkpoint/*.params 9 | *.bin.gz 10 | *__pycache__* 11 | tmp/* 12 | ocr/utils/sctk-*/ 13 | models/* 14 | -------------------------------------------------------------------------------- /5_a_character_error_distance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Model Distance between characters" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 4, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import mxnet as mx\n", 18 | "import difflib\n", 19 | "\n", 20 | "from ocr.handwriting_line_recognition import Network as BiLSTMNetwork, decode as topK_decode\n", 21 | "from ocr.utils.noisy_forms_dataset import Noisy_forms_dataset\n", 22 | "from ocr.utils.ngram_dataset import Ngram_dataset\n", 23 | "from ocr.utils.iam_dataset import resize_image" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Decode noisy forms" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "We want to find what characters are more likely to be confused with each others to build a distance model between them" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "For that we do a diff of the predictions vs the form" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 6, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "line_image_size = (60, 800)\n", 54 | "def handwriting_recognition_transform(image):\n", 55 | " image, _ = resize_image(image, line_image_size)\n", 56 | " image = mx.nd.array(image)/255.\n", 57 | " image = (image - 0.942532484060557) / 0.15926149044640417\n", 58 | " image = image.as_in_context(ctx)\n", 59 | " image = image.expand_dims(0).expand_dims(0)\n", 60 | " return image\n", 61 | "\n", 62 | "def get_ns(is_train):\n", 63 | " network = BiLSTMNetwork(rnn_hidden_states=512, rnn_layers=2, max_seq_len=160, ctx=ctx)\n", 64 | " network.load_parameters(\"models/handwriting_line_sl_160_a_512_o_2.params\", ctx=ctx)\n", 65 | "\n", 66 | " def noise_source_transform(image, text):\n", 67 | " image = handwriting_recognition_transform(image)\n", 68 | " output = network(image)\n", 69 | " predict_probs = output.softmax().asnumpy()\n", 70 | " return predict_probs\n", 71 | " ns = Noisy_forms_dataset(noise_source_transform, train=is_train, name=\"OCR_noise2\", topK_decode=topK_decode)\n", 72 | " return ns" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 9, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "ctx = mx.gpu(0) if mx.context.num_gpus() > 0 else mx.cpu()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 8, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "train_ns = get_ns(is_train=True)\n", 91 | "ng_train_ds = Ngram_dataset(train_ns, \"word_5train\", output_type=\"word\", n=5)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "#### Using ndiff to diff the expected result and the predicted results" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 13, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "insertions = []\n", 108 | "deletions = []\n", 109 | "substitutions = []\n", 110 | "\n", 111 | "for i in range(len(ng_train_ds)):\n", 112 | " _, _, noisy, actual = ng_train_ds[i]\n", 113 | " diffs = []\n", 114 | " for diff in difflib.ndiff(noisy, actual):\n", 115 | " if diff[0] == \"+\" or diff[0] == \"-\":\n", 116 | " diffs.append(diff)\n", 117 | " if len(diffs) == 1:\n", 118 | " if diffs[0][0] == \"+\":\n", 119 | " insertions.append(diffs[0][-1])\n", 120 | " if diffs[0][0] == \"-\":\n", 121 | " deletions.append(diffs[0][-1])\n", 122 | " if len(diffs) == 2:\n", 123 | " if diffs[0][0] == \"+\" and diffs[1][0] == \"-\" or diffs[0][0] == \"-\" and diffs[1][0] == \"+\":\n", 124 | " changes1 = (diffs[0][-1], diffs[1][-1])\n", 125 | " changes2 = (diffs[1][-1], diffs[0][-1])\n", 126 | " substitutions.append(changes1)\n", 127 | " substitutions.append(changes2)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "#### Using SequenceMatcher to diff the expected result and the predicted results" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 14, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "insertions = []\n", 144 | "deletions = []\n", 145 | "substitutions = []\n", 146 | "output = []\n", 147 | "for i in range(len(ng_train_ds)):\n", 148 | " _, _, noisy, actual = ng_train_ds[i]\n", 149 | " seqm = difflib.SequenceMatcher(None, noisy, actual)\n", 150 | " for opcode, a0, a1, b0, b1 in seqm.get_opcodes():\n", 151 | " if opcode == 'equal':\n", 152 | " output.append(seqm.a[a0:a1])\n", 153 | " elif opcode == 'insert':\n", 154 | " for char in seqm.b[b0:b1]:\n", 155 | " insertions.append(char)\n", 156 | " elif opcode == 'delete':\n", 157 | " for char in seqm.a[a0:a1]:\n", 158 | " deletions.append(char)\n", 159 | " elif opcode == 'replace':\n", 160 | " # seqm.a[a0:a1] -> seqm.b[b0:b1]\n", 161 | " if len(seqm.a[a0:a1]) == len(seqm.b[b0:b1]):\n", 162 | " for charA, charB in zip(seqm.a[a0:a1], seqm.b[b0:b1]):\n", 163 | " substitutions.append((charA, charB))\n", 164 | " else:\n", 165 | " pass" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 20, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", 178 | " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.9 0.9 1.\n", 179 | " 1. 1. 0.8 0.8 0.8 0.8 1. 1. 0.8 1. 0.8 1. 0.9 0.8 1. 0.9 1. 0.9\n", 180 | " 0.9 0.9 1. 0.9 0.8 0.8 1. 1. 1. 0.9 1. 0.8 0.9 0.8 0.8 0.8 0.9 0.8\n", 181 | " 0.8 0.8 0.9 0.8 0.8 0.9 0.9 0.9 0.9 0.9 0.8 0.8 0.8 0.9 0.9 0.8 1. 0.9\n", 182 | " 1. 1. 1. 1. 1. 1. 1. 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.9 0.8\n", 183 | " 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.8 0.9 0.8 1. 1. 1. 1.\n", 184 | " 1. 1. ]\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "insertion_dict = {}\n", 190 | "for insertion in insertions:\n", 191 | " if insertion not in insertion_dict:\n", 192 | " insertion_dict[insertion] = 0\n", 193 | " insertion_dict[insertion] += 1\n", 194 | "insertion_costs = np.ones(128, dtype=np.float64)\n", 195 | "for key in insertion_dict:\n", 196 | " insertion_costs[ord(key)] = 0.9 if insertion_dict[key] <= 4 else 0.8\n", 197 | "print(insertion_costs)\n", 198 | "np.savetxt(\"models/insertion_costs.txt\", insertion_costs, fmt='%4.6f')" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 21, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "{'h': 40, 'r': 22, 'i': 17, 'W': 3, 'y': 8, 't': 51, 'n': 21, 'l': 14, 'e': 39, 'a': 23, 'A': 7, 's': 24, '.': 8, 'H': 2, 'u': 6, 'o': 14, 'm': 13, 'p': 4, 'S': 2, 'w': 20, 'x': 1, 'F': 3, 'T': 9, '1': 12, '5': 11, 'c': 12, 'M': 5, 'f': 2, 'G': 2, 'b': 4, 'g': 1, 'd': 8, ',': 3, '0': 1, 'B': 2, 'C': 3, '\"': 1, 'I': 1, 'v': 1}\n", 211 | "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", 212 | " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.9 1.\n", 213 | " 1. 1. 1. 1. 1. 1. 1. 1. 0.9 1. 0.8 1. 0.9 0.8 1. 1. 1. 0.8\n", 214 | " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.8 0.9 0.9 1. 1. 0.9 0.9\n", 215 | " 0.9 0.9 1. 1. 1. 0.8 1. 1. 1. 1. 1. 0.9 0.8 1. 1. 0.9 1. 1.\n", 216 | " 1. 1. 1. 1. 1. 1. 1. 0.8 0.9 0.8 0.8 0.8 0.9 0.9 0.8 0.8 1. 1.\n", 217 | " 0.8 0.8 0.8 0.8 0.9 1. 0.8 0.8 0.8 0.8 0.9 0.8 0.9 0.8 1. 1. 1. 1.\n", 218 | " 1. 1. ]\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "deletion_dict = {}\n", 224 | "for deletion in deletions:\n", 225 | " if deletion not in deletion_dict:\n", 226 | " deletion_dict[deletion] = 0\n", 227 | " deletion_dict[deletion] += 1\n", 228 | "print(deletion_dict)\n", 229 | "deletion_costs = np.ones(128, dtype=np.float64)\n", 230 | "for key in deletion_dict:\n", 231 | " deletion_costs[ord(key)] = 0.9 if deletion_dict[key] <= 4 else 0.8\n", 232 | "print(deletion_costs)\n", 233 | "np.savetxt(\"models/deletion_costs.txt\", deletion_costs, fmt='%4.6f')" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 22, 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "{('r', 's'): 5, ('l', 't'): 8, ('t', 'h'): 5, ('t', 'l'): 13, ('n', 'm'): 18, ('M', 'U'): 1, ('f', 't'): 1, ('A', 'N'): 1, ('e', 'o'): 13, ('e', 'u'): 2, ('n', 'r'): 9, ('h', 'k'): 4, ('e', 'a'): 18, ('c', 'e'): 3, ('.', ','): 21, ('H', 'M'): 1, ('c', 'C'): 3, ('t', 'r'): 4, ('L', 'h'): 1, ('W', 'b'): 1, ('r', 'e'): 3, ('r', 'R'): 1, ('r', 'n'): 10, ('r', 'v'): 5, ('P', 'R'): 1, ('o', 'e'): 6, ('v', 'r'): 4, ('t', 'd'): 4, ('n', 'a'): 1, ('h', 'L'): 1, ('W', 'S'): 1, ('W', 'w'): 3, ('r', 'x'): 2, ('c', 't'): 3, ('C', 'G'): 1, ('L', 't'): 1, ('a', 'b'): 1, ('e', 'M'): 3, ('y', 'g'): 6, ('e', 'm'): 1, ('a', 'o'): 24, ('S', 'I'): 1, ('r', 'i'): 3, ('w', 's'): 2, ('j', 'S'): 1, ('e', 'E'): 4, ('k', 'l'): 2, ('n', 't'): 2, ('t', 'k'): 2, ('e', 'w'): 1, ('h', '\"'): 1, ('t', 'M'): 1, ('\"', \"'\"): 6, (',', '.'): 13, ('w', 'a'): 1, ('l', 'L'): 2, ('l', 'h'): 3, ('e', 'n'): 3, ('u', 'n'): 3, ('f', 'F'): 1, ('f', 'P'): 1, ('t', 'n'): 1, ('l', 'n'): 1, ('n', 'u'): 5, ('o', 'a'): 6, ('t', 'f'): 4, ('W', 'I'): 1, ('t', 'b'): 2, ('w', 'I'): 1, ('l', 'k'): 1, ('c', 'o'): 2, ('t', 'H'): 2, ('s', 'o'): 2, ('c', 'r'): 1, ('a', 'e'): 6, ('i', 'a'): 1, ('a', 'A'): 9, ('o', 's'): 2, ('w', 'v'): 5, ('d', 'l'): 1, ('e', 'y'): 3, ('a', 'c'): 1, ('t', 'A'): 3, ('o', 'r'): 1, ('d', 'D'): 1, ('E', 'r'): 1, ('g', 'q'): 1, ('l', 's'): 2, ('S', 's'): 1, ('u', 'o'): 3, ('A', 'b'): 2, (',', ';'): 5, ('a', 'n'): 2, ('t', 's'): 2, ('F', 'f'): 1, ('o', 'O'): 3, ('y', 'e'): 1, ('n', 'c'): 1, ('t', '.'): 1, ('k', 'x'): 1, ('A', 'I'): 1, ('c', 's'): 2, ('e', 'c'): 4, ('l', 'b'): 2, ('e', 's'): 2, ('M', 'l'): 2, ('L', 'R'): 1, ('t', 'T'): 4, ('o', 'y'): 1, ('m', 'n'): 5, ('3', '8'): 2, ('s', 'g'): 1, ('e', 'i'): 2, ('.', 'I'): 1, ('s', 'k'): 1, ('B', 'b'): 2, ('a', 'u'): 3, ('I', 'i'): 1, ('w', 't'): 2, ('h', 'b'): 2, ('M', 'H'): 1, ('u', 'i'): 1, ('T', 't'): 1, ('w', 'r'): 1, ('T', 'i'): 1, ('n', 's'): 4, ('s', 'r'): 1, ('.', ':'): 1, ('g', 'G'): 1, ('m', 'v'): 1, ('h', 'n'): 2, ('i', 'o'): 1, ('w', 'i'): 1, ('N', 'M'): 1, ('H', 't'): 1, ('1', 'S'): 1, ('a', 'i'): 2, ('e', 'k'): 2, ('n', 'h'): 1, ('9', '3'): 1, ('f', 'G'): 1, ('w', 'W'): 1, ('h', 'H'): 3, ('n', 'N'): 2, ('l', 'U'): 1, ('i', \"'\"): 1, ('M', 'L'): 1, ('k', 'c'): 1, ('e', 'f'): 1, ('v', 'w'): 1, ('k', 'd'): 1, ('t', 'e'): 1, ('n', 'v'): 1, ('s', 't'): 1, ('e', 'p'): 2, ('k', 'w'): 1, ('s', 'd'): 1, ('r', 'b'): 1, ('t', ')'): 1, ('e', 'l'): 1, ('f', 'p'): 1, ('o', 'i'): 2, ('c', 'g'): 1, ('a', 't'): 3, ('t', 'z'): 1, ('i', 'e'): 1, ('p', 'b'): 4, ('s', 'S'): 1, ('r', 'w'): 1, ('t', 'c'): 1, ('C', 'c'): 1, ('E', '1'): 1, ('S', '5'): 1, ('s', 'c'): 1, ('z', 't'): 1, ('b', 'o'): 1, ('o', 'b'): 2, ('e', 't'): 1, ('a', 'l'): 2, ('r', 'z'): 1, ('i', 'u'): 1, ('o', 'u'): 1, ('c', 'd'): 1, ('c', 'p'): 1, ('G', 'c'): 1, ('M', 'n'): 1, ('p', 'r'): 1, ('B', 'h'): 1, ('s', 'n'): 2, ('x', 't'): 1, ('l', 'd'): 1, ('c', 'v'): 1, ('C', 'L'): 1, ('l', 'u'): 1, ('w', '\"'): 1, ('n', 'w'): 1, ('t', 'm'): 1}\n", 246 | "[[1. 1. 1. ... 1. 1. 1.]\n", 247 | " [1. 1. 1. ... 1. 1. 1.]\n", 248 | " [1. 1. 1. ... 1. 1. 1.]\n", 249 | " ...\n", 250 | " [1. 1. 1. ... 1. 1. 1.]\n", 251 | " [1. 1. 1. ... 1. 1. 1.]\n", 252 | " [1. 1. 1. ... 1. 1. 1.]]\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "substitution_dict = {}\n", 258 | "for subs in substitutions:\n", 259 | " if subs not in substitution_dict:\n", 260 | " substitution_dict[subs] = 0\n", 261 | " substitution_dict[subs] += 1\n", 262 | "print(substitution_dict)\n", 263 | "substitute_costs = np.ones((128, 128), dtype=np.float64)\n", 264 | "for key in substitution_dict:\n", 265 | " key1, key2 = key\n", 266 | " substitute_costs[ord(key1), ord(key2)] = 0.9 if substitution_dict[key] <= 4 else 0.8\n", 267 | "print(substitute_costs)\n", 268 | "np.savetxt(\"models/substitute_costs.txt\", substitute_costs, fmt='%4.6f')" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [] 277 | } 278 | ], 279 | "metadata": { 280 | "kernelspec": { 281 | "display_name": "Python 3", 282 | "language": "python", 283 | "name": "python3" 284 | }, 285 | "language_info": { 286 | "codemirror_mode": { 287 | "name": "ipython", 288 | "version": 3 289 | }, 290 | "file_extension": ".py", 291 | "mimetype": "text/x-python", 292 | "name": "python", 293 | "nbconvert_exporter": "python", 294 | "pygments_lexer": "ipython3", 295 | "version": "3.6.4" 296 | } 297 | }, 298 | "nbformat": 4, 299 | "nbformat_minor": 2 300 | } 301 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check [existing open](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/issues), or [recently closed](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/labels/help%20wanted) issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | [handwritten-text-recognition-for-apache-mxnet] 2 | Copyright [2018]-[2019] Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten Text Recognition (OCR) with MXNet Gluon 2 | 3 | ## Local Setup 4 | 5 | `git clone https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet --recursive` 6 | 7 | You need to install SCLITE for WER evaluation 8 | You can follow the following bash script from this folder: 9 | 10 | ```bash 11 | cd .. 12 | git clone https://github.com/usnistgov/SCTK 13 | cd SCTK 14 | export CXXFLAGS="-std=c++11" && make config 15 | make all 16 | make check 17 | make install 18 | make doc 19 | cd - 20 | ``` 21 | 22 | You also need hsnwlib 23 | 24 | ```bash 25 | pip install pybind11 numpy setuptools 26 | cd .. 27 | git clone https://github.com/nmslib/hnswlib 28 | cd hnswlib/python_bindings 29 | python setup.py install 30 | cd ../.. 31 | ``` 32 | if "AssertionError: Please enter credentials for the IAM dataset in credentials.json or as arguments" occurs rename credentials.json.example and to credentials.json with your username and password. 33 | 34 | ## Overview 35 | 36 | ![](https://cdn-images-1.medium.com/max/1000/1*nJ-ePgwhOjOhFH3lJuSuFA.png) 37 | 38 | The pipeline is composed of 3 steps: 39 | - Detecting the handwritten area in a form [[blog post](https://medium.com/apache-mxnet/page-segmentation-with-gluon-dcb4e5955e2)], [[jupyter notebook](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/1_b_paragraph_segmentation_dcnn.ipynb)], [[python script](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/ocr/scripts/paragraph_segmentation_dcnn.py)] 40 | - Detecting lines of handwritten texts [[blog post](https://medium.com/apache-mxnet/handwriting-ocr-line-segmentation-with-gluon-7af419f3a3d8)], [[jupyter notebook](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/2_line_word_segmentation.ipynb)], [[python script](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/word_and_line_segmentation.py)] 41 | - Recognising characters and applying a language model to correct errors. [[blog post](https://medium.com/apache-mxnet/handwriting-ocr-handwriting-recognition-and-language-modeling-with-mxnet-gluon-4c7165788c67)], [[jupyter notebook](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/3_handwriting_recognition.ipynb)], [[python script](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/ocr/scripts/handwriting_line_recognition.py)] 42 | 43 | The entire inference pipeline can be found in this [notebook](https://github.com/awslabs/handwritten-text-recognition-for-apache-mxnet/blob/master/0_handwriting_ocr.ipynb). See the *pretrained models* section for the pretrained models. 44 | 45 | A recorded talk detailing the approach is available on youtube. [[video](https://www.youtube.com/watch?v=xDcOdif4lj0)] 46 | 47 | The corresponding slides are available on slideshare. [[slides](https://www.slideshare.net/apachemxnet/ocr-with-mxnet-gluon)] 48 | 49 | ## Pretrained models: 50 | 51 | You can get the models by running `python get_models.py`: 52 | 53 | ## Sample results 54 | 55 | ![](https://cdn-images-1.medium.com/max/2000/1*8lnqqlqomgdGshJB12dW1Q.png) 56 | 57 | The greedy, lexicon search, and beam search outputs present similar and reasonable predictions for the selected examples. In Figure 6, interesting examples are presented. The first line of Figure 6 show cases where the lexicon search algorithm provided fixes that corrected the words. In the top example, “tovely” (as it was written) was corrected “lovely” and “woved” was corrected to “waved”. In addition, the beam search output corrected “a” into “all”, however it missed a space between “lovely” and “things”. In the second example, “selt” was converted to “salt” with the lexicon search output. However, “selt” was erroneously converted to “self” in the beam search output. Therefore, in this example, beam search performed worse. In the third example, none of the three methods significantly provided comprehensible results. Finally, in the forth example, the lexicon search algorithm incorrectly converted “forhim” into “forum”, however the beam search algorithm correctly identified “for him”. 58 | 59 | ## Dataset: 60 | * To use test_iam_dataset.ipynb, create credentials.json using credentials.json.example and editing the appropriate field. The username and password can be obtained from http://www.fki.inf.unibe.ch/DBs/iamDB/iLogin/index.php. 61 | 62 | * **It is recommended to use an instance with 32GB+ RAM and 100GB disk size, a GPU is also recommended. A p3.2xlarge would be the recommended starter instance on AWS for this project** 63 | 64 | ## Appendix 65 | 66 | ### 1) Handwritten area 67 | 68 | ##### Model architecture 69 | 70 | ![](https://cdn-images-1.medium.com/max/1000/1*AggJmOXhjSySPf_4rPk4FA.png) 71 | 72 | ##### Results 73 | 74 | ![](https://cdn-images-1.medium.com/max/800/1*HEb82jJp93I0EFgYlJhfAw.png) 75 | 76 | ### 2) Line Detection 77 | 78 | ##### Model architecture 79 | 80 | ![](https://cdn-images-1.medium.com/max/800/1*jMkO7hy-1f0ZFHT3S2iH0Q.png) 81 | 82 | ##### Results 83 | 84 | ![](https://cdn-images-1.medium.com/max/1000/1*JJGwLXJL-bV7zsfrfw84ew.png) 85 | 86 | ### 3) Handwritten text recognition 87 | 88 | ##### Model architecture 89 | 90 | ![](https://cdn-images-1.medium.com/max/800/1*JTbCUnKgAySN--zJqzqy0Q.png) 91 | 92 | ##### Results 93 | 94 | ![](https://cdn-images-1.medium.com/max/2000/1*8lnqqlqomgdGshJB12dW1Q.png) 95 | 96 | -------------------------------------------------------------------------------- /credentials.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "username": "", 3 | "password": "" 4 | } 5 | -------------------------------------------------------------------------------- /get_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | from os import path 6 | import zipfile 7 | 8 | import mxnet as mx 9 | 10 | dirname = 'dataset' 11 | if not path.isdir(dirname): 12 | os.makedirs(dirname) 13 | 14 | dirname = 'models' 15 | if not path.isdir(dirname): 16 | os.makedirs(dirname) 17 | 18 | print("Downloading Paragraph Segmentation parameters") 19 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/paragraph_segmentation2.params', dirname=dirname) 20 | 21 | print("Downloading Word Segmentation parameters") 22 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/word_segmentation2.params', dirname=dirname) 23 | 24 | print("Downloading Handwriting Line Recognition parameters") 25 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/handwriting_line8.params', dirname=dirname) 26 | 27 | print("Downloading Denoiser parameters") 28 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/denoiser2.params', dirname=dirname) 29 | 30 | print("Downloading cost matrices") 31 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/deletion_costs.txt', dirname=dirname) 32 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/substitute_costs.txt', dirname=dirname) 33 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/insertion_costs.txt', dirname=dirname) 34 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/substitute_probs.json', dirname=dirname) 35 | 36 | print("Downloading fonts") 37 | dirname = path.join('dataset','fonts') 38 | if not path.isdir(dirname): 39 | os.makedirs(dirname) 40 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/fonts.zip', dirname=dirname) 41 | with zipfile.ZipFile(path.join(dirname, "fonts.zip"),"r") as zip_ref: 42 | zip_ref.extractall(dirname) 43 | 44 | print("Downloading text datasets") 45 | dirname = path.join('dataset','typo') 46 | if not path.isdir(dirname): 47 | os.makedirs(dirname) 48 | 49 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/alicewonder.txt', dirname=dirname) 50 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/all.txt', dirname=dirname) 51 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/text_train.txt', dirname=dirname) 52 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/validating.json', dirname=dirname) 53 | mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/typo-corpus-r1.txt', dirname=dirname) 54 | -------------------------------------------------------------------------------- /images/love.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/handwritten-text-recognition-for-apache-mxnet/d0079cb9a8775b8be583b213f2f03e5e2cf9adac/images/love.png -------------------------------------------------------------------------------- /images/output_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/handwritten-text-recognition-for-apache-mxnet/d0079cb9a8775b8be583b213f2f03e5e2cf9adac/images/output_13_0.png -------------------------------------------------------------------------------- /images/output_55_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/handwritten-text-recognition-for-apache-mxnet/d0079cb9a8775b8be583b213f2f03e5e2cf9adac/images/output_55_0.png -------------------------------------------------------------------------------- /images/output_59_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/handwritten-text-recognition-for-apache-mxnet/d0079cb9a8775b8be583b213f2f03e5e2cf9adac/images/output_59_1.png -------------------------------------------------------------------------------- /images/output_6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/handwritten-text-recognition-for-apache-mxnet/d0079cb9a8775b8be583b213f2f03e5e2cf9adac/images/output_6_1.png -------------------------------------------------------------------------------- /localizing.md: -------------------------------------------------------------------------------- 1 | # Localizing for other languages 2 | To localize the handwriting OCR pipeline for another language you are going to need: 3 | 1. A dataset of images. You will need a set of images of pages that contain handwritten text. For each line in the text you will need 4 | the transcription. 5 | 2. A language model trained on the text for your language. It may be helpful to train the language model on text from other sources, not just the text for the images in your data set. 6 | The Gluon NLP toolkit ( http://gluon-nlp.mxnet.io/ ) comes with state-of-the-art language models that you can train on text for a language. Or perhaps if you can tolerate a slightly higher error rate you could skip the steps in the pipeline that use language modeling, e.g. there are use-cases where maximum accuracy is not necessary such as building a search index to return relevant documents. 7 | 3. Modifications to the code to handle the character set. 8 | 9 | The code could be refactored in future to make it easier to adapt to different writing sets, but for now we give some tips to get you started adapting it to your needs. 10 | 11 | Before we dive deep on the code, we give a few tips and tricks for Unicode processing in Python. 12 | 13 | It is easy to get the official Unicode name for each character using unicodedata.name(): 14 | 15 | ```python 16 | import unicodedata 17 | for c in "ekpɔ wò": # Ewe: he saw you 18 | print(unicodedata.name(c)) 19 | ``` 20 | **LATIN SMALL LETTER E** 21 | 22 | **LATIN SMALL LETTER K** 23 | 24 | **LATIN SMALL LETTER P** 25 | 26 | **LATIN SMALL LETTER OPEN O** 27 | 28 | **SPACE** 29 | 30 | **LATIN SMALL LETTER W** 31 | 32 | **LATIN SMALL LETTER O WITH GRAVE** 33 | 34 | You can use the official Unicode names in strings rather than having to decipher hexadecimal values: 35 | 36 | ```python 37 | print("\N{LATIN SMALL LETTER E}\N{LATIN SMALL LETTER K}\N{LATIN SMALL LETTER P}\N{LATIN SMALL LETTER OPEN O}") 38 | ``` 39 | **ekpɔ** 40 | 41 | These named Unicode codepoints can be mixed and matched with simple text in the interest of brevity, e.g. 42 | ```python 43 | print("ekp\N{LATIN SMALL LETTER OPEN O}") 44 | ``` 45 | **ekpɔ** 46 | 47 | Some characters have alternate representations in Unicode: precomposed and decomposed. For example, the following character 48 | 49 | á 50 | 51 | can be represented as a single Unicode codepoint: U+00E1 LATIN SMALL LETTER A WITH ACUTE. This is called the composed representation. 52 | 53 | Alternatively we could represent this using two Unicode codepoints: a lowercase a and a zero width combining accent. This is called the decomposed form. The rendering engine in the operating system of your computer knows to put the diacritic on top of the letter a: 54 | 55 | ```python 56 | print("a\N{Combining Acute Accent}") 57 | ``` 58 | **á** 59 | 60 | 61 | There are a few assumptions in the current implementation: 62 | 1. Text is written horizontally. 63 | 2. Text is written left-to-right. The visual sequence of characters on the page corresponds to how the sequence of characters is encoded in memory. 64 | 2. Text is written using the English Latin alphabet. 65 | 3. Words consist of a series of individual atomic characters i.e. precomposed characters. The pipeline does not recognize a letter and an accent diacritic as separate entities. 66 | 67 | We address how to modify the pipeline for each of those assumptions in the sections below. 68 | 69 | ## Text is written horizontally 70 | The line segmentation would need to be changed for vertical writing systems e.g. traditional Mongolian script (https://en.wikipedia.org/wiki/Mongolian_script). 71 | 72 | ## Text is written left-to-right 73 | The current pipeline assumes that the text is written left-to-right AND that the visual sequence of characters on the page 74 | corresponds to how the sequence of characters is encoded in memory. The writing systems for some languages, such as Arabic, 75 | violate both these assumptions. 76 | 77 | Some languages, e.g. Arabic, are written (mostly) right-to-left. But the in-memory encoding of the characters follows the 78 | logical order i.e. the first sound of the first word in a line is encoded in memory as the zeroth character. A line of text 79 | that ends with a question mark will have the question mark as the final codepoint in the in-memory string, but visually it 80 | will be represented as the leftmost glyph. 81 | 82 | To handle these languages, the stages in the OCR pipeline 83 | that locate the text on the page and segment it into horizontal lines can be used as-is. 84 | The point where the process becomes sensitive is calculating the CTC loss – comparing the character guesses from the model 85 | to the reference string. 86 | 87 | The reference string gets encoded character by character here: 88 | https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/ocr/handwriting_line_recognition.py#L213 89 | 90 | For English it’s fairly straightforward – you just go left to right through the characters for each word, noting the index of the character. At the top of that file you’ll see where we list off the characters of the English alphabet: 91 | 92 | https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/ocr/handwriting_line_recognition.py#L29 93 | 94 | So for Arabic you would need to update the set of characters. And then you would need to encode the labels so that they match the visual order that the algorithm would be encountering them. For the simple case where all the text is written right to left it’s just a matter of saying that the zeroth label is the last char of the reference string etc. 95 | Of course, Arabic writing is actually bidirectional so you would need handle that as you were encoding the string. 96 | 97 | ## Text is written using the English Latin alphabet 98 | For languages that use the same Latin alphabet characters as English with no additional characters or diacritics, 99 | e.g. Swahili, you will not need to make deep changes. Change the data set and language model and retrain. 100 | 101 | For languages that use Latin script with some additions e.g. if they add Ɖ 102 | ( https://en.wikipedia.org/wiki/African_D ) you would need to add those characters to the string assigned to the variable alphabet\_encoding here: https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/ocr/handwriting_line_recognition.py#L29. 103 | 104 | For languages that use Latin script plus additional diacritics you will need to use precomposed forms. For many languages, 105 | the precomposed codepoints are what is used by default so no additional processing will be necessary. However, if the 106 | text is represented as decomposed codepoints you can convert it to a precomposed representation using the 107 | unicodedate.normalize() function in Python: 108 | 109 | ```python 110 | import unicodedata 111 | 112 | # an 'a' followed by an accent. The rendering engine will put the accent on top of the a 113 | decomposed = "a\N{Combining Acute Accent}" 114 | print(decomposed) 115 | 116 | precomposed = unicodedata.normalize("NFC", decomposed) 117 | print(precomposed) 118 | print(unicodedata.name(precomposed)) 119 | 120 | print(len(precomposed) == 1) 121 | ``` 122 | **á** 123 | 124 | **á** 125 | 126 | **LATIN SMALL LETTER A WITH ACUTE** 127 | 128 | **True** 129 | 130 | For languages that use non-Latin scripts, you will need to specify the characters. 131 | 132 | ## Specific examples 133 | Putting all these notes together, here's what you would need to do to adapt the OCR pipeline to some specific languages. 134 | 135 | Languages that use a Brahmic script, e.g. Hindi. These scripts are [abugidas](https://en.wikipedia.org/wiki/Abugida). 136 | There are symbols for syllables and diacritics that modify these symbols e.g. to replace the default vowel. For these 137 | languages you would need to convert the sequence of syllable plus diacritic(s) to a precomposed form. All possible 138 | precomposed forms would have to be defined in the variable alphabet\_encoding. These languages are written left-to-right 139 | so no additional changes should be necessary. 140 | 141 | For a language like [Ewe](https://en.wikipedia.org/wiki/Ewe_language), which uses a modified Latin script, you would need 142 | to add the extra characters in all their precomposed forms. For example, the phrase "ekpɔ wò" ("he saw you") shown previously 143 | tells us that we would need to add the LATIN SMALL LETTER OPEN O codepoint and also the precomposed codepoint 144 | LATIN SMALL LETTER O WITH GRAVE. 145 | -------------------------------------------------------------------------------- /model_checkpoint/README.md: -------------------------------------------------------------------------------- 1 | # Intermediary Model Checkpoints 2 | -------------------------------------------------------------------------------- /ocr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /ocr/evaluate_cer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import difflib 5 | import logging 6 | import math 7 | import string 8 | import random 9 | 10 | import numpy as np 11 | import mxnet as mx 12 | from tqdm import tqdm 13 | 14 | 15 | from .paragraph_segmentation_dcnn import SegmentationNetwork, paragraph_segmentation_transform 16 | from .word_and_line_segmentation import SSD as WordSegmentationNet, predict_bounding_boxes 17 | from .handwriting_line_recognition import Network as HandwritingRecognitionNet, handwriting_recognition_transform 18 | from .handwriting_line_recognition import decode as decoder_handwriting, alphabet_encoding 19 | 20 | from .utils.iam_dataset import IAMDataset, crop_handwriting_page 21 | from .utils.sclite_helper import ScliteHelper 22 | from .utils.word_to_line import sort_bbs_line_by_line, crop_line_images 23 | 24 | # Setup 25 | logging.basicConfig(level=logging.DEBUG) 26 | random.seed(123) 27 | np.random.seed(123) 28 | mx.random.seed(123) 29 | 30 | # Input sizes 31 | form_size = (1120, 800) 32 | segmented_paragraph_size = (800, 800) 33 | line_image_size = (60, 800) 34 | 35 | # Parameters 36 | min_c = 0.01 37 | overlap_thres = 0.001 38 | topk = 400 39 | rnn_hidden_states = 512 40 | rnn_layers = 2 41 | max_seq_len = 160 42 | 43 | recognition_model = "models/handwriting_line8.params.params" 44 | paragraph_segmentation_model = "models/paragraph_segmentation2.params" 45 | word_segmentation_model = "models/word_segmentation2.params" 46 | 47 | def get_arg_max(prob): 48 | ''' 49 | The greedy algorithm convert the output of the handwriting recognition network 50 | into strings. 51 | ''' 52 | arg_max = prob.topk(axis=2).asnumpy() 53 | return decoder_handwriting(arg_max)[0] 54 | 55 | denoise_func = get_arg_max 56 | 57 | if __name__ == '__main__': 58 | 59 | # Compute context 60 | ctx = mx.gpu(1) 61 | 62 | # Models 63 | logging.info("Loading models...") 64 | paragraph_segmentation_net = ParagraphSegmentationNet(ctx) 65 | paragraph_segmentation_net.load_parameters(paragraph_segmentation_model, ctx) 66 | 67 | word_segmentation_net = WordSegmentationNet(2, ctx=ctx) 68 | word_segmentation_net.load_parameters(word_segmentation_model, ctx) 69 | 70 | handwriting_line_recognition_net = HandwritingRecognitionNet(rnn_hidden_states=rnn_hidden_states, 71 | rnn_layers=rnn_layers, 72 | max_seq_len=max_seq_len, 73 | ctx=ctx) 74 | handwriting_line_recognition_net.load_parameters(recognition_model, ctx) 75 | logging.info("models loaded.") 76 | 77 | # Data 78 | logging.info("loading data...") 79 | test_ds = IAMDataset("form_original", train=False) 80 | logging.info("data loaded.") 81 | 82 | 83 | sclite = ScliteHelper() 84 | for i in tqdm(range(len(test_ds))): 85 | image, text = test_ds[i] 86 | resized_image = paragraph_segmentation_transform(image, image_size=form_size) 87 | paragraph_bb = paragraph_segmentation_net(resized_image.as_in_context(ctx)) 88 | paragraph_segmented_image = crop_handwriting_page(image, paragraph_bb[0].asnumpy(), image_size=segmented_paragraph_size) 89 | word_bb = predict_bounding_boxes(word_segmentation_net, paragraph_segmented_image, min_c, overlap_thres, topk, ctx) 90 | line_bbs = sort_bbs_line_by_line(word_bb) 91 | line_images = crop_line_images(paragraph_segmented_image, line_bbs) 92 | 93 | predicted_text = [] 94 | for line_image in line_images: 95 | line_image = handwriting_recognition_transform(line_image, line_image_size) 96 | character_probabilities = handwriting_line_recognition_net(line_image.as_in_context(ctx)) 97 | decoded_text = denoise_func(character_probabilities) 98 | predicted_text.append(decoded_text) 99 | 100 | actual_text = text[0].replace(""", '\"').replace("&", "&").replace('";', '\"')[:-1] 101 | actual_text = actual_text.split("\n") 102 | if len(predicted_text) > len(actual_text): 103 | predicted_text = predicted_text[:len(actual_text)] 104 | sclite.add_text(predicted_text, actual_text) 105 | 106 | _, er = sclite.get_cer() 107 | print("Mean CER = {}".format(er)) 108 | -------------------------------------------------------------------------------- /ocr/handwriting_line_recognition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import random 6 | import os 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | import mxnet as mx 11 | import numpy as np 12 | from skimage import transform as skimage_tf 13 | from skimage import exposure 14 | 15 | from mxnet import nd, autograd, gluon 16 | from mxboard import SummaryWriter 17 | from mxnet.gluon.model_zoo.vision import resnet34_v1 18 | np.seterr(all='raise') 19 | 20 | import multiprocessing 21 | mx.random.seed(1) 22 | 23 | from .utils.iam_dataset import IAMDataset, resize_image 24 | from .utils.draw_text_on_image import draw_text_on_image 25 | 26 | print_every_n = 1 27 | send_image_every_n = 20 28 | 29 | # Best results: 30 | # python handwriting_line_recognition.py --epochs 251 -n handwriting_line.params -g 0 -l 0.0001 -x 0.1 -y 0.1 -j 0.15 -k 0.15 -p 0.75 -o 2 -a 128 31 | 32 | alphabet_encoding = r' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 33 | alphabet_dict = {alphabet_encoding[i]:i for i in range(len(alphabet_encoding))} 34 | 35 | class EncoderLayer(gluon.HybridBlock): 36 | ''' 37 | The encoder layer takes the image features from a CNN. The image features are transposed so that the LSTM 38 | slices of the image features can be sequentially fed into the LSTM from left to right (and back via the 39 | bidirectional LSTM). 40 | ''' 41 | def __init__(self, hidden_states=200, rnn_layers=1, max_seq_len=100, **kwargs): 42 | self.max_seq_len = max_seq_len 43 | super(EncoderLayer, self).__init__(**kwargs) 44 | with self.name_scope(): 45 | self.lstm = mx.gluon.rnn.LSTM(hidden_states, rnn_layers, bidirectional=True) 46 | 47 | def hybrid_forward(self, F, x): 48 | x = x.transpose((0, 3, 1, 2)) 49 | x = x.flatten() 50 | x = x.split(num_outputs=self.max_seq_len, axis=1) # (SEQ_LEN, N, CHANNELS) 51 | x = F.concat(*[elem.expand_dims(axis=0) for elem in x], dim=0) 52 | x = self.lstm(x) 53 | x = x.transpose((1, 0, 2)) #(N, SEQ_LEN, HIDDEN_UNITS) 54 | return x 55 | 56 | class Network(gluon.HybridBlock): 57 | ''' 58 | The CNN-biLSTM to recognise handwriting text given an image of handwriten text. 59 | Parameters 60 | ---------- 61 | num_downsamples: int, default 2 62 | The number of times to downsample the image features. Each time the features are downsampled, a new LSTM 63 | is created. 64 | resnet_layer_id: int, default 4 65 | The layer ID to obtain features from the resnet34 66 | lstm_hidden_states: int, default 200 67 | The number of hidden states used in the LSTMs 68 | lstm_layers: int, default 1 69 | The number of layers of LSTMs to use 70 | ''' 71 | FEATURE_EXTRACTOR_FILTER = 64 72 | def __init__(self, num_downsamples=2, resnet_layer_id=4, rnn_hidden_states=200, rnn_layers=1, max_seq_len=100, ctx=mx.gpu(0), **kwargs): 73 | super(Network, self).__init__(**kwargs) 74 | self.p_dropout = 0.5 75 | self.num_downsamples = num_downsamples 76 | self.max_seq_len = max_seq_len 77 | self.ctx = ctx 78 | with self.name_scope(): 79 | self.body = self.get_body(resnet_layer_id=resnet_layer_id) 80 | 81 | self.encoders = gluon.nn.HybridSequential() 82 | with self.encoders.name_scope(): 83 | for i in range(self.num_downsamples): 84 | encoder = self.get_encoder(rnn_hidden_states=rnn_hidden_states, rnn_layers=rnn_layers, max_seq_len=max_seq_len) 85 | self.encoders.add(encoder) 86 | self.decoder = self.get_decoder() 87 | self.downsampler = self.get_down_sampler(self.FEATURE_EXTRACTOR_FILTER) 88 | 89 | def get_down_sampler(self, num_filters): 90 | ''' 91 | Creates a two-stacked Conv-BatchNorm-Relu and then a pooling layer to 92 | downsample the image features by half. 93 | 94 | Parameters 95 | ---------- 96 | num_filters: int 97 | To select the number of filters in used the downsampling convolutional layer. 98 | Returns 99 | ------- 100 | network: gluon.nn.HybridSequential 101 | The downsampler network that decreases the width and height of the image features by half. 102 | 103 | ''' 104 | out = gluon.nn.HybridSequential() 105 | with out.name_scope(): 106 | for _ in range(2): 107 | out.add(gluon.nn.Conv2D(num_filters, 3, strides=1, padding=1)) 108 | out.add(gluon.nn.BatchNorm(in_channels=num_filters)) 109 | out.add(gluon.nn.Activation('relu')) 110 | out.add(gluon.nn.MaxPool2D(2)) 111 | out.collect_params().initialize(mx.init.Normal(), ctx=self.ctx) 112 | out.hybridize() 113 | return out 114 | 115 | def get_body(self, resnet_layer_id): 116 | ''' 117 | Create the feature extraction network based on resnet34. 118 | The first layer of the res-net is converted into grayscale by averaging the weights of the 3 channels 119 | of the original resnet. 120 | 121 | Parameters 122 | ---------- 123 | resnet_layer_id: int 124 | The resnet_layer_id specifies which layer to take from 125 | the bottom of the network. 126 | Returns 127 | ------- 128 | network: gluon.nn.HybridSequential 129 | The body network for feature extraction based on resnet 130 | ''' 131 | 132 | pretrained = resnet34_v1(pretrained=True, ctx=self.ctx) 133 | pretrained_2 = resnet34_v1(pretrained=True, ctx=mx.cpu(0)) 134 | first_weights = pretrained_2.features[0].weight.data().mean(axis=1).expand_dims(axis=1) 135 | # First weights could be replaced with individual channels. 136 | 137 | body = gluon.nn.HybridSequential() 138 | with body.name_scope(): 139 | first_layer = gluon.nn.Conv2D(channels=64, kernel_size=(7, 7), padding=(3, 3), strides=(2, 2), in_channels=1, use_bias=False) 140 | first_layer.initialize(mx.init.Xavier(), ctx=self.ctx) 141 | first_layer.weight.set_data(first_weights) 142 | body.add(first_layer) 143 | body.add(*pretrained.features[1:-resnet_layer_id]) 144 | return body 145 | 146 | def get_encoder(self, rnn_hidden_states, rnn_layers, max_seq_len): 147 | ''' 148 | Creates an LSTM to learn the sequential component of the image features. 149 | 150 | Parameters 151 | ---------- 152 | 153 | rnn_hidden_states: int 154 | The number of hidden states in the RNN 155 | 156 | rnn_layers: int 157 | The number of layers to stack the RNN 158 | Returns 159 | ------- 160 | 161 | network: gluon.nn.Sequential 162 | The encoder network to learn the sequential information of the image features 163 | ''' 164 | 165 | encoder = gluon.nn.HybridSequential() 166 | with encoder.name_scope(): 167 | encoder.add(EncoderLayer(hidden_states=rnn_hidden_states, rnn_layers=rnn_layers, max_seq_len=max_seq_len)) 168 | encoder.add(gluon.nn.Dropout(self.p_dropout)) 169 | encoder.collect_params().initialize(mx.init.Xavier(), ctx=self.ctx) 170 | return encoder 171 | 172 | def get_decoder(self): 173 | ''' 174 | Creates a network to convert the output of the encoder into characters. 175 | ''' 176 | 177 | alphabet_size = len(alphabet_encoding) + 1 178 | decoder = mx.gluon.nn.Dense(units=alphabet_size, flatten=False) 179 | decoder.collect_params().initialize(mx.init.Xavier(), ctx=self.ctx) 180 | return decoder 181 | 182 | def hybrid_forward(self, F, x): 183 | features = self.body(x) 184 | hidden_states = [] 185 | hs = self.encoders[0](features) 186 | hidden_states.append(hs) 187 | for i, _ in enumerate(range(self.num_downsamples - 1)): 188 | features = self.downsampler(features) 189 | hs = self.encoders[i+1](features) 190 | hidden_states.append(hs) 191 | hs = F.concat(*hidden_states, dim=2) 192 | output = self.decoder(hs) 193 | return output 194 | 195 | def handwriting_recognition_transform(image, line_image_size): 196 | ''' 197 | Resize and normalise the image to be fed into the network. 198 | ''' 199 | image, _ = resize_image(image, line_image_size) 200 | image = mx.nd.array(image)/255. 201 | image = (image - 0.942532484060557) / 0.15926149044640417 202 | image = image.expand_dims(0).expand_dims(0) 203 | return image 204 | 205 | def transform(image, label): 206 | ''' 207 | This function resizes the input image and converts so that it could be fed into the network. 208 | Furthermore, the label (text) is one-hot encoded. 209 | ''' 210 | image = np.expand_dims(image, axis=0).astype(np.float32) 211 | if image[0, 0, 0] > 1: 212 | image = image/255. 213 | image = (image - 0.942532484060557) / 0.15926149044640417 214 | label_encoded = np.zeros(max_seq_len, dtype=np.float32)-1 215 | i = 0 216 | for word in label: 217 | word = word.replace(""", r'"') 218 | word = word.replace("&", r'&') 219 | word = word.replace('";', '\"') 220 | for letter in word: 221 | label_encoded[i] = alphabet_dict[letter] 222 | i += 1 223 | return image, label_encoded 224 | 225 | def augment_transform(image, label): 226 | ''' 227 | This function randomly: 228 | - translates the input image by +-width_range and +-height_range (percentage). 229 | - scales the image by y_scaling and x_scaling (percentage) 230 | - shears the image by shearing_factor (radians) 231 | ''' 232 | 233 | ty = random.uniform(-random_y_translation, random_y_translation) 234 | tx = random.uniform(-random_x_translation, random_x_translation) 235 | 236 | sx = random.uniform(1. - random_y_scaling, 1. + random_y_scaling) 237 | sy = random.uniform(1. - random_x_scaling, 1. + random_x_scaling) 238 | 239 | s = random.uniform(-random_shearing, random_shearing) 240 | gamma = random.uniform(0.001, random_gamma) 241 | image = exposure.adjust_gamma(image, gamma) 242 | 243 | st = skimage_tf.AffineTransform(scale=(sx, sy), 244 | shear=s, 245 | translation=(tx*image.shape[1], ty*image.shape[0])) 246 | augmented_image = skimage_tf.warp(image, st, cval=1.0) 247 | return transform(augmented_image*255., label) 248 | 249 | 250 | def decode(prediction): 251 | ''' 252 | Returns the string given one-hot encoded vectors. 253 | ''' 254 | 255 | results = [] 256 | for word in prediction: 257 | result = [] 258 | for i, index in enumerate(word): 259 | if i < len(word) - 1 and word[i] == word[i+1] and word[-1] != -1: #Hack to decode label as well 260 | continue 261 | if index == len(alphabet_dict) or index == -1: 262 | continue 263 | else: 264 | result.append(alphabet_encoding[int(index)]) 265 | results.append(result) 266 | words = [''.join(word) for word in results] 267 | return words 268 | 269 | def run_epoch(e, network, dataloader, trainer, log_dir, print_name, is_train): 270 | ''' 271 | Run one epoch to train or test the CNN-biLSTM network 272 | 273 | Parameters 274 | ---------- 275 | 276 | e: int 277 | The epoch number 278 | network: nn.Gluon.HybridSequential 279 | The CNN-biLSTM network 280 | dataloader: gluon.data.DataLoader 281 | The train or testing dataloader that is wrapped around the iam_dataset 282 | 283 | log_dir: Str 284 | The directory to store the log files for mxboard 285 | print_name: Str 286 | Name to print for associating with the data. usually this will be "train" and "test" 287 | 288 | is_train: bool 289 | Boolean to indicate whether or not the network should be updated. is_train should only be set to true for the training data 290 | Returns 291 | ------- 292 | 293 | epoch_loss: float 294 | The loss of the current epoch 295 | ''' 296 | 297 | total_loss = [nd.zeros(1, ctx_) for ctx_ in ctx] 298 | for i, (x_, y_) in enumerate(dataloader): 299 | X = gluon.utils.split_and_load(x_, ctx) 300 | Y = gluon.utils.split_and_load(y_, ctx) 301 | with autograd.record(train_mode=is_train): 302 | output = [network(x) for x in X] 303 | loss_ctc = [ctc_loss(o, y) for o, y in zip(output, Y)] 304 | 305 | if is_train: 306 | [l.backward() for l in loss_ctc] 307 | trainer.step(x_.shape[0]) 308 | 309 | if i == 0 and e % send_image_every_n == 0 and e > 0: 310 | predictions = output[0][:4].softmax().topk(axis=2).asnumpy() 311 | decoded_text = decode(predictions) 312 | image = X[0][:4].asnumpy() 313 | image = image * 0.15926149044640417 + 0.942532484060557 314 | output_image = draw_text_on_image(image, decoded_text) 315 | print("{} first decoded text = {}".format(print_name, decoded_text[0])) 316 | with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw: 317 | sw.add_image('bb_{}_image'.format(print_name), output_image, global_step=e) 318 | 319 | for i, l in enumerate(loss_ctc): 320 | total_loss[i] += l.mean() 321 | 322 | epoch_loss = float(sum([tl.asscalar() for tl in total_loss]))/(len(dataloader)*len(ctx)) 323 | 324 | with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw: 325 | sw.add_scalar('loss', {print_name: epoch_loss}, global_step=e) 326 | 327 | return epoch_loss 328 | 329 | if __name__ == "__main__": 330 | parser = argparse.ArgumentParser() 331 | parser.add_argument("-g", "--gpu_id", default="0", 332 | help="IDs of the GPU to use, -1 for CPU") 333 | 334 | parser.add_argument("-t", "--line_or_word", default="line", 335 | help="to choose the handwriting to train on words or lines") 336 | 337 | parser.add_argument("-u", "--num_downsamples", default=2, 338 | help="Number of downsamples for the res net") 339 | parser.add_argument("-q", "--resnet_layer_id", default=4, 340 | help="layer ID to obtain features from the resnet34") 341 | parser.add_argument("-a", "--rnn_hidden_states", default=200, 342 | help="Number of hidden states for the RNN encoder") 343 | parser.add_argument("-o", "--rnn_layers", default=1, 344 | help="Number of layers for the RNN") 345 | 346 | parser.add_argument("-e", "--epochs", default=121, 347 | help="Number of epochs to run") 348 | parser.add_argument("-l", "--learning_rate", default=0.0001, 349 | help="Learning rate for training") 350 | parser.add_argument("-w", "--lr_scale", default=1, 351 | help="Amount the divide the learning rate") 352 | parser.add_argument("-r", "--lr_period", default=30, 353 | help="Divides the learning rate after period") 354 | 355 | parser.add_argument("-s", "--batch_size", default=64, 356 | help="Batch size") 357 | 358 | parser.add_argument("-x", "--random_x_translation", default=0.03, 359 | help="Randomly translation the image in the x direction (+ or -)") 360 | parser.add_argument("-y", "--random_y_translation", default=0.03, 361 | help="Randomly translation the image in the y direction (+ or -)") 362 | parser.add_argument("-j", "--random_x_scaling", default=0.10, 363 | help="Randomly scale the image in the x direction") 364 | parser.add_argument("-k", "--random_y_scaling", default=0.10, 365 | help="Randomly scale the image in the y direction") 366 | parser.add_argument("-p", "--random_shearing", default=0.5, 367 | help="Randomly shear the image in radians (+ or -)") 368 | parser.add_argument("-ga", "--random_gamma", default=1, 369 | help="Randomly update gamma of image (+ or -)") 370 | 371 | parser.add_argument("-d", "--log_dir", default="./logs", 372 | help="Directory to store the log files") 373 | parser.add_argument("-c", "--checkpoint_dir", default="model_checkpoint", 374 | help="Directory to store the checkpoints") 375 | parser.add_argument("-n", "--checkpoint_name", default="handwriting_line.params", 376 | help="Name to store the checkpoints") 377 | parser.add_argument("-m", "--load_model", default=None, 378 | help="Name of model to load") 379 | parser.add_argument("-sl", "--max-seq-len", default=None, 380 | help="Maximum sequence length") 381 | args = parser.parse_args() 382 | 383 | print(args) 384 | 385 | gpu_ids = [int(elem) for elem in args.gpu_id.split(",")] 386 | 387 | if gpu_ids == [-1]: 388 | ctx=[mx.cpu()] 389 | else: 390 | ctx=[mx.gpu(i) for i in gpu_ids] 391 | 392 | line_or_word = args.line_or_word 393 | assert line_or_word in ["line", "word"], "{} is not a value option in [\"line\", \"word\"]" 394 | 395 | num_downsamples = int(args.num_downsamples) 396 | resnet_layer_id = int(args.resnet_layer_id) 397 | rnn_hidden_states = int(args.rnn_hidden_states) 398 | rnn_layers = int(args.rnn_layers) 399 | 400 | epochs = int(args.epochs) 401 | learning_rate = float(args.learning_rate) 402 | lr_scale = float(args.lr_scale) 403 | lr_period = float(args.lr_period) 404 | batch_size = int(args.batch_size) 405 | 406 | random_y_translation, random_x_translation = float(args.random_x_translation), float(args.random_y_translation) 407 | random_y_scaling, random_x_scaling = float(args.random_y_scaling), float(args.random_x_scaling) 408 | random_shearing = float(args.random_shearing) 409 | random_gamma = float(args.random_gamma) 410 | 411 | log_dir = args.log_dir 412 | checkpoint_dir, checkpoint_name = args.checkpoint_dir, args.checkpoint_name 413 | load_model = args.load_model 414 | max_seq_len = args.max_seq_len 415 | 416 | if max_seq_len is not None: 417 | max_seq_len = int(max_seq_len) 418 | elif line_or_word == "line": 419 | max_seq_len = 100 420 | else: 421 | max_seq_len = 32 422 | 423 | net = Network(num_downsamples=num_downsamples, resnet_layer_id=resnet_layer_id , rnn_hidden_states=rnn_hidden_states, rnn_layers=rnn_layers, 424 | max_seq_len=max_seq_len, ctx=ctx) 425 | 426 | if load_model is not None and os.path.isfile(os.path.join(checkpoint_dir,load_model)): 427 | net.load_parameters(os.path.join(checkpoint_dir,load_model)) 428 | 429 | train_ds = IAMDataset(line_or_word, output_data="text", train=True) 430 | print("Number of training samples: {}".format(len(train_ds))) 431 | 432 | test_ds = IAMDataset(line_or_word, output_data="text", train=False) 433 | print("Number of testing samples: {}".format(len(test_ds))) 434 | 435 | train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size, shuffle=True, last_batch="rollover", num_workers=4*len(ctx)) 436 | test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size, shuffle=True, last_batch="discard", num_workers=4*len(ctx)) 437 | 438 | schedule = mx.lr_scheduler.FactorScheduler(step=lr_period*len(train_data), factor=lr_scale) 439 | schedule.base_lr = learning_rate 440 | 441 | trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate, "lr_scheduler": schedule, 'clip_gradient': 2}) 442 | 443 | ctc_loss = gluon.loss.CTCLoss() 444 | 445 | best_test_loss = 10e10 446 | for e in range(epochs): 447 | train_loss = run_epoch(e, net, train_data, trainer, log_dir, print_name="train", is_train=True) 448 | test_loss = run_epoch(e, net, test_data, trainer, log_dir, print_name="test", is_train=False) 449 | if test_loss < best_test_loss: 450 | print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss)) 451 | net.save_parameters(os.path.join(checkpoint_dir, checkpoint_name)) 452 | best_test_loss = test_loss 453 | 454 | if e % print_every_n == 0 and e > 0: 455 | print("Epoch {0}, train_loss {1:.6f}, test_loss {2:.6f}".format(e, train_loss, test_loss)) 456 | -------------------------------------------------------------------------------- /ocr/paragraph_segmentation_dcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import multiprocessing 6 | import time 7 | import random 8 | import os 9 | 10 | import matplotlib.pyplot as plt 11 | import matplotlib.patches as patches 12 | 13 | import mxnet as mx 14 | import numpy as np 15 | from skimage.draw import line_aa 16 | from skimage import transform as skimage_transform 17 | 18 | from mxnet import nd, autograd, gluon 19 | from mxnet.image import resize_short 20 | from mxboard import SummaryWriter 21 | 22 | mx.random.seed(1) 23 | 24 | from .utils.iam_dataset import IAMDataset, resize_image 25 | from .utils.iou_loss import IOU_loss 26 | from .utils.draw_box_on_image import draw_box_on_image 27 | 28 | print_every_n = 20 29 | send_image_every_n = 20 30 | save_every_n = 100 31 | 32 | # pre-training: python paragraph_segmentation_dcnn.py -r 0.001 -e 181 -n cnn_mse.params -y 0.15 33 | # fine-tuning: python paragraph_segmentation_dcnn.py -r 0.0001 -l iou -e 150 -n cnn_iou.params -f cnn_mse.params -x 0 -y 0 34 | 35 | def paragraph_segmentation_transform(image, image_size): 36 | ''' 37 | Function used for inference to resize the image for paragraph segmentation 38 | ''' 39 | resized_image, _ = resize_image(image, image_size) 40 | 41 | resized_image = mx.nd.array(resized_image).expand_dims(axis=2) 42 | resized_image = mx.image.resize_short(resized_image, int(800/3)) 43 | resized_image = resized_image.transpose([2, 0, 1])/255. 44 | resized_image = resized_image.expand_dims(axis=0) 45 | return resized_image 46 | 47 | def transform(data, label): 48 | ''' 49 | Function that converts "data"" into the input image tensor for a CNN 50 | Label is converted into a float tensor. 51 | ''' 52 | image = mx.nd.array(data).expand_dims(axis=2) 53 | image = resize_short(image, int(800/3)) 54 | image = image.transpose([2, 0, 1])/255. 55 | label = label[0].astype(np.float32) 56 | 57 | bb = label.copy() 58 | new_w = (1 + expand_bb_scale) * bb[2] 59 | new_h = (1 + expand_bb_scale) * bb[3] 60 | 61 | bb[0] = bb[0] - (new_w - bb[2])/2 62 | bb[1] = bb[1] - (new_h - bb[3])/2 63 | bb[2] = new_w 64 | bb[3] = new_h 65 | 66 | return image, mx.nd.array(bb) 67 | 68 | def augment_transform(data, label): 69 | ''' 70 | Function that randomly translates the input image by +-width_range and +-height_range. 71 | The labels (bounding boxes) are also translated by the same amount. 72 | ''' 73 | ty = random.uniform(-random_y_translation, random_y_translation) 74 | tx = random.uniform(-random_x_translation, random_x_translation) 75 | st = skimage_transform.SimilarityTransform(translation=(tx*data.shape[1], ty*data.shape[0])) 76 | data = skimage_transform.warp(data, st) 77 | label = label.copy() 78 | label[0][0] = label[0][0] - tx 79 | label[0][1] = label[0][1] - ty 80 | return transform(data*255., label) 81 | 82 | class SegmentationNetwork(gluon.nn.HybridBlock): 83 | 84 | def __init__(self, p_dropout = 0.5, ctx=mx.cpu()): 85 | super(SegmentationNetwork, self).__init__() 86 | 87 | pretrained = gluon.model_zoo.vision.resnet34_v1(pretrained=True, ctx=ctx) 88 | first_weights = pretrained.features[0].weight.data().mean(axis=1).expand_dims(axis=1) 89 | 90 | body = gluon.nn.HybridSequential(prefix="SegmentationNetwork_") 91 | with body.name_scope(): 92 | first_layer = gluon.nn.Conv2D(channels=64, kernel_size=(7, 7), padding=(3, 3), strides=(2, 2), in_channels=1, use_bias=False) 93 | first_layer.initialize(mx.init.Normal(), ctx=ctx) 94 | first_layer.weight.set_data(first_weights) 95 | body.add(first_layer) 96 | body.add(*pretrained.features[1:6]) 97 | 98 | output = gluon.nn.HybridSequential() 99 | with output.name_scope(): 100 | output.add(gluon.nn.Flatten()) 101 | output.add(gluon.nn.Dense(64, activation='relu')) 102 | output.add(gluon.nn.Dropout(p_dropout)) 103 | output.add(gluon.nn.Dense(64, activation='relu')) 104 | output.add(gluon.nn.Dropout(p_dropout)) 105 | output.add(gluon.nn.Dense(4, activation='sigmoid')) 106 | 107 | output.collect_params().initialize(mx.init.Normal(), ctx=ctx) 108 | body.add(output) 109 | self.cnn = body 110 | 111 | def hybrid_forward(self, F, x): 112 | return self.cnn(x) 113 | 114 | def make_cnn_old(): 115 | p_dropout = 0.5 116 | 117 | cnn = gluon.nn.HybridSequential() 118 | cnn.add(gluon.nn.Conv2D(kernel_size=(3,3), padding=(1,1), channels=16, activation="relu")) 119 | cnn.add(gluon.nn.BatchNorm()) 120 | 121 | cnn.add(gluon.nn.Conv2D(kernel_size=(3,3), padding=(1,1), channels=16, activation="relu")) 122 | cnn.add(gluon.nn.BatchNorm()) 123 | 124 | cnn.add(gluon.nn.Conv2D(kernel_size=(3,3), padding=(1,1), channels=16, activation="relu")) 125 | cnn.add(gluon.nn.BatchNorm()) 126 | 127 | cnn.add(gluon.nn.Conv2D(kernel_size=(3,3), padding=(1,1), channels=16, activation="relu")) 128 | cnn.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) 129 | cnn.add(gluon.nn.BatchNorm()) 130 | 131 | cnn.add(gluon.nn.Conv2D(kernel_size=(3,3), padding=(1,1), channels=16, activation="relu")) 132 | cnn.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) 133 | cnn.add(gluon.nn.BatchNorm()) 134 | 135 | cnn.add(gluon.nn.Flatten()) 136 | cnn.add(gluon.nn.Dense(64, activation='relu')) 137 | cnn.add(gluon.nn.Dropout(p_dropout)) 138 | cnn.add(gluon.nn.Dense(64, activation='relu')) 139 | cnn.add(gluon.nn.Dropout(p_dropout)) 140 | cnn.add(gluon.nn.Dense(4, activation='sigmoid')) 141 | 142 | cnn.hybridize() 143 | cnn.collect_params().initialize(mx.init.Normal(), ctx=ctx) 144 | return cnn 145 | 146 | def run_epoch(e, network, dataloader, loss_function, trainer, log_dir, print_name, is_train): 147 | total_loss = nd.zeros(1, ctx) 148 | for i, (data, label) in enumerate(dataloader): 149 | data = data.as_in_context(ctx) 150 | label = label.as_in_context(ctx) 151 | 152 | with autograd.record(train_mode=is_train): 153 | output = network(data) 154 | loss_i = loss_function(output, label) 155 | if is_train: 156 | loss_i.backward() 157 | trainer.step(data.shape[0]) 158 | 159 | total_loss += loss_i.mean() 160 | 161 | if e % send_image_every_n == 0 and e > 0 and i == 0: 162 | output_image = draw_box_on_image(output.asnumpy(), label.asnumpy(), data.asnumpy()) 163 | epoch_loss = float(total_loss .asscalar())/len(dataloader) 164 | 165 | with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw: 166 | sw.add_scalar('loss', {print_name: epoch_loss}, global_step=e) 167 | if e % send_image_every_n == 0 and e > 0: 168 | output_image[output_image<0] = 0 169 | output_image[output_image>1] = 1 170 | sw.add_image('bb_{}_image'.format(print_name), output_image, global_step=e) 171 | 172 | return epoch_loss 173 | 174 | def main(ctx=mx.gpu()): 175 | if not os.path.isdir(checkpoint_dir): 176 | os.makedirs(checkpoint_dir) 177 | 178 | train_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=True) 179 | print("Number of training samples: {}".format(len(train_ds))) 180 | 181 | test_ds = IAMDataset("form", output_data="bb", output_parse_method="form", train=False) 182 | print("Number of testing samples: {}".format(len(test_ds))) 183 | 184 | train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size, 185 | shuffle=True, num_workers=int(multiprocessing.cpu_count()/2)) 186 | test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size, 187 | shuffle=False, num_workers=int(multiprocessing.cpu_count()/2)) 188 | 189 | net = SegmentationNetwork() 190 | net.hybridize() 191 | net.collect_params().reset_ctx(ctx) 192 | if restore_checkpoint_name: 193 | net.load_parameters("{}/{}".format(checkpoint_dir, restore_checkpoint_name), ctx=ctx) 194 | 195 | trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate}) 196 | best_test_loss = 10e5 197 | for e in range(epochs): 198 | train_loss = run_epoch(e, net, train_data, loss_function=loss_function, log_dir=log_dir, 199 | trainer=trainer, print_name="train", is_train=True) 200 | test_loss = run_epoch(e, net, test_data, loss_function=loss_function, log_dir=log_dir, 201 | trainer=trainer, print_name="test", is_train=False) 202 | if test_loss < best_test_loss: 203 | print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss)) 204 | net.save_parameters(os.path.join(checkpoint_dir, checkpoint_name)) 205 | best_test_loss = test_loss 206 | if e % print_every_n == 0 and e > 0: 207 | print("Epoch {0}, train_loss {1:.6f}, test_loss {2:.6f}".format(e, train_loss, test_loss)) 208 | 209 | if __name__ == "__main__": 210 | loss_options = ["mse", "iou"] 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument("-g", "--gpu_id", default=0, 213 | help="Gpu ID to use, -1 CPU") 214 | parser.add_argument("-l", "--loss", default="mse", 215 | help="Set loss function of the network. Options: {}".format(loss_options)) 216 | parser.add_argument("-e", "--epochs", default=300, 217 | help="The number of epochs to run") 218 | parser.add_argument("-b", "--batch_size", default=32, 219 | help="The batch size used for training") 220 | parser.add_argument("-r", "--learning_rate", default=0.001, 221 | help="The learning rate used for training") 222 | 223 | parser.add_argument("-c", "--checkpoint_dir", default="model_checkpoint", 224 | help="Directory name for the model checkpoint") 225 | parser.add_argument("-n", "--checkpoint_name", default="cnn.params", 226 | help="Name for the model checkpoint") 227 | parser.add_argument("-f", "--restore_checkpoint_name", default=None, 228 | help="Name for the model to restore from") 229 | 230 | parser.add_argument("-d", "--log_dir", default="./logs", 231 | help="Location to save the MXBoard logs") 232 | 233 | parser.add_argument("-s", "--expand_bb_scale", default=0.03, 234 | help="Scale to expand the bounding box") 235 | parser.add_argument("-x", "--random_x_translation", default=0.05, 236 | help="Randomly translate the image by x%") 237 | parser.add_argument("-y", "--random_y_translation", default=0.05, 238 | help="Randomly translate the image by y%") 239 | 240 | args = parser.parse_args() 241 | print(args) 242 | loss = args.loss 243 | epochs = int(args.epochs) 244 | batch_size = int(args.batch_size) 245 | learning_rate = float(args.learning_rate) 246 | checkpoint_dir = args.checkpoint_dir 247 | checkpoint_name = args.checkpoint_name 248 | restore_checkpoint_name = args.restore_checkpoint_name 249 | log_dir = args.log_dir 250 | expand_bb_scale = float(args.expand_bb_scale) 251 | random_x_translation = float(args.random_x_translation) 252 | random_y_translation = float(args.random_y_translation) 253 | gpu_id = int(args.gpu_id) 254 | ctx = mx.gpu(gpu_id) if gpu_id != -1 else mx.cpu() 255 | 256 | assert loss in loss_options, "{} is not an available option {}".format(loss, loss_options) 257 | 258 | if loss == "iou": 259 | loss_function = IOU_loss() 260 | elif loss == "mse": 261 | loss_function = gluon.loss.L2Loss() 262 | 263 | if restore_checkpoint_name: 264 | restore_checkpoint = os.path.join(checkpoint_dir, restore_checkpoint_name) 265 | assert os.path.isfile(restore_checkpoint), "{} does not exist".format(os.path.join(checkpoint_dir, restore_checkpoint_name)) 266 | main(ctx) 267 | -------------------------------------------------------------------------------- /ocr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 -------------------------------------------------------------------------------- /ocr/utils/beam_search.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/githubharald/CTCDecoder 2 | # 3 | #MIT License 4 | 5 | #Copyright (c) 2018 Harald Scheidl 6 | 7 | #Permission is hereby granted, free of charge, to any person obtaining a copy 8 | #of this software and associated documentation files (the "Software"), to deal 9 | #in the Software without restriction, including without limitation the rights 10 | #to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | #copies of the Software, and to permit persons to whom the Software is 12 | #furnished to do so, subject to the following conditions: 13 | 14 | #The above copyright notice and this permission notice shall be included in all 15 | #copies or substantial portions of the Software. 16 | 17 | #THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | #IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | #FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | #AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | #LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | #OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | #SOFTWARE. 24 | 25 | from __future__ import division 26 | from __future__ import print_function 27 | import numpy as np 28 | 29 | class BeamEntry: 30 | "information about one single beam at specific time-step" 31 | def __init__(self): 32 | self.prTotal = 0 # blank and non-blank 33 | self.prNonBlank = 0 # non-blank 34 | self.prBlank = 0 # blank 35 | self.prText = 1 # LM score 36 | self.lmApplied = False # flag if LM was already applied to this beam 37 | self.labeling = () # beam-labeling 38 | 39 | class BeamState: 40 | "information about the beams at specific time-step" 41 | def __init__(self): 42 | self.entries = {} 43 | 44 | def norm(self): 45 | "length-normalise LM score" 46 | for (k, _) in self.entries.items(): 47 | labelingLen = len(self.entries[k].labeling) 48 | self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) 49 | 50 | def sort(self): 51 | "return beam-labelings, sorted by probability" 52 | beams = [v for (_, v) in self.entries.items()] 53 | sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) 54 | return [x.labeling for x in sortedBeams] 55 | 56 | def applyLM(parentBeam, childBeam, classes, lm): 57 | "calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" 58 | if lm and not childBeam.lmApplied: 59 | c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char 60 | c2 = classes[childBeam.labeling[-1]] # second char 61 | lmFactor = 0.01 # influence of language model 62 | bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other 63 | childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence 64 | childBeam.lmApplied = True # only apply LM once per beam entry 65 | 66 | def addBeam(beamState, labeling): 67 | "add beam if it does not yet exist" 68 | if labeling not in beamState.entries: 69 | beamState.entries[labeling] = BeamEntry() 70 | 71 | def ctcBeamSearch(mat, classes, lm, beamWidth): 72 | "beam search as described by the paper of Hwang et al. and the paper of Graves et al." 73 | 74 | blankIdx = len(classes) 75 | maxT, maxC = mat.shape 76 | 77 | # initialise beam state 78 | last = BeamState() 79 | labeling = () 80 | last.entries[labeling] = BeamEntry() 81 | last.entries[labeling].prBlank = 1 82 | last.entries[labeling].prTotal = 1 83 | 84 | # go over all time-steps 85 | for t in range(maxT): 86 | curr = BeamState() 87 | 88 | # get beam-labelings of best beams 89 | bestLabelings = last.sort()[0:beamWidth] 90 | 91 | # go over best beams 92 | for labeling in bestLabelings: 93 | 94 | # probability of paths ending with a non-blank 95 | prNonBlank = 0 96 | # in case of non-empty beam 97 | if labeling: 98 | # probability of paths with repeated last char at the end 99 | try: 100 | prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] 101 | except FloatingPointError: 102 | prNonBlank = 0 103 | 104 | # probability of paths ending with a blank 105 | prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] 106 | 107 | # add beam at current time-step if needed 108 | addBeam(curr, labeling) 109 | 110 | # fill in data 111 | curr.entries[labeling].labeling = labeling 112 | curr.entries[labeling].prNonBlank += prNonBlank 113 | curr.entries[labeling].prBlank += prBlank 114 | curr.entries[labeling].prTotal += prBlank + prNonBlank 115 | curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from 116 | curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling 117 | 118 | # extend current beam-labeling 119 | for c in range(maxC - 1): 120 | # add new char to current beam-labeling 121 | newLabeling = labeling + (c,) 122 | 123 | # if new labeling contains duplicate char at the end, only consider paths ending with a blank 124 | if labeling and labeling[-1] == c: 125 | prNonBlank = mat[t, c] * last.entries[labeling].prBlank 126 | else: 127 | prNonBlank = mat[t, c] * last.entries[labeling].prTotal 128 | 129 | # add beam at current time-step if needed 130 | addBeam(curr, newLabeling) 131 | 132 | # fill in data 133 | curr.entries[newLabeling].labeling = newLabeling 134 | curr.entries[newLabeling].prNonBlank += prNonBlank 135 | curr.entries[newLabeling].prTotal += prNonBlank 136 | 137 | # apply LM 138 | applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) 139 | 140 | # set new beam state 141 | last = curr 142 | 143 | # normalise LM scores according to beam-labeling-length 144 | last.norm() 145 | 146 | # sort by probability 147 | bestLabelings = last.sort()[:beamWidth] # get most probable labeling 148 | 149 | output = [] 150 | for bestLabeling in bestLabelings: 151 | # map labels to chars 152 | res = '' 153 | for l in bestLabeling: 154 | res += classes[l] 155 | output.append(res) 156 | return output 157 | 158 | def testBeamSearch(): 159 | "test decoder" 160 | classes = 'ab' 161 | mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) 162 | print('Test beam search') 163 | expected = 'a' 164 | actual = ctcBeamSearch(mat, classes, None) 165 | print('Expected: "' + expected + '"') 166 | print('Actual: "' + actual + '"') 167 | print('OK' if expected == actual else 'ERROR') 168 | 169 | if __name__ == '__main__': 170 | testBeamSearch() -------------------------------------------------------------------------------- /ocr/utils/denoiser_utils.py: -------------------------------------------------------------------------------- 1 | import gluonnlp as nlp 2 | import leven 3 | import mxnet as mx 4 | import numpy as np 5 | 6 | from ocr.utils.encoder_decoder import decode_char 7 | 8 | class SequenceGenerator: 9 | 10 | def __init__(self, sampler, language_model, vocab, ctx_nlp, tokenizer=nlp.data.SacreMosesTokenizer(), detokenizer=nlp.data.SacreMosesDetokenizer()): 11 | self.sampler = sampler 12 | self.language_model = language_model 13 | self.ctx_nlp = ctx_nlp 14 | self.vocab = vocab 15 | self.tokenizer = tokenizer 16 | self.detokenizer = detokenizer 17 | 18 | def generate_sequences(self, inputs, begin_states, sentence): 19 | samples, scores, valid_lengths = self.sampler(inputs, begin_states) 20 | samples = samples[0].asnumpy() 21 | scores = scores[0].asnumpy() 22 | valid_lengths = valid_lengths[0].asnumpy() 23 | max_score = -10e20 24 | 25 | # Heuristic #1 26 | #If the sentence is correct, let's not try to change it 27 | sentence_tokenized = [i.replace(""", '"').replace("'","'").replace("&", "&") for i in self.tokenizer(sentence)] 28 | sentence_correct = True 29 | for token in sentence_tokenized: 30 | if (token not in self.vocab or self.vocab[token] > 400000) and token.lower() not in ["don't", "doesn't", "can't", "won't", "ain't", "couldn't", "i'd", "you'd", "he's", "she's", "it's", "i've", "you've", "she'd"]: 31 | sentence_correct = False 32 | break 33 | if sentence_correct: 34 | return sentence 35 | 36 | # Heuristic #2 37 | # We want sentence that have the most in-vocabulary words 38 | # and we penalize sentences that have out of vocabulary words 39 | # that do not start with a capital letter 40 | for i, sample in enumerate(samples): 41 | tokens = decode_char(sample[:valid_lengths[i]]) 42 | tokens = [i.replace(""", '"').replace("'","'").replace("&", "&") for i in self.tokenizer(tokens)] 43 | score = 0 44 | 45 | for t in tokens: 46 | # Boosting names 47 | if (t in self.vocab and self.vocab[t] < 450000) or (len(t) > 0 and t.istitle()): 48 | score += 0 49 | else: 50 | score -= 1 51 | score -= 0 52 | if score == max_score: 53 | max_score = score 54 | best_tokens.append(tokens) 55 | elif score > max_score: 56 | max_score = score 57 | best_tokens = [tokens] 58 | 59 | # Heurisitic #3 60 | # Smallest edit distance 61 | # We then take the sentence with the lowest edit distance 62 | # From the predicted original sentence 63 | best_dist = 1000 64 | output_tokens = best_tokens[0] 65 | best_tokens_ = [] 66 | for tokens in best_tokens: 67 | dist = leven.levenshtein(sentence, ''.join(self.detokenizer(tokens))) 68 | if dist < best_dist: 69 | best_dist = dist 70 | best_tokens_ =[tokens] 71 | elif dist == best_dist: 72 | best_tokens_.append(tokens) 73 | 74 | # Heuristic #4 75 | # We take the sentence with the smallest number of tokens 76 | # to avoid split up composed words 77 | min_len = 10e20 78 | for tokens in best_tokens_: 79 | if len(tokens) < min_len: 80 | min_len = len(tokens) 81 | best_tokens__ = [tokens] 82 | elif len(tokens) == min_len: 83 | best_tokens__.append(tokens) 84 | 85 | # Heuristic #5 86 | # Lowest ppl 87 | # If we still have ties we take the sentence with the lowest 88 | # Perplexity score according to the language model 89 | best_ppl = 10e20 90 | for tokens in best_tokens__: 91 | if len(tokens) > 1: 92 | inputs = self.vocab[tokens] 93 | hidden = self.language_model.begin_state(batch_size=1, func=mx.nd.zeros, ctx=self.ctx_nlp) 94 | output, _ = self.language_model(mx.nd.array(inputs).expand_dims(axis=1).as_in_context(self.ctx_nlp), hidden) 95 | output = output.softmax() 96 | l = 0 97 | for i in range(1, len(inputs)): 98 | l += -output[i-1][0][inputs[i]].log() 99 | ppl = (l/len(inputs)).exp() 100 | if ppl < best_ppl: 101 | output_tokens = tokens 102 | best_ppl = ppl 103 | output = ''.join(self.detokenizer(output_tokens)) 104 | 105 | 106 | # Heuristic #6 107 | # Sometimes there are artefact at the end of the corrected sentence 108 | # We cut the end of the sentence 109 | if len(output) > len(sentence) + 10: 110 | output = output[:len(sentence)+2] 111 | return output 112 | -------------------------------------------------------------------------------- /ocr/utils/draw_box_on_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | from skimage.draw import line_aa 6 | import matplotlib.pyplot as plt 7 | 8 | def draw_line(image, y1, x1, y2, x2, line_type): 9 | rr, cc, val = line_aa(y1, x1, y2, x2) 10 | if line_type == "dotted": 11 | rr = np.delete(rr, np.arange(0, rr.size, 5)) 12 | cc = np.delete(cc, np.arange(0, cc.size, 5)) 13 | image[rr, cc] = 0 14 | return image 15 | 16 | def draw_box(bounding_box, image, line_type, is_xywh=True): 17 | image_h, image_w = image.shape[-2:] 18 | if is_xywh: 19 | (x, y, w, h) = bounding_box 20 | (x1, y1, x2, y2) = (x, y, x + w, y + h) 21 | else: 22 | (x1, y1, x2, y2) = bounding_box 23 | (x1, y1, x2, y2) = (int(x1), int(y1), int(x2), int(y2)) 24 | if y2 >= image_h: 25 | y2 = image_h - 1 26 | if x2 >= image_w: 27 | x2 = image_w - 1 28 | if y1 >= image_h: 29 | y1 = image_h - 1 30 | if x1 >= image_w: 31 | x1 = image_w - 1 32 | if y2 < 0: 33 | y2 = 0 34 | if x2 < 0: 35 | x2 =0 36 | if y1 < 0: 37 | y1 = 0 38 | if x1 < 0: 39 | x1 = 0 40 | 41 | image = draw_line(image, y1, x1, y2, x1, line_type) 42 | image = draw_line(image, y2, x1, y2, x2, line_type) 43 | image = draw_line(image, y2, x2, y1, x2, line_type) 44 | image = draw_line(image, y1, x2, y1, x1, line_type) 45 | return image 46 | 47 | def draw_boxes_on_image(pred, label, images): 48 | ''' Function to draw multiple bounding boxes on the images. Predicted bounding boxes will be 49 | presented with a dotted line and actual boxes are presented with a solid line. 50 | 51 | Parameters 52 | ---------- 53 | 54 | pred: [n x [x, y, w, h]] 55 | The predicted bounding boxes in percentages. 56 | n is the number of bounding boxes predicted on an image 57 | 58 | label: [n x [x, y, w, h]] 59 | The actual bounding boxes in percentages 60 | n is the number of bounding boxes predicted on an image 61 | 62 | images: [[np.array]] 63 | The correponding images. 64 | 65 | Returns 66 | ------- 67 | 68 | images: [[np.array]] 69 | Images with bounding boxes printed on them. 70 | ''' 71 | image_h, image_w = images.shape[-2:] 72 | label[:, :, 0], label[:, :, 1] = label[:, :, 0] * image_w, label[:, :, 1] * image_h 73 | label[:, :, 2], label[:, :, 3] = label[:, :, 2] * image_w, label[:, :, 3] * image_h 74 | for i in range(len(pred)): 75 | pred_b = pred[i] 76 | pred_b[:, 0], pred_b[:, 1] = pred_b[:, 0] * image_w, pred_b[:, 1] * image_h 77 | pred_b[:, 2], pred_b[:, 3] = pred_b[:, 2] * image_w, pred_b[:, 3] * image_h 78 | 79 | image = images[i, 0] 80 | for j in range(pred_b.shape[0]): 81 | image = draw_box(pred_b[j, :], image, line_type="dotted") 82 | 83 | for k in range(label.shape[1]): 84 | image = draw_box(label[i, k, :], image, line_type="solid") 85 | images[i, 0, :, :] = image 86 | return images 87 | 88 | def draw_box_on_image(pred, label, images): 89 | ''' Function to draw bounding boxes on the images. Predicted bounding boxes will be 90 | presented with a dotted line and actual boxes are presented with a solid line. 91 | 92 | Parameters 93 | ---------- 94 | 95 | pred: [[x, y, w, h]] 96 | The predicted bounding boxes in percentages 97 | 98 | label: [[x, y, w, h]] 99 | The actual bounding boxes in percentages 100 | 101 | images: [[np.array]] 102 | The correponding images. 103 | 104 | Returns 105 | ------- 106 | 107 | images: [[np.array]] 108 | Images with bounding boxes printed on them. 109 | ''' 110 | 111 | image_h, image_w = images.shape[-2:] 112 | pred[:, 0], pred[:, 1] = pred[:, 0] * image_w, pred[:, 1] * image_h 113 | pred[:, 2], pred[:, 3] = pred[:, 2] * image_w, pred[:, 3] * image_h 114 | 115 | label[:, 0], label[:, 1] = label[:, 0] * image_w, label[:, 1] * image_h 116 | label[:, 2], label[:, 3] = label[:, 2] * image_w, label[:, 3] * image_h 117 | 118 | for i in range(images.shape[0]): 119 | image = images[i, 0] 120 | image = draw_box(pred[i, :], image, line_type="dotted") 121 | image = draw_box(label[i, :], image, line_type="solid") 122 | images[i, 0, :, :] = image 123 | return images 124 | -------------------------------------------------------------------------------- /ocr/utils/draw_text_on_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | def draw_text_on_image(images, text): 8 | output_image_shape = (images.shape[0], images.shape[1], images.shape[2] * 2, images.shape[3]) # Double the output_image_shape to print the text in the bottom 9 | 10 | output_images = np.zeros(shape=output_image_shape) 11 | for i in range(images.shape[0]): 12 | white_image_shape = (images.shape[2], images.shape[3]) 13 | white_image = np.ones(shape=white_image_shape)*1.0 14 | text_image = cv2.putText(white_image, text[i], org=(5, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=0.0, thickness=1) 15 | output_images[i, :, :images.shape[2], :] = images[i] 16 | output_images[i, :, images.shape[2]:, :] = text_image 17 | return output_images 18 | -------------------------------------------------------------------------------- /ocr/utils/expand_bounding_box.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | def expand_bounding_box(bb, expand_bb_scale_x=0.05, expand_bb_scale_y=0.05): 5 | (x, y, w, h) = bb 6 | new_w = (1 + expand_bb_scale_x) * w 7 | new_h = (1 + expand_bb_scale_y) * h 8 | 9 | x = x - (new_w - w)/2 10 | y = y - (new_h - h)/2 11 | w = new_w 12 | h = new_h 13 | return (x, y, w, h) 14 | -------------------------------------------------------------------------------- /ocr/utils/iam_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import tarfile 6 | import urllib 7 | import sys 8 | import time 9 | import glob 10 | import pickle 11 | import xml.etree.ElementTree as ET 12 | import cv2 13 | import json 14 | import numpy as np 15 | import pandas as pd 16 | import zipfile 17 | import matplotlib.pyplot as plt 18 | import logging 19 | import requests 20 | 21 | from mxnet.gluon.data import dataset 22 | from mxnet import nd 23 | 24 | from .expand_bounding_box import expand_bounding_box 25 | 26 | def crop_image(image, bb): 27 | ''' Helper function to crop the image by the bounding box (in percentages) 28 | ''' 29 | (x, y, w, h) = bb 30 | x = x * image.shape[1] 31 | y = y * image.shape[0] 32 | w = w * image.shape[1] 33 | h = h * image.shape[0] 34 | (x1, y1, x2, y2) = (x, y, x + w, y + h) 35 | (x1, y1, x2, y2) = (int(x1), int(y1), int(x2), int(y2)) 36 | return image[y1:y2, x1:x2] 37 | 38 | def resize_image(image, desired_size): 39 | ''' Helper function to resize an image while keeping the aspect ratio. 40 | Parameter 41 | --------- 42 | 43 | image: np.array 44 | The image to be resized. 45 | 46 | desired_size: (int, int) 47 | The (height, width) of the resized image 48 | 49 | Return 50 | ------ 51 | 52 | image: np.array 53 | The image of size = desired_size 54 | 55 | bounding box: (int, int, int, int) 56 | (x, y, w, h) in percentages of the resized image of the original 57 | ''' 58 | size = image.shape[:2] 59 | if size[0] > desired_size[0] or size[1] > desired_size[1]: 60 | ratio_w = float(desired_size[0])/size[0] 61 | ratio_h = float(desired_size[1])/size[1] 62 | ratio = min(ratio_w, ratio_h) 63 | new_size = tuple([int(x*ratio) for x in size]) 64 | image = cv2.resize(image, (new_size[1], new_size[0])) 65 | size = image.shape 66 | 67 | delta_w = max(0, desired_size[1] - size[1]) 68 | delta_h = max(0, desired_size[0] - size[0]) 69 | top, bottom = delta_h//2, delta_h-(delta_h//2) 70 | left, right = delta_w//2, delta_w-(delta_w//2) 71 | 72 | color = image[0][0] 73 | if color < 230: 74 | color = 230 75 | image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=float(color)) 76 | crop_bb = (left/image.shape[1], top/image.shape[0], (image.shape[1] - right - left)/image.shape[1], 77 | (image.shape[0] - bottom - top)/image.shape[0]) 78 | image[image > 230] = 255 79 | return image, crop_bb 80 | 81 | def crop_handwriting_page(image, bb, image_size): 82 | ''' 83 | Given an image and bounding box (bb) crop the input image based on the bounding box. 84 | The final output image was scaled based on the image size. 85 | 86 | Parameters 87 | ---------- 88 | image: np.array 89 | Input form image 90 | 91 | bb: (x, y, w, h) 92 | The bounding box in percentages to crop 93 | 94 | image_size: (h, w) 95 | Image size to scale the output image to. 96 | 97 | Returns 98 | ------- 99 | output_image: np.array 100 | cropped image of size image_size. 101 | ''' 102 | image = crop_image(image, bb) 103 | 104 | image, _ = resize_image(image, desired_size=image_size) 105 | return image 106 | 107 | class IAMDataset(dataset.ArrayDataset): 108 | """ The IAMDataset provides images of handwritten passages written by multiple 109 | individuals. The data is available at http://www.fki.inf.unibe.ch 110 | 111 | The passages can be parsed into separate words, lines, or the whole form. 112 | The dataset should be separated into writer independent training and testing sets. 113 | 114 | Parameters 115 | ---------- 116 | parse_method: str, Required 117 | To select the method of parsing the images of the passage 118 | Available options: [form, form_bb, line, word] 119 | 120 | credentials: (str, str), Default None 121 | Your (username, password) for the IAM dataset. Register at 122 | http://www.fki.inf.unibe.ch/DBs/iamDB/iLogin/index.php 123 | By default, IAMDataset will read it from credentials.json 124 | 125 | root: str, default: dataset/iamdataset 126 | Location to save the database 127 | 128 | train: bool, default True 129 | Whether to load the training or testing set of writers. 130 | 131 | output_data_type: str, default text 132 | What type of data you want as an output: Text or bounding box. 133 | Available options are: [text, bb] 134 | 135 | output_parse_method: str, default None 136 | If the bounding box (bb) was selected as an output_data_type, 137 | this parameter can select which bb you want to obtain. 138 | Available options: [form, line, word] 139 | 140 | output_form_text_as_array: bool, default False 141 | When output_data is set to text and the parse method is set to form or form_original, 142 | if output_form_text_as_array is true, the output text will be a list of lines string 143 | """ 144 | MAX_IMAGE_SIZE_FORM = (1120, 800) 145 | MAX_IMAGE_SIZE_LINE = (60, 800) 146 | MAX_IMAGE_SIZE_WORD = (30, 140) 147 | def __init__(self, parse_method, credentials=None, 148 | root=os.path.join(os.path.dirname(__file__), '..', '..','dataset', 'iamdataset'), 149 | train=True, output_data="text", 150 | output_parse_method=None, 151 | output_form_text_as_array=False): 152 | 153 | _parse_methods = ["form", "form_original", "form_bb", "line", "word"] 154 | error_message = "{} is not a possible parsing method: {}".format( 155 | parse_method, _parse_methods) 156 | assert parse_method in _parse_methods, error_message 157 | self._parse_method = parse_method 158 | url_partial = "http://fki.tic.heia-fr.ch/DBs/iamDB/data/{filename}.tgz" 159 | if self._parse_method == "form": 160 | self._data_urls = [url_partial.format(data_type="forms", filename="forms" + a) for a in ["A-D", "E-H", "I-Z"]] 161 | elif self._parse_method == "form_bb": 162 | self._data_urls = [url_partial.format(data_type="forms", filename="forms" + a) for a in ["A-D", "E-H", "I-Z"]] 163 | elif self._parse_method == "form_original": 164 | self._data_urls = [url_partial.format(data_type="forms", filename="forms" + a) for a in ["A-D", "E-H", "I-Z"]] 165 | elif self._parse_method == "line": 166 | self._data_urls = [url_partial.format(data_type="lines", filename="lines")] 167 | elif self._parse_method == "word": 168 | self._data_urls = [url_partial.format(data_type="words", filename="words")] 169 | self._xml_url = "http://fki.tic.heia-fr.ch/DBs/iamDB/data/xml.tgz" 170 | 171 | if credentials == None: 172 | if os.path.isfile(os.path.join(os.path.dirname(__file__), '..','..', 'credentials.json')): 173 | with open(os.path.join(os.path.dirname(__file__), '..','..', 'credentials.json')) as f: 174 | credentials = json.load(f) 175 | self._credentials = (credentials["username"], credentials["password"]) 176 | else: 177 | assert False, "Please enter credentials for the IAM dataset in credentials.json or as arguments" 178 | else: 179 | self._credentials = credentials 180 | 181 | self._train = train 182 | 183 | _output_data_types = ["text", "bb"] 184 | error_message = "{} is not a possible output data: {}".format( 185 | output_data, _output_data_types) 186 | assert output_data in _output_data_types, error_message 187 | self._output_data = output_data 188 | 189 | if self._output_data == "bb": 190 | assert self._parse_method in ["form", "form_bb"], "Bounding box only works with form." 191 | _parse_methods = ["form", "line", "word"] 192 | error_message = "{} is not a possible output parsing method: {}".format( 193 | output_parse_method, _parse_methods) 194 | assert output_parse_method in _parse_methods, error_message 195 | self._output_parse_method = output_parse_method 196 | 197 | self.image_data_file_name = os.path.join(root, "image_data-{}-{}-{}*.plk".format( 198 | self._parse_method, self._output_data, self._output_parse_method)) 199 | else: 200 | self.image_data_file_name = os.path.join(root, "image_data-{}-{}*.plk".format(self._parse_method, self._output_data)) 201 | 202 | self._root = root 203 | if not os.path.isdir(root): 204 | os.makedirs(root) 205 | self._output_form_text_as_array = output_form_text_as_array 206 | 207 | data = self._get_data() 208 | super(IAMDataset, self).__init__(data) 209 | 210 | @staticmethod 211 | def _reporthook(count, block_size, total_size): 212 | ''' Prints a process bar that is compatible with urllib.request.urlretrieve 213 | ''' 214 | toolbar_width = 40 215 | percentage = float(count * block_size) / total_size * 100 216 | # Taken from https://gist.github.com/sibosutd/c1d9ef01d38630750a1d1fe05c367eb8 217 | sys.stdout.write('\r') 218 | sys.stdout.write("Completed: [{:{}}] {:>3}%" 219 | .format('-' * int(percentage / (100.0 / toolbar_width)), 220 | toolbar_width, int(percentage))) 221 | sys.stdout.flush() 222 | 223 | def _extract(self, archive_file, archive_type, output_dir): 224 | ''' Helper function to extract archived files. Available for tar.tgz and zip files 225 | Parameter 226 | --------- 227 | archive_file: str 228 | Filepath to the archive file 229 | archive_type: str, options: [tar, zip] 230 | Select the type of file you want to extract 231 | output_dir: str 232 | Location where you want to extract the files to 233 | ''' 234 | logging.info("Extracting {}".format(archive_file)) 235 | _available_types = ["tar", "zip"] 236 | error_message = "Archive_type {} is not an available option ({})".format(archive_type, _available_types) 237 | assert archive_type in _available_types, error_message 238 | if archive_type == "tar": 239 | tar = tarfile.open(archive_file, "r:gz") 240 | tar.extractall(os.path.join(self._root, output_dir)) 241 | tar.close() 242 | elif archive_type == "zip": 243 | zip_ref = zipfile.ZipFile(archive_file, 'r') 244 | zip_ref.extractall(os.path.join(self._root, output_dir)) 245 | zip_ref.close() 246 | 247 | def _download(self, url): 248 | ''' Helper function to download using the credentials provided 249 | Parameter 250 | --------- 251 | url: str 252 | The url of the file you want to download. 253 | ''' 254 | session = requests.Session() 255 | data = {"email": self._credentials[0], "password": self._credentials[1]} 256 | login_url = "https://fki.tic.heia-fr.ch/login" 257 | login_response = session.post(login_url, data=data) 258 | filename = os.path.basename(url) 259 | print("Downloading {}: ".format(filename)) 260 | with session.get(url, stream=True) as get_response: 261 | get_response.raise_for_status() 262 | with open(os.path.join(self._root, filename), 'wb') as f: 263 | for count, chunk in enumerate(get_response.iter_content(chunk_size=8192)): 264 | self._reporthook(count=count, block_size=8192, total_size=float(get_response.headers["Content-Length"])) 265 | # If you have chunk encoded response uncomment if 266 | # and set chunk_size parameter to None. 267 | # if chunk: 268 | f.write(chunk) 269 | sys.stdout.write("\n") 270 | 271 | def _download_xml(self): 272 | ''' Helper function to download and extract the xml of the IAM database 273 | ''' 274 | archive_file = os.path.join(self._root, os.path.basename(self._xml_url)) 275 | logging.info("Downloding xml from {}".format(self._xml_url)) 276 | if not os.path.isfile(archive_file): 277 | self._download(self._xml_url) 278 | self._extract(archive_file, archive_type="tar", output_dir="xml") 279 | 280 | def _download_data(self): 281 | ''' Helper function to download and extract the data of the IAM database 282 | ''' 283 | for url in self._data_urls: 284 | logging.info("Downloding data from {}".format(url)) 285 | archive_file = os.path.join(self._root, os.path.basename(url)) 286 | if not os.path.isfile(archive_file): 287 | self._download(url) 288 | self._extract(archive_file, archive_type="tar", output_dir=self._parse_method.split("_")[0]) 289 | 290 | def _download_subject_list(self): 291 | ''' Helper function to download and extract the subject list of the IAM database 292 | ''' 293 | url = "https://fki.tic.heia-fr.ch/static/zip/largeWriterIndependentTextLineRecognitionTask.zip" 294 | archive_file = os.path.join(self._root, os.path.basename(url)) 295 | if not os.path.isfile(archive_file): 296 | logging.info("Downloding subject list from {}".format(url)) 297 | self._download(url) 298 | self._extract(archive_file, archive_type="zip", output_dir="subject") 299 | 300 | def _pre_process_image(self, img_in): 301 | im = cv2.imread(img_in, cv2.IMREAD_GRAYSCALE) 302 | if np.size(im) == 1: # skip if the image data is corrupt. 303 | return None 304 | # reduce the size of form images so that it can fit in memory. 305 | if self._parse_method in ["form", "form_bb"]: 306 | im, _ = resize_image(im, self.MAX_IMAGE_SIZE_FORM) 307 | if self._parse_method == "line": 308 | im, _ = resize_image(im, self.MAX_IMAGE_SIZE_LINE) 309 | if self._parse_method == "word": 310 | im, _ = resize_image(im, self.MAX_IMAGE_SIZE_WORD) 311 | img_arr = np.asarray(im) 312 | return img_arr 313 | 314 | def _get_bb_of_item(self, item, height, width): 315 | ''' Helper function to find the bounding box (bb) of an item in the xml file. 316 | All the characters within the item are found and the left-most (min) and right-most (max + length) 317 | are found. 318 | The bounding box emcompasses the left and right most characters in the x and y direction. 319 | 320 | Parameter 321 | --------- 322 | item: xml.etree object for a word/line/form. 323 | 324 | height: int 325 | Height of the form to calculate percentages of bounding boxes 326 | 327 | width: int 328 | Width of the form to calculate percentages of bounding boxes 329 | 330 | Returns 331 | ------- 332 | list 333 | The bounding box [x, y, w, h] in percentages that encompasses the item. 334 | ''' 335 | 336 | character_list = [a for a in item.iter("cmp")] 337 | if len(character_list) == 0: # To account for some punctuations that have no words 338 | return None 339 | x1 = np.min([int(a.attrib['x']) for a in character_list]) 340 | y1 = np.min([int(a.attrib['y']) for a in character_list]) 341 | x2 = np.max([int(a.attrib['x']) + int(a.attrib['width']) for a in character_list]) 342 | y2 = np.max([int(a.attrib['y']) + int(a.attrib['height'])for a in character_list]) 343 | 344 | x1 = float(x1) / width 345 | x2 = float(x2) / width 346 | y1 = float(y1) / height 347 | y2 = float(y2) / height 348 | bb = [x1, y1, x2 - x1, y2 - y1] 349 | return bb 350 | 351 | def _get_output_data(self, item, height, width): 352 | ''' Function to obtain the output data (both text and bounding boxes). 353 | Note that the bounding boxes are rescaled based on the rescale_ratio parameter. 354 | 355 | Parameter 356 | --------- 357 | item: xml.etree 358 | XML object for a word/line/form. 359 | 360 | height: int 361 | Height of the form to calculate percentages of bounding boxes 362 | 363 | width: int 364 | Width of the form to calculate percentages of bounding boxes 365 | 366 | Returns 367 | ------- 368 | 369 | np.array 370 | A numpy array ouf the output requested (text or the bounding box) 371 | ''' 372 | 373 | output_data = [] 374 | if self._output_data == "text": 375 | if self._parse_method in ["form", "form_bb", "form_original"]: 376 | text = "" 377 | for line in item.iter('line'): 378 | text += line.attrib["text"] + "\n" 379 | output_data.append(text) 380 | else: 381 | output_data.append(item.attrib['text']) 382 | else: 383 | for item_output in item.iter(self._output_parse_method): 384 | bb = self._get_bb_of_item(item_output, height, width) 385 | if bb == None: # Account for words with no letters 386 | continue 387 | output_data.append(bb) 388 | output_data = np.array(output_data) 389 | return output_data 390 | 391 | def _change_bb_reference(self, bb, relative_bb, bb_reference_size, relative_bb_reference_size, output_size, operator): 392 | ''' Helper function to convert bounding boxes relative into another bounding bounding box. 393 | Parameter 394 | -------- 395 | bb: [[int, int, int, int]] 396 | Bounding boxes (x, y, w, h) in percentages to be converted. 397 | 398 | relative_bb: [int, int, int, int] 399 | Reference bounding box (in percentages) to convert bb to 400 | 401 | bb_reference_size: (int, int) 402 | Size (h, w) in pixels of the image containing bb 403 | 404 | relative_bb_reference_size: (int, int) 405 | Size (h, w) in pixels of the image containing relative_bb 406 | 407 | output_size: (int, int) 408 | Size (h, w) in pixels of the output image 409 | 410 | operator: string 411 | Options ["plus", "minus"]. "plus" if relative_bb is within bb and "minus" if bb is within relative_bb 412 | 413 | Returns 414 | ------- 415 | bb: [[int, int, int, int]] 416 | Bounding boxes (x, y, w, h) in percentages that are converted 417 | 418 | ''' 419 | (x1, y1, x2, y2) = (bb[:, 0], bb[:, 1], bb[:, 0] + bb[:, 2], bb[:, 1] + bb[:, 3]) 420 | (x1, y1, x2, y2) = (x1 * bb_reference_size[1], y1 * bb_reference_size[0], 421 | x2 * bb_reference_size[1], y2 * bb_reference_size[0]) 422 | 423 | if operator == "plus": 424 | new_x1 = (x1 + relative_bb[0] * relative_bb_reference_size[1]) / output_size[1] 425 | new_y1 = (y1 + relative_bb[1] * relative_bb_reference_size[0]) / output_size[0] 426 | new_x2 = (x2 + relative_bb[0] * relative_bb_reference_size[1]) / output_size[1] 427 | new_y2 = (y2 + relative_bb[1] * relative_bb_reference_size[0]) / output_size[0] 428 | else: 429 | new_x1 = (x1 - relative_bb[0] * relative_bb_reference_size[1]) / output_size[1] 430 | new_y1 = (y1 - relative_bb[1] * relative_bb_reference_size[0]) / output_size[0] 431 | new_x2 = (x2 - relative_bb[0] * relative_bb_reference_size[1]) / output_size[1] 432 | new_y2 = (y2 - relative_bb[1] * relative_bb_reference_size[0]) / output_size[0] 433 | 434 | new_bbs = np.zeros(shape=bb.shape) 435 | new_bbs[:, 0] = new_x1 436 | new_bbs[:, 1] = new_y1 437 | new_bbs[:, 2] = new_x2 - new_x1 438 | new_bbs[:, 3] = new_y2 - new_y1 439 | return new_bbs 440 | 441 | def _crop_and_resize_form_bb(self, item, image_arr, output_data, height, width): 442 | bb = self._get_bb_of_item(item, height, width) 443 | 444 | # Expand the form bounding box by 5% 445 | expand_bb_scale = 0.05 446 | new_w = (1 + expand_bb_scale) * bb[2] 447 | new_h = (1 + expand_bb_scale) * bb[3] 448 | 449 | bb[0] = bb[0] - (new_w - bb[2])/2 450 | bb[1] = bb[1] - (new_h - bb[3])/2 451 | bb[2] = new_w 452 | bb[3] = new_h 453 | 454 | image_arr_bb = crop_image(image_arr, bb) 455 | 456 | if self._output_data == "bb": 457 | output_data = self._change_bb_reference(output_data, bb, image_arr.shape, image_arr.shape, image_arr_bb.shape, "minus") 458 | 459 | image_arr_bb_, bb = resize_image(image_arr_bb, desired_size=(700, 700)) 460 | 461 | if self._output_data == "bb": 462 | output_data = self._change_bb_reference(output_data, bb, image_arr_bb.shape, image_arr_bb_.shape, image_arr_bb_.shape, "plus") 463 | image_arr = image_arr_bb_ 464 | return image_arr, output_data 465 | 466 | def _save_dataframe_chunks(self, df, name): 467 | for i, df_split in enumerate(np.array_split(df, 4)): 468 | filename = name[:-5] + str(i) + ".plk" # remove *.plk in the filename 469 | df_split.to_pickle(filename, protocol=2) 470 | 471 | def _load_dataframe_chunks(self, name): 472 | image_data_chunks = [] 473 | for fn in sorted(glob.glob(name)): 474 | df = pickle.load(open(fn, 'rb')) 475 | image_data_chunks.append(df) 476 | image_data = pd.concat(image_data_chunks) 477 | return image_data 478 | 479 | def _process_data(self): 480 | ''' Function that iterates through the downloaded xml file to gather the input images and the 481 | corresponding output. 482 | 483 | Returns 484 | ------- 485 | pd.DataFrame 486 | A pandas dataframe that contains the subject, image and output requested. 487 | ''' 488 | image_data = [] 489 | xml_files = glob.glob(self._root + "/xml/*.xml") 490 | print("Processing data:") 491 | logging.info("Processing data") 492 | for i, xml_file in enumerate(xml_files): 493 | tree = ET.parse(xml_file) 494 | root = tree.getroot() 495 | height, width = int(root.attrib["height"]), int(root.attrib["width"]) 496 | for item in root.iter(self._parse_method.split("_")[0]): 497 | # Split _ to account for only taking the base "form", "line", "word" that is available in the IAM dataset 498 | if self._parse_method in ["form", "form_bb", "form_original"]: 499 | image_id = item.attrib["id"] 500 | else: 501 | tmp_id = item.attrib["id"] 502 | tmp_id_split = tmp_id.split("-") 503 | image_id = os.path.join(tmp_id_split[0], tmp_id_split[0] + "-" + tmp_id_split[1], tmp_id) 504 | image_filename = os.path.join(self._root, self._parse_method.split("_")[0], image_id + ".png") 505 | image_arr = self._pre_process_image(image_filename) 506 | if image_arr is None: 507 | continue 508 | output_data = self._get_output_data(item, height, width) 509 | if self._parse_method == "form_bb": 510 | image_arr, output_data = self._crop_and_resize_form_bb(item, image_arr, output_data, height, width) 511 | image_data.append([item.attrib["id"], image_arr, output_data]) 512 | self._reporthook(i, 1, len(xml_files)) 513 | image_data = pd.DataFrame(image_data, columns=["subject", "image", "output"]) 514 | self._save_dataframe_chunks(image_data, self.image_data_file_name) 515 | return image_data 516 | 517 | def _process_subjects(self, train_subject_lists = ["trainset", "validationset1", "validationset2"], 518 | test_subject_lists = ["testset"]): 519 | ''' Function to organise the list of subjects to training and testing. 520 | The IAM dataset provides 4 files: trainset, validationset1, validationset2, and testset each 521 | with a list of subjects. 522 | 523 | Parameters 524 | ---------- 525 | 526 | train_subject_lists: [str], default ["trainset", "validationset1", "validationset2"] 527 | The filenames of the list of subjects to be used for training the model 528 | 529 | test_subject_lists: [str], default ["testset"] 530 | The filenames of the list of subjects to be used for testing the model 531 | 532 | Returns 533 | ------- 534 | 535 | train_subjects: [str] 536 | A list of subjects used for training 537 | 538 | test_subjects: [str] 539 | A list of subjects used for testing 540 | ''' 541 | 542 | train_subjects = [] 543 | test_subjects = [] 544 | for train_list in train_subject_lists: 545 | subject_list = pd.read_csv(os.path.join(self._root, "subject", train_list+".txt")) 546 | train_subjects.append(subject_list.values) 547 | for test_list in test_subject_lists: 548 | subject_list = pd.read_csv(os.path.join(self._root, "subject", test_list+".txt")) 549 | test_subjects.append(subject_list.values) 550 | 551 | train_subjects = np.concatenate(train_subjects) 552 | test_subjects = np.concatenate(test_subjects) 553 | if self._parse_method in ["form", "form_bb", "form_original"]: 554 | # For the form method, the "subject names" do not match the ones provided 555 | # in the file. This clause transforms the subject names to match the file. 556 | new_train_subjects = [] 557 | for i in train_subjects: 558 | form_subject_number = i[0].split("-")[0] + "-" + i[0].split("-")[1] 559 | new_train_subjects.append(form_subject_number) 560 | new_test_subjects = [] 561 | for i in test_subjects: 562 | form_subject_number = i[0].split("-")[0] + "-" + i[0].split("-")[1] 563 | new_test_subjects.append(form_subject_number) 564 | train_subjects, test_subjects = new_train_subjects, new_test_subjects 565 | return train_subjects, test_subjects 566 | 567 | def _convert_subject_list(self, subject_list): 568 | ''' Function to convert the list of subjects for the "word" parse method 569 | 570 | Parameters 571 | ---------- 572 | 573 | subject_lists: [str] 574 | A list of subjects 575 | 576 | Returns 577 | ------- 578 | 579 | subject_lists: [str] 580 | A list of subjects that is compatible with the "word" parse method 581 | 582 | ''' 583 | 584 | if self._parse_method == "word": 585 | new_subject_list = [] 586 | for sub in subject_list: 587 | new_subject_number = "-".join(sub.split("-")[:3]) 588 | new_subject_list.append(new_subject_number) 589 | return new_subject_list 590 | else: 591 | return subject_list 592 | 593 | def _get_data(self): 594 | ''' Function to get the data and to extract the data for training or testing 595 | 596 | Returns 597 | ------- 598 | 599 | pd.DataFram 600 | A dataframe (subject, image, and output) that contains only the training/testing data 601 | 602 | ''' 603 | 604 | # Get the data 605 | if not os.path.isdir(self._root): 606 | os.makedirs(self._root) 607 | 608 | if len(glob.glob(self.image_data_file_name)) > 0: 609 | logging.info("Loading data from pickle") 610 | images_data = self._load_dataframe_chunks(self.image_data_file_name) 611 | else: 612 | self._download_xml() 613 | self._download_data() 614 | images_data = self._process_data() 615 | 616 | # Extract train or test data out 617 | self._download_subject_list() 618 | train_subjects, test_subjects = self._process_subjects() 619 | if self._train: 620 | data = images_data[np.in1d(self._convert_subject_list(images_data["subject"]), 621 | train_subjects)] 622 | else: 623 | data = images_data[np.in1d(self._convert_subject_list(images_data["subject"]), 624 | test_subjects)] 625 | return data 626 | 627 | def __getitem__(self, idx): 628 | return (self._data[0].iloc[idx].image, self._data[0].iloc[idx].output) 629 | -------------------------------------------------------------------------------- /ocr/utils/iou_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from mxnet.gluon.loss import Loss 5 | from mxnet import gluon 6 | 7 | class IOU_loss(Loss): 8 | r"""Calculates the iou between `pred` and `label`. 9 | 10 | Implementation based on: 11 | Yu, J., Jiang, Y., Wang, Z., Cao, Z., & Huang, T. (2016, October). Unitbox: An advanced object detection network. 12 | # In Proceedings of the 2016 ACM on Multimedia Conference (pp. 516-520). ACM. 13 | 14 | Parameters 15 | ---------- 16 | weight : float or None 17 | Global scalar weight for loss. 18 | batch_axis : int, default 0 19 | The axis that represents mini-batch. 20 | 21 | Inputs: 22 | - **pred**: prediction tensor with shape [x, y, w, h] each in percentages 23 | - **label**: target tensor with the shape [x, y, w, h] each in percentages 24 | - **sample_weight**: element-wise weighting tensor. Must be broadcastable 25 | to the same shape as pred. For example, if pred has shape (64, 10) 26 | and you want to weigh each sample in the batch separately, 27 | sample_weight should have shape (64, 1). 28 | 29 | Outputs: 30 | - **loss**: IOU loss tensor with shape (batch_size,). 31 | """ 32 | 33 | def __init__(self, weight=1., batch_axis=0, **kwargs): 34 | super(IOU_loss, self).__init__(weight, batch_axis, **kwargs) 35 | 36 | def hybrid_forward(self, F, pred, label, sample_weight=None): 37 | ''' 38 | Interpreted from: https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/ 39 | Steps to calculate IOU 40 | 1) Calculate the area of the predicted and actual bounding boxes 41 | 2) Calculate the area of the intersection between the predicting and actual bounding box 42 | 3) Calculate the log IOU by: log(intersection area / (union area)) 43 | 3) If the bounding boxes do not overlap with one another, set the iou to zero 44 | 4) Calculate the negative mean of the IOU 45 | ''' 46 | 47 | pred_area = pred[:, 2] * pred[:, 3] 48 | label_area = label[:, 2] * label[:, 3] 49 | 50 | x1_1, y1_1 = pred[:, 0], pred[:, 1] 51 | x2_1, y2_1 = pred[:, 0] + pred[:, 2], pred[:, 1] + pred[:, 3] 52 | 53 | x1_2, y1_2 = label[:, 0], label[:, 1] 54 | x2_2, y2_2 = label[:, 0] + label[:, 2], label[:, 1] + label[:, 3] 55 | 56 | x_overlaps = F.logical_or(x2_1 < x1_2, x1_1 > x2_2) 57 | y_overlaps = F.logical_or(y2_1 < y1_2, y1_1 > y2_2) 58 | overlaps = F.logical_not(F.logical_or(x_overlaps, y_overlaps)) 59 | 60 | x1_1 = x1_1.expand_dims(0) 61 | y1_1 = y1_1.expand_dims(0) 62 | x2_1 = x2_1.expand_dims(0) 63 | y2_1 = y2_1.expand_dims(0) 64 | x1_2 = x1_2.expand_dims(0) 65 | y1_2 = y1_2.expand_dims(0) 66 | x2_2 = x2_2.expand_dims(0) 67 | y2_2 = y2_2.expand_dims(0) 68 | 69 | x_a = F.max(F.concat(x1_1, x1_2, dim=0), axis=0) 70 | y_a = F.max(F.concat(y1_1, y1_2, dim=0), axis=0) 71 | x_b = F.min(F.concat(x2_1, x2_2, dim=0), axis=0) 72 | y_b = F.min(F.concat(y2_1, y2_2, dim=0), axis=0) 73 | 74 | inter_area = (x_b - x_a) * (y_b - y_a) 75 | 76 | iou = F.log(inter_area) - F.log(pred_area + label_area - inter_area) 77 | 78 | loss = gluon.loss._apply_weighting(F, iou, self._weight, sample_weight) 79 | loss = F.where(F.logical_not(overlaps), F.zeros(shape=overlaps.shape), loss) 80 | mean_loss = F.mean(loss, axis=self._batch_axis, exclude=True) 81 | return -mean_loss 82 | -------------------------------------------------------------------------------- /ocr/utils/lexicon_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import nltk 5 | import numpy as np 6 | from nltk.metrics import * 7 | from nltk.util import ngrams 8 | import enchant # spell checker library pyenchant 9 | from enchant.checker import SpellChecker 10 | from nltk.stem import PorterStemmer 11 | from nltk.corpus import words 12 | import string 13 | import re 14 | from collections import Counter 15 | from nltk.corpus import brown 16 | 17 | from nltk.probability import FreqDist 18 | from nltk.metrics import edit_distance 19 | 20 | from sympound import sympound 21 | import re 22 | from weighted_levenshtein import lev 23 | 24 | class WordSuggestor(): 25 | ''' 26 | Code obtained from http://norvig.com/spell-correct.html. 27 | ''' 28 | def __init__(self): 29 | self.words = Counter(brown.words()) 30 | 31 | def P(self, word): 32 | "Probability of `word`." 33 | N = sum(self.words.values()) 34 | return self.words[word] / N 35 | 36 | def correction(self, word): 37 | "Most probable spelling correction for word." 38 | return max(candidates(word), key=self.P) 39 | 40 | def candidates(self, word): 41 | "Generate possible spelling corrections for word." 42 | return (self.known([word]) or self.known(self.edits1(word)) or self.known(self.edits2(word)) or [word]) 43 | 44 | def known(self, words): 45 | "The subset of `words` that appear in the dictionary of WORDS." 46 | return set(w for w in words if w in self.words) 47 | 48 | def edits1(self, word): 49 | "All edits that are one edit away from `word`." 50 | letters = 'abcdefghijklmnopqrstuvwxyz' 51 | splits = [(word[:i], word[i:]) for i in range(len(word) + 1)] 52 | deletes = [L + R[1:] for L, R in splits if R] 53 | transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1] 54 | replaces = [L + c + R[1:] for L, R in splits if R for c in letters] 55 | inserts = [L + c + R for L, R in splits for c in letters] 56 | return set(deletes + transposes + replaces + inserts) 57 | 58 | def edits2(self, word): 59 | "All edits that are two edits away from `word`." 60 | return (e2 for e1 in self.edits1(word) for e2 in self.edits1(e1)) 61 | 62 | class OcrDistanceMeasure(): 63 | # Helper class to obtain a handwriting error weighted edit distance. The weighted edit distance class can be found in 64 | # https://github.com/infoscout/weighted-levenshtein. 65 | # Substitute_costs.txt, insertion_costs and deletion_costs are calculated in 66 | # https://github.com/ThomasDelteil/Gluon_OCR_LSTM_CTC/blob/language_model/model_distance.ipynb 67 | 68 | def __init__(self): 69 | self.substitute_costs = self.make_substitute_costs() 70 | self.insertion_costs = self.make_insertion_costs() 71 | self.deletion_costs = self.make_deletion_costs() 72 | 73 | def make_substitute_costs(self): 74 | substitute_costs = np.loadtxt('models/substitute_costs.txt', dtype=float) 75 | #substitute_costs = np.ones((128, 128), dtype=np.float64) 76 | return substitute_costs 77 | 78 | def make_insertion_costs(self): 79 | insertion_costs = np.loadtxt('models/insertion_costs.txt', dtype=float) 80 | #insertion_costs = np.ones(128, dtype=np.float64) 81 | return insertion_costs 82 | 83 | def make_deletion_costs(self): 84 | deletion_costs = np.loadtxt('models/deletion_costs.txt', dtype=float) 85 | #deletion_costs = np.ones(128, dtype=np.float64) 86 | return deletion_costs 87 | 88 | def __call__(self, input1, input2): 89 | return lev(input1, input2, substitute_costs=self.substitute_costs, 90 | insert_costs=self.insertion_costs, 91 | delete_costs=self.deletion_costs) 92 | 93 | class LexiconSearch: 94 | ''' 95 | Lexicon search was based on https://github.com/rameshjesswani/Semantic-Textual-Similarity/blob/master/nlp_basics/nltk/string_similarity.ipynb 96 | ''' 97 | def __init__(self): 98 | self.dictionary = enchant.Dict('en') 99 | self.word_suggestor = WordSuggestor() 100 | self.distance_measure = OcrDistanceMeasure() 101 | 102 | def suggest_words(self, word): 103 | candidates = list(self.word_suggestor.candidates(word)) 104 | output = [] 105 | for word in candidates: 106 | if word[0].isupper(): 107 | output.append(word) 108 | else: 109 | if self.dictionary.check(word): 110 | output.append(word) 111 | return output 112 | 113 | def minimumEditDistance_spell_corrector(self,word): 114 | max_distance = 3 115 | 116 | if (self.dictionary.check(word.lower())): 117 | return word 118 | 119 | suggested_words = self.suggest_words(word) 120 | num_modified_characters = [] 121 | 122 | if len(suggested_words) != 0: 123 | for sug_words in suggested_words: 124 | num_modified_characters.append(self.distance_measure(word, sug_words)) 125 | 126 | minimum_edit_distance = min(num_modified_characters) 127 | best_arg = num_modified_characters.index(minimum_edit_distance) 128 | if max_distance > minimum_edit_distance: 129 | best_suggestion = suggested_words[best_arg] 130 | return best_suggestion 131 | else: 132 | return word 133 | else: 134 | return word 135 | -------------------------------------------------------------------------------- /ocr/utils/max_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # source: https://brilliant.org/wiki/ford-fulkerson-algorithm/ 5 | 6 | 7 | class Edge(object): 8 | def __init__(self, u, v, w): 9 | self.source = u 10 | self.sink = v 11 | self.capacity = w 12 | 13 | def __repr__(self): 14 | return "%s->%s:%s" % (self.source, self.sink, self.capacity) 15 | 16 | 17 | class FlowNetwork(object): 18 | def __init__(self): 19 | self.adj = {} 20 | self.flow = {} 21 | 22 | def add_vertex(self, vertex): 23 | self.adj[vertex] = [] 24 | 25 | def get_edges(self, v): 26 | return self.adj[v] 27 | 28 | def add_edge(self, u, v, w=0): 29 | if u == v: 30 | raise ValueError("u == v") 31 | edge = Edge(u, v, w) 32 | redge = Edge(v, u, 0) 33 | edge.redge = redge 34 | redge.redge = edge 35 | self.adj[u].append(edge) 36 | self.adj[v].append(redge) 37 | self.flow[edge] = 0 38 | self.flow[redge] = 0 39 | 40 | def find_path(self, source, sink, path): 41 | if source == sink: 42 | return path 43 | for edge in self.get_edges(source): 44 | residual = edge.capacity - self.flow[edge] 45 | if residual > 0 and edge not in path: 46 | result = self.find_path(edge.sink, sink, path + [edge]) 47 | if result is not None: 48 | return result 49 | 50 | def max_flow(self, source, sink): 51 | path = self.find_path(source, sink, []) 52 | while path is not None: 53 | residuals = [edge.capacity - self.flow[edge] for edge in path] 54 | flow = min(residuals) 55 | for edge in path: 56 | self.flow[edge] += flow 57 | self.flow[edge.redge] -= flow 58 | path = self.find_path(source, sink, []) 59 | return sum(self.flow[edge] for edge in self.get_edges(source)) 60 | -------------------------------------------------------------------------------- /ocr/utils/ngram_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import os 6 | import pickle 7 | from mxnet.gluon.data import dataset 8 | 9 | class Ngram_dataset(dataset.ArrayDataset): 10 | ''' 11 | The ngram_dataset takes a noisy_forms_dataset object and converts it into an ngram dataset. 12 | That is, for each word/characters i a tuple is created that contains: 13 | * n words/characters before 14 | * n words/characters after 15 | * Noisy word/characters i 16 | * The actual word/character i 17 | 18 | Parameters 19 | ---------- 20 | passage_ds: Noisy_forms_dataset 21 | A noisy_forms_dataset object 22 | 23 | name: string 24 | An identifier for the temporary save file. 25 | 26 | output_type: string, options: ["word", "character"] 27 | Output data in terms of n characters of words 28 | 29 | n: int, default 3 30 | The number of characters to output 31 | ''' 32 | def __init__(self, passage_ds, name, output_type, n=3): 33 | self.passage_ds = passage_ds 34 | self.n = n 35 | 36 | output_type_options = ["word", "character"] 37 | assert output_type in output_type_options, "{} is not a valid type. Available options: {}".format( 38 | output_type, outptu_type_options) 39 | self.output_type = output_type 40 | 41 | root = os.path.join("dataset", "noisy_forms") 42 | if not os.path.isdir(root): 43 | os.makedirs(root) 44 | ds_location = os.path.join(root, "ngram_{}.pickle".format(name)) 45 | if os.path.isfile(ds_location): 46 | data = pickle.load(open(ds_location, 'rb')) 47 | else: 48 | data = self._get_data() 49 | pickle.dump(data, open(ds_location, 'wb'), protocol=2) 50 | super(Ngram_dataset, self).__init__(data) 51 | 52 | def _get_n_grams(self, text_arr, idx, pre, n): 53 | output = [] 54 | if pre: 55 | indexes = range(idx - n, idx) 56 | else: 57 | indexes = range(idx + 1, idx + n + 1) 58 | for i in indexes: 59 | if 0 <= i and i < len(text_arr): 60 | word = text_arr[i] 61 | if len(word) == 0: 62 | output.append(0) 63 | else: 64 | output.append(word) 65 | else: 66 | output.append(0) 67 | return output 68 | 69 | def _remove_empty_words(self, text_arr): 70 | # Helper function to remove empty words 71 | output = [] 72 | for word in text_arr: 73 | if len(word) > 0: 74 | output.append(word) 75 | return output 76 | 77 | def _separate_word_breaks(self, text_arr): 78 | # Helper function to separate words that are split into two with "-" 79 | output = [] 80 | 81 | for word in text_arr: 82 | if "-" in word: 83 | words = word.split("-") 84 | for word_i in words: 85 | if len(word) > 0: 86 | output.append(word_i) 87 | else: 88 | output.append(word) 89 | return output 90 | 91 | def _is_ngram_similar(self, ngram1, ngram2, p1=0.8, p2=0.8): 92 | ''' 93 | Helper function to check if ngram1 is similar to ngram2. 94 | Parameters 95 | ---------- 96 | ngram1: [str] 97 | A list of strings (or 0 for the null character) that is of size n. 98 | ngram2: [str] 99 | A list of strings (or 0 for the null character) that is of size n. 100 | 101 | p1: float 102 | The percentage of characters that are the same within 2 words to be considered the same word. 103 | 104 | p2: float 105 | The percentage of words that are the same for ngram1 and ngram2 to be considered similar 106 | 107 | Return 108 | ------ 109 | is_ngram_similar: bool 110 | Boolearn that indicates ngram1 and ngram2 are similar. 111 | ''' 112 | 113 | in_count = [] 114 | for ngram1_i, ngram2_i in zip(ngram1, ngram2): 115 | if ngram1_i == 0 or ngram2_i == 0: 116 | if ngram1_i == ngram2_i: 117 | in_count.append(1) 118 | else: 119 | in_count.append(0) 120 | else: 121 | ngram1_i_np = np.array(list(ngram1_i)) 122 | ngram2_i_np = np.array(list(ngram2_i)) 123 | if np.mean(np.in1d(ngram1_i_np, ngram2_i_np)) > p1: 124 | in_count.append(1) 125 | else: 126 | in_count.append(0) 127 | is_ngram_similar = np.mean(in_count) > p2 128 | return is_ngram_similar 129 | 130 | def _get_data(self): 131 | ngrams = [] 132 | for i in range(len(self.passage_ds)): 133 | noisy_text, text = self.passage_ds[i] 134 | noisy_text_arr, text_arr = noisy_text.split(" "), text.split(" ") 135 | 136 | # Heuristics 137 | noisy_text_arr = self._separate_word_breaks(noisy_text_arr) 138 | text_arr = self._separate_word_breaks(text_arr) 139 | 140 | noisy_text_arr = self._remove_empty_words(noisy_text_arr) 141 | text_arr = self._remove_empty_words(text_arr) 142 | 143 | for j in range(len(noisy_text_arr)): 144 | pre_values_j = self._get_n_grams(noisy_text_arr, j, pre=True, n=3) 145 | post_values_j = self._get_n_grams(noisy_text_arr, j, pre=False, n=3) 146 | 147 | for k in range(len(text_arr)): 148 | pre_values_k = self._get_n_grams(text_arr, k, pre=True, n=3) 149 | post_values_k = self._get_n_grams(text_arr, k, pre=False, n=3) 150 | if self._is_ngram_similar(pre_values_j, pre_values_k) and self._is_ngram_similar(post_values_j, post_values_k): 151 | noisy_value = noisy_text_arr[j] 152 | actual_value = text_arr[k] 153 | pre_values = self._get_n_grams(text_arr, k, pre=True, n=self.n) 154 | post_values = self._get_n_grams(text_arr, k, pre=False, n=self.n) 155 | if self.output_type == "word": 156 | ngrams.append([pre_values, post_values, noisy_value, actual_value]) 157 | elif self.output_type == "character": 158 | pre_values = [str(a) for a in pre_values] 159 | post_values = [str(a) for a in post_values] 160 | 161 | noisy_full_string = " ".join(pre_values) + " " + noisy_value + " " + " ".join(post_values) 162 | actual_full_string = " ".join(pre_values) + " " + actual_value + " " + " ".join(post_values) 163 | noisy_index = len(" ".join(pre_values)) + 1 164 | for c in range(len(noisy_value)): 165 | idx = c + noisy_index 166 | new_pre_values = actual_full_string [idx-self.n:idx] 167 | new_post_values = actual_full_string [idx+1:idx+self.n + 1] 168 | new_noisy_values = noisy_full_string[idx] 169 | new_actual_values = actual_full_string[idx] 170 | ngrams.append([new_pre_values, new_post_values, new_noisy_values, new_actual_values]) 171 | return ngrams 172 | 173 | def __getitem__(self, idx): 174 | pre_values, post_values, noisy_value, actual_value = self._data[0][idx] 175 | return pre_values, post_values, noisy_value, actual_value 176 | 177 | def __len__(self): 178 | return len(self._data[0]) 179 | -------------------------------------------------------------------------------- /ocr/utils/noisy_forms_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import pickle 6 | import os 7 | 8 | from .iam_dataset import IAMDataset 9 | 10 | class Noisy_forms_dataset: 11 | ''' 12 | The noisy_forms_dataset provides pairs of identical passages, one of the passages is noisy. 13 | The noise includes random replacements, insertions and deletions. 14 | 15 | Parameters 16 | ---------- 17 | noise_source_transform: (np.array, str) -> str 18 | The noise_source_transform is a function that takes an image containing a single line of handwritten text 19 | and outputs a noisy version of the text string. 20 | 21 | train: bool 22 | Indicates if the data should be used for training or testing 23 | 24 | name: str 25 | An identifier to save a temporary version of the database 26 | ''' 27 | train_size = 0.8 28 | def __init__(self, noise_source_transform, train, name, topK_decode): 29 | self.noise_source_transform = noise_source_transform 30 | self.iam_dataset_form = IAMDataset("form", output_data="text", train=True) 31 | self.iam_dataset_line = IAMDataset("line", output_data="text", train=True) 32 | self.train = train 33 | self.seed = np.random.uniform(0, 1.0, size=len(self.iam_dataset_form)) 34 | self.topK_decode = topK_decode 35 | root = os.path.join("dataset", "noisy_forms") 36 | if not os.path.isdir(root): 37 | os.makedirs(root) 38 | 39 | ds_location = os.path.join(root, "ns{}.pickle".format(name)) 40 | if os.path.isfile(ds_location): 41 | self.train_data, self.test_data = pickle.load(open(ds_location, 'rb')) 42 | else: 43 | self.train_data, self.test_data = self._get_data() 44 | pickle.dump((self.train_data, self.test_data), open(ds_location, 'wb'), protocol=2) 45 | 46 | def _is_line_in_form(self, line_text, form_text, p=0.8): 47 | ''' 48 | Helper function to check if a line of text is within a form. 49 | Since there are differences in punctuations, spaces, etc. The line was split into separate words and if 50 | more than probability of the line is within the form, it's considered in the form. 51 | 52 | Parameters 53 | ---------- 54 | line_text: str 55 | A string of a line. 56 | 57 | form_text: str 58 | A string of a whole passage. 59 | 60 | p: float, default=0.8 61 | the probability of words of a line that is within a form to consider the line is in the form. 62 | 63 | Return 64 | ------ 65 | is_line_in_form: bool 66 | If the line is considered in the form, return true, return false otherwise. 67 | ''' 68 | line_text_array = np.array(line_text.split(" ")) 69 | form_text_array = np.array(form_text.split(" ")) 70 | in_form = np.in1d(line_text_array, form_text_array) 71 | if np.mean(in_form) > p: 72 | return True 73 | else: 74 | return False 75 | 76 | def _get_data(self): 77 | ''' 78 | Generates a noisy text using the noise_source_transform then organises the data from multiple lines 79 | into a single form (to keep the context of the form consistent). 80 | 81 | Returns 82 | ------- 83 | train_data: [(str, str)] 84 | Contains a list of tuples that contains a two passages that are the same but one is noisy. 85 | 86 | test_data: [(str, str)] 87 | Contains a list of tuples that contains a two passages that are the same but one is noisy. This 88 | list of tuples contains independent samples compared to train_data. 89 | ''' 90 | train_data = [] 91 | test_data = [] 92 | 93 | for idx_form in range(len(self.iam_dataset_form)): 94 | print("{}/{}".format(idx_form, len(self.iam_dataset_form))) 95 | _, form_text = self.iam_dataset_form[idx_form] 96 | form_text = form_text[0].replace("\n", " ") 97 | 98 | _, full_form_text = self.iam_dataset_form[idx_form] 99 | full_form_text = full_form_text[0].replace("\n", " ") 100 | 101 | lines_in_form = [] 102 | for idx_line in range(len(self.iam_dataset_line)): 103 | # Iterates through every line data to check if it's within the form. 104 | image, line_text = self.iam_dataset_line[idx_line] 105 | line_text = line_text[0] 106 | 107 | if self._is_line_in_form(line_text, form_text): 108 | prob = self.noise_source_transform(image, line_text) 109 | predicted_text = self.topK_decode(np.argmax(prob, axis=2))[0] 110 | lines_in_form.append(predicted_text) 111 | form_text = form_text.replace(line_text, "") 112 | 113 | predicted_form_text = ' '.join(lines_in_form) 114 | if len(predicted_text) > 500: 115 | import pdb; pdb.set_trace(); 116 | 117 | if self.seed[idx_form] < self.train_size: 118 | train_data.append([predicted_form_text, full_form_text]) 119 | else: 120 | test_data.append([predicted_form_text, full_form_text]) 121 | 122 | return train_data, test_data 123 | 124 | def __getitem__(self, idx): 125 | if self.train: 126 | noisy_text, actual_text = self.train_data[idx] 127 | return noisy_text, actual_text 128 | else: 129 | noisy_text, actual_text = self.test_data[idx] 130 | return noisy_text, actual_text 131 | 132 | def __len__(self): 133 | if self.train: 134 | return len(self.train_data) 135 | else: 136 | return len(self.test_data) 137 | -------------------------------------------------------------------------------- /ocr/utils/sclite_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import re 4 | import uuid 5 | 6 | class ScliteHelper(): 7 | ''' 8 | The Sclite helper class calculates the word error rate (WER) and charater error rate (CER) 9 | given a predicted and actual text. 10 | This class uses sclite2.4 (ftp://jaguar.ncsl.nist.gov/pub/sctk-2.4.10-20151007-1312Z.tar.bz2) 11 | and formats the data according. 12 | Parameters 13 | ---------- 14 | sclite_location: optional, default="sctk-2.4.10/bin" 15 | Location of the sclite_program 16 | tmp_file_location: optional, default=tmp 17 | folder to store the temporary text files. 18 | ''' 19 | 20 | def __init__(self, sclite_location=os.path.join("..", "SCTK", "bin"), 21 | tmp_file_location="tmp", use_uuid=True): 22 | # Check if sclite exists 23 | assert os.path.isdir(sclite_location), "{} does not exist".format(sclite_location) 24 | sclite_error = "{} doesn't contain sclite".format(sclite_location) 25 | retries = 10 26 | for i in range(retries): 27 | if self._test_sclite(sclite_location): 28 | break 29 | elif i == retries-1: 30 | raise sclite_error 31 | self.sclite_location = sclite_location 32 | if use_uuid: 33 | tmp_file_location += "/" + str(uuid.uuid4()) 34 | # Check if tmp_file_location exists 35 | if not os.path.isdir(tmp_file_location): 36 | os.makedirs(tmp_file_location) 37 | self.tmp_file_location = tmp_file_location 38 | self.predicted_text = [] 39 | self.actual_text = [] 40 | 41 | def clear(self): 42 | ''' 43 | Clear the class for new calculations. 44 | ''' 45 | self.predicted_text = [] 46 | self.actual_text = [] 47 | 48 | def _test_sclite(self, sclite_location): 49 | sclite_path = os.path.join(sclite_location, "sclite") 50 | command_line_options = [sclite_path] 51 | try: 52 | subprocess.check_output(command_line_options, stderr=subprocess.STDOUT) 53 | except OSError: 54 | return False 55 | except subprocess.CalledProcessError: 56 | return True 57 | return True 58 | 59 | def _write_string_to_sclite_file(self, sentences_arrays, filename): 60 | SPEAKER_LABEL = "(spk{}_{})" 61 | # Split string into sentences 62 | converted_string = '' 63 | for i, sentences_array in enumerate(sentences_arrays): 64 | for line, sentence in enumerate(sentences_array): 65 | converted_string += sentence + SPEAKER_LABEL.format(i+1, line+1) + "\n" 66 | 67 | # Write converted_string into file 68 | filepath = os.path.join(self.tmp_file_location, filename) 69 | with open(filepath, "w") as f: 70 | f.write(converted_string) 71 | 72 | def _run_sclite(self, predicted_filename, actual_filename, mode, output): 73 | ''' 74 | Run command line for sclite. 75 | Parameters 76 | --------- 77 | predicted_filename: str 78 | file containing output string of the network 79 | actual_filename: str 80 | file containing string of the label 81 | mode: string, Options = ["CER", "WER"] 82 | Choose between CER or WER 83 | output: string, Options = ["print", "string"] 84 | Choose between printing the output or returning a string 85 | Returns 86 | ------- 87 | 88 | stdoutput 89 | If string was chosen as the output option, this function will return a file 90 | containing the stdout 91 | ''' 92 | assert mode in ["CER", "WER"], "mode {} is not in ['CER', 'WER]".format(mode) 93 | assert output in ["print", "string"], "output {} is not in ['print', 'string']".format( 94 | output) 95 | 96 | command_line = [os.path.join(self.sclite_location, "sclite"), 97 | "-h", os.path.join(self.tmp_file_location, predicted_filename), 98 | "-r", os.path.join(self.tmp_file_location, actual_filename), 99 | "-i", "rm"] 100 | if mode == "WER": 101 | pass # Word error rate is by default 102 | 103 | retries = 10 104 | 105 | for i in range(retries): 106 | try: 107 | if mode == "CER": 108 | command_line.append("-c") 109 | if output == "print": 110 | subprocess.call(command_line) 111 | elif output == "string": 112 | cmd = subprocess.Popen(command_line, stdout=subprocess.PIPE) 113 | return cmd.stdout 114 | except: 115 | print("There was an error") 116 | 117 | def _print_error_rate_summary(self, mode, predicted_filename="predicted.txt", 118 | actual_filename="actual.txt"): 119 | ''' 120 | Print the error rate summary of sclite 121 | 122 | Parameters 123 | ---------- 124 | mode: string, Options = ["CER", "WER"] 125 | Choose between CER or WER 126 | ''' 127 | self._run_sclite(predicted_filename, actual_filename, mode, output="print") 128 | 129 | def _get_error_rate(self, mode, predicted_filename="predicted.txt", 130 | actual_filename="actual.txt"): 131 | ''' 132 | Get the error rate by analysing the output of sclite 133 | Parameters 134 | ---------- 135 | mode: string, Options = ["CER", "WER"] 136 | Choose between CER or WER 137 | Returns 138 | ------- 139 | number: int 140 | The number of characters or words depending on the mode selected. 141 | error_rate: float 142 | ''' 143 | number = None 144 | er = None 145 | output_file = self._run_sclite(predicted_filename, actual_filename, 146 | mode, output="string") 147 | 148 | match_tar = r'.*Mean.*\|.* (\d*.\d) .* (\d*.\d).* \|' 149 | for line in output_file.readlines(): 150 | match = re.match(match_tar, line.decode('utf-8'), re.M|re.I) 151 | if match: 152 | number = match.group(1) 153 | er = match.group(2) 154 | assert number != None and er != None, "Error in parsing output." 155 | return float(number), 100.0 - float(er) 156 | 157 | def _make_sclite_files(self, predicted_filename="predicted.txt", 158 | actual_filename="actual.txt"): 159 | ''' 160 | Run command line for sclite. 161 | Parameters 162 | --------- 163 | predicted_filename: str, default: predicted.txt 164 | filename of the predicted file 165 | actual_filename: str, default: actual.txt 166 | filename of the actual file 167 | ''' 168 | self._write_string_to_sclite_file(self.predicted_text, filename=predicted_filename) 169 | self._write_string_to_sclite_file(self.actual_text, filename=actual_filename) 170 | 171 | def add_text(self, predicted_text, actual_text): 172 | ''' 173 | Function to save predicted and actual text pairs in memory. 174 | Running the future fuctions will generate the required text files. 175 | ''' 176 | self.predicted_text.append(predicted_text) 177 | self.actual_text.append(actual_text) 178 | 179 | def print_wer_summary(self): 180 | ''' 181 | see _print_error_rate_summary for docstring 182 | ''' 183 | self._make_sclite_files() 184 | self._print_error_rate_summary(mode="WER") 185 | 186 | def print_cer_summary(self): 187 | ''' 188 | see _print_error_rate_summary for docstring 189 | ''' 190 | self._make_sclite_files() 191 | self._print_error_rate_summary(mode="CER") 192 | 193 | def get_wer(self): 194 | ''' 195 | See _get_error_rate for docstring 196 | ''' 197 | self._make_sclite_files() 198 | return self._get_error_rate(mode="WER") 199 | 200 | def get_cer(self): 201 | ''' 202 | See _get_error_rate for docstring 203 | ''' 204 | self._make_sclite_files() 205 | return self._get_error_rate(mode="CER") 206 | 207 | if __name__ == "__main__": 208 | cls = ScliteHelper() 209 | actual1 = 'Jonathan loves to eat apples. This is the second sentence.' 210 | predicted1 = 'Jonothon loves to eot. This is the second santense.' 211 | 212 | cls.add_text(predicted1, actual1) 213 | actual2 = 'Jonathan loves to eat apples. This is the second sentence.' 214 | predicted2 = 'Jonothan loves to eot. This is the second santense.' 215 | cls.add_text(predicted2, actual2) 216 | 217 | cls.print_cer_summary() 218 | num, er = cls.get_cer() 219 | print(num, er) -------------------------------------------------------------------------------- /ocr/utils/word_to_line.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import math 5 | 6 | import numpy as np 7 | from scipy.cluster.hierarchy import fcluster 8 | 9 | from .expand_bounding_box import expand_bounding_box 10 | 11 | def _clip_value(value, max_value): 12 | ''' 13 | Helper function to make sure that "value" will not be greater than max_value 14 | or lower than 0. 15 | ''' 16 | output = value 17 | if output < 0: 18 | output = 0 19 | if output > max_value: 20 | output = max_value 21 | return int(output) 22 | 23 | def _get_max_coord(bbs, x_or_y): 24 | ''' 25 | Helper function to find the largest coordinate given a list of 26 | bounding boxes in the x or y direction. 27 | ''' 28 | assert x_or_y in ["x", "y"], "x_or_y can only be x or y" 29 | max_value = 0.0 30 | for bb in bbs: 31 | if x_or_y == "x": 32 | value = bb[0] + bb[2] 33 | else: 34 | value = bb[1] + bb[3] 35 | if value > max_value: 36 | max_value = value 37 | return max_value 38 | 39 | def _get_min_coord(bbs, x_or_y): 40 | ''' 41 | Helper function to find the largest coordinate given a list of 42 | bounding boxes in the x or y direction. 43 | ''' 44 | assert x_or_y in ["x", "y"], "x_or_y can only be x or y" 45 | min_value = 100 46 | for bb in bbs: 47 | if x_or_y == "x": 48 | value = bb[0] 49 | else: 50 | value = bb[1] 51 | if value < min_value: 52 | min_value = value 53 | return min_value 54 | 55 | def _get_bounding_box_of_bb_list(bbs_in_a_line): 56 | ''' 57 | Given a list of bounding boxes, find the maximum x, y and 58 | minimum x, y coordinates. This is the bounding box that 59 | emcompasses all the words. Return this bounding box in the form 60 | (x', y', w', h'). 61 | ''' 62 | max_x = _get_max_coord(bbs_in_a_line, x_or_y="x") 63 | min_x = _get_min_coord(bbs_in_a_line, x_or_y="x") 64 | 65 | max_y = _get_max_coord(bbs_in_a_line, x_or_y="y") 66 | min_y = _get_min_coord(bbs_in_a_line, x_or_y="y") 67 | 68 | line_bb = (min_x, min_y, max_x - min_x, max_y - min_y) 69 | return line_bb 70 | 71 | def _filter_bbs(bbs, min_size=0.005): 72 | ''' 73 | Remove bounding boxes that are too small 74 | ''' 75 | output_bbs = [] 76 | for bb in bbs: 77 | if bb[2] * bb[3] > min_size: 78 | output_bbs.append(bb) 79 | return np.array(output_bbs) 80 | 81 | def _get_line_overlap_percentage(y1, h1, y2, h2): 82 | ''' 83 | Calculates how much (percentage) y2->y2+h2 overlaps with y1->y1+h1. 84 | Algorithm assumes that y2 is larger than y1 85 | ''' 86 | if y2 > y1 and (y1 + h1) > y2: 87 | # Is y2 enclosed in y1 88 | if (y1 + h1) > (y2 + h2): 89 | return 1.0 90 | else: 91 | return ((y1 + h1) - (y2))/h1 92 | else: 93 | return 0.0 94 | 95 | def _get_rect_overlap_percentage(x1, y1, w1, h1, x2, y2, w2, h2): 96 | ''' 97 | Calculate how much (in percentage) that rect2 overlaps with rect1 98 | ''' 99 | # Check if rect overlaps 100 | x_overlap = (x1 + w1 >= x2 and x2 >= x1) or (x2 + w2 >= x1 and x1 >= x2) 101 | y_overlap = (y1 + h1 >= y2 and y2 >= y1) or (y2 + h2 >= y1 and y1 >= y2) 102 | if x_overlap and y_overlap: 103 | intersect_size = max(0, min(x1 + w1, x2 + w2) - min(x1, x2)) * max(0, min(y1 + h1, y2 + h2) - max(y1, y2)) 104 | s1 = w1 * h1 105 | return intersect_size / s1 106 | else: 107 | return 0 108 | 109 | def combine_bbs_into_lines(bbs, y_overlap=0.2): 110 | ''' 111 | Algorithm to group word crops into lines. 112 | Iterates over every bb, if the overlap in the y direction 113 | between 2 boxes has less than y_overlap overlap, then group the previous words together. 114 | ''' 115 | line_bbs = [] 116 | bbs_in_a_line = [] 117 | y_indexes = np.argsort(bbs[:, 1]) 118 | # Iterate through the sorted bounding box. 119 | previous_y_coords = None 120 | for y_index in y_indexes: 121 | y_coords = (bbs[y_index, 1], bbs[y_index, 3]) # y and height 122 | 123 | # new line if the overlap is more than y_overlap 124 | if previous_y_coords is not None: 125 | line_overlap_percentage1 = _get_line_overlap_percentage( 126 | previous_y_coords[0], previous_y_coords[1], 127 | y_coords[0], y_coords[1]) 128 | line_overlap_percentage2 = _get_line_overlap_percentage( 129 | y_coords[0], y_coords[1], 130 | previous_y_coords[0], previous_y_coords[1]) 131 | line_overlap_percentage = max(line_overlap_percentage1, line_overlap_percentage2) 132 | if line_overlap_percentage < y_overlap: 133 | line_bb = _get_bounding_box_of_bb_list(bbs_in_a_line) 134 | line_bbs.append(line_bb) 135 | bbs_in_a_line = [] 136 | bbs_in_a_line.append(bbs[y_index, :]) 137 | previous_y_coords = y_coords 138 | 139 | # process the last line 140 | line_bb = _get_bounding_box_of_bb_list(bbs_in_a_line) 141 | line_bbs.append(line_bb) 142 | return line_bbs 143 | 144 | def sort_bbs_line_by_line(bbs, y_overlap=0.2): 145 | ''' 146 | Function to combine word bbs into lines. 147 | ''' 148 | line_bbs = _filter_bbs(bbs, min_size=0.0001) #Filter small word BBs 149 | line_bbs = combine_bbs_into_lines(line_bbs, y_overlap) 150 | line_bb_expanded = [] 151 | for line_bb in line_bbs: 152 | line_bb_i = expand_bounding_box(line_bb, expand_bb_scale_x=0.1, 153 | expand_bb_scale_y=0.05) 154 | line_bb_expanded.append(line_bb_i) 155 | line_bbs = np.array(line_bb_expanded) 156 | 157 | # X start heuristics 158 | # Remove lines that start more than 150% away 159 | x_start_within_boundary = line_bbs[:, 0] < 0.5 160 | line_bbs = line_bbs[x_start_within_boundary] 161 | 162 | # Remove lines that start 20% away from the average 163 | x_start_line_bbs = line_bbs[:, 0] 164 | x_start_diff = np.abs(x_start_line_bbs - np.median(x_start_line_bbs)) 165 | x_start_remove = x_start_diff < 0.2 166 | line_bbs = line_bbs[x_start_remove] 167 | 168 | # X length heuristics 169 | # Remove lines that are 50% shorter excluding the last element 170 | if len(line_bbs) > 1: 171 | x_length_line_bbs = line_bbs[:-1, 0] - line_bbs[:-1, 2] 172 | x_length_diff = np.abs(x_length_line_bbs - np.median(x_length_line_bbs)) 173 | x_length_remove = x_length_diff < 0.5 174 | last_line = line_bbs[-1] 175 | line_bbs = line_bbs[:-1][x_length_remove] 176 | line_bbs = np.vstack([line_bbs, last_line]) 177 | 178 | # Y height heuristics 179 | # Split lines that are more than 1.5 of the others 180 | y_height = line_bbs[:, 3] 181 | y_height_diff = np.abs(y_height/np.median(y_height)) 182 | y_height_remove = y_height_diff > 1.65 183 | 184 | new_line_bbs = [] 185 | for i in range(line_bbs.shape[0]): 186 | if y_height_remove[i]: 187 | # split line into 2 188 | new_line_top = np.copy(line_bbs[i]) 189 | new_line_top[3] = new_line_top[3] / 2 190 | 191 | new_line_bottom = np.copy(line_bbs[i]) 192 | new_line_bottom[1] = new_line_bottom[1] + new_line_bottom[3]/2 193 | new_line_bottom[3] = new_line_bottom[3] / 2 194 | 195 | new_line_bbs.append(new_line_top) 196 | new_line_bbs.append(new_line_bottom) 197 | else: 198 | new_line_bbs.append(line_bbs[i]) 199 | line_bbs = np.vstack(new_line_bbs) 200 | 201 | # Y consistency heuristics 202 | # Remove lines that overlap by 40% with other lines 203 | line_total_overlap = [] 204 | for i in range(line_bbs.shape[0]): 205 | overlap_i = 0.0 206 | for j in range(line_bbs.shape[0]): 207 | if i != j: 208 | line_i, line_j = line_bbs[i], line_bbs[j] 209 | overlap_i += _get_rect_overlap_percentage(line_i[0], line_i[1], line_i[2], line_i[3], 210 | line_j[0], line_j[1], line_j[2], line_j[3]) 211 | line_total_overlap.append(overlap_i) 212 | overlap_remove = np.array(line_total_overlap) < 1 213 | line_bbs = line_bbs[overlap_remove] 214 | return line_bbs 215 | 216 | def crop_line_images(image, line_bbs): 217 | ''' 218 | Given the input form image, crop the image given a list of bounding boxes. 219 | ''' 220 | line_images = [] 221 | for line_bb in line_bbs: 222 | (x, y, w, h) = line_bb 223 | image_h, image_w = image.shape[-2:] 224 | (x, y, w, h) = (x * image_w, y * image_h, w * image_w, h * image_h) 225 | x1 = _clip_value(x, max_value=image_w) 226 | x2 = _clip_value(x + w, max_value=image_w) 227 | y1 = _clip_value(y, max_value=image_h) 228 | y2 = _clip_value(y + h, max_value=image_h) 229 | 230 | line_image = image[y1:y2, x1:x2] 231 | if line_image.shape[0] > 0 and line_image.shape[1] > 0: 232 | line_images.append(line_image) 233 | return line_images 234 | -------------------------------------------------------------------------------- /ocr/word_and_line_segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import time 5 | import random 6 | import os 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | import argparse 11 | 12 | import mxnet as mx 13 | from mxnet.contrib.ndarray import MultiBoxPrior, MultiBoxTarget, MultiBoxDetection, box_nms 14 | import numpy as np 15 | from skimage.draw import line_aa 16 | from skimage import transform as skimage_tf 17 | 18 | from mxnet import nd, autograd, gluon 19 | from mxnet.image import resize_short 20 | from mxboard import SummaryWriter 21 | from mxnet.gluon.model_zoo.vision import resnet34_v1 22 | np.seterr(all='raise') 23 | 24 | import multiprocessing 25 | mx.random.seed(1) 26 | 27 | from .utils.iam_dataset import IAMDataset 28 | from .utils.draw_box_on_image import draw_boxes_on_image 29 | 30 | print_every_n = 5 31 | send_image_every_n = 20 32 | save_every_n = 50 33 | 34 | # To run: 35 | # python word_segmentation.py --min_c 0.01 --overlap_thres 0.10 --topk 150 --epoch 401 --checkpoint_name ssd_400.params 36 | # For fine_tuning: 37 | # python word_segmentation.py -p ssd_550.params 38 | 39 | # python word_segmentation.py --min_c 0.05 --overlap_thres 0.001 --topk 400 --epoch 401 --checkpoint_name word_seg.params 40 | 41 | class SSD(gluon.Block): 42 | def __init__(self, num_classes, ctx, **kwargs): 43 | super(SSD, self).__init__(**kwargs) 44 | 45 | # Seven sets of anchor boxes are defined. For each set, n=2 sizes and m=3 ratios are defined. 46 | # Four anchor boxes (n + m - 1) are generated: 2 square anchor boxes based on the n=2 sizes and 2 rectanges based on 47 | # the sizes and the ratios. See https://discuss.mxnet.io/t/question-regarding-ssd-algorithm/1307 for more information. 48 | 49 | #self.anchor_sizes = [[.1, .2], [.2, .3], [.2, .4], [.4, .6], [.5, .7], [.6, .8], [.7, .9]] 50 | #self.anchor_ratios = [[1, 3, 5], [1, 3, 5], [1, 6, 8], [1, 5, 7], [1, 6, 8], [1, 7, 9], [1, 7, 10]] 51 | 52 | self.anchor_sizes = [[.1, .2], [.2, .3], [.2, .4], [.3, .4], [.3, .5], [.4, .6]] 53 | self.anchor_ratios = [[1, 3, 5], [1, 3, 5], [1, 6, 8], [1, 4, 7], [1, 6, 8], [1, 5, 7]] 54 | 55 | self.num_anchors = len(self.anchor_sizes) 56 | self.num_classes = num_classes 57 | self.ctx = ctx 58 | with self.name_scope(): 59 | self.body, self.downsamples, self.class_preds, self.box_preds = self.get_ssd_model() 60 | self.downsamples.initialize(mx.init.Normal(), ctx=self.ctx) 61 | self.class_preds.initialize(mx.init.Normal(), ctx=self.ctx) 62 | self.box_preds.initialize(mx.init.Normal(), ctx=self.ctx) 63 | 64 | def get_body(self): 65 | ''' 66 | Create the feature extraction network of the SSD based on resnet34. 67 | The first layer of the res-net is converted into grayscale by averaging the weights of the 3 channels 68 | of the original resnet. 69 | 70 | Returns 71 | ------- 72 | network: gluon.nn.HybridSequential 73 | The body network for feature extraction based on resnet 74 | 75 | ''' 76 | pretrained = resnet34_v1(pretrained=True, ctx=self.ctx) 77 | pretrained_2 = resnet34_v1(pretrained=True, ctx=mx.cpu(0)) 78 | first_weights = pretrained_2.features[0].weight.data().mean(axis=1).expand_dims(axis=1) 79 | # First weights could be replaced with individual channels. 80 | 81 | body = gluon.nn.HybridSequential() 82 | with body.name_scope(): 83 | first_layer = gluon.nn.Conv2D(channels=64, kernel_size=(7, 7), padding=(3, 3), strides=(2, 2), in_channels=1, use_bias=False) 84 | first_layer.initialize(mx.init.Normal(), ctx=self.ctx) 85 | first_layer.weight.set_data(first_weights) 86 | body.add(first_layer) 87 | body.add(*pretrained.features[1:-3]) 88 | return body 89 | 90 | def get_class_predictor(self, num_anchors_predicted): 91 | ''' 92 | Creates the category prediction network (takes input from each downsampled feature) 93 | 94 | Parameters 95 | ---------- 96 | 97 | num_anchors_predicted: int 98 | Given n sizes and m ratios, the number of boxes predicted is n+m-1. 99 | e.g., sizes=[.1, .2], ratios=[1, 3, 5] the number of anchors predicted is 4. 100 | 101 | Returns 102 | ------- 103 | 104 | network: gluon.nn.HybridSequential 105 | The class predictor network 106 | ''' 107 | return gluon.nn.Conv2D(num_anchors_predicted*(self.num_classes + 1), kernel_size=3, padding=1) 108 | 109 | def get_box_predictor(self, num_anchors_predicted): 110 | ''' 111 | Creates the bounding box prediction network (takes input from each downsampled feature) 112 | 113 | Parameters 114 | ---------- 115 | 116 | num_anchors_predicted: int 117 | Given n sizes and m ratios, the number of boxes predicted is n+m-1. 118 | e.g., sizes=[.1, .2], ratios=[1, 3, 5] the number of anchors predicted is 4. 119 | 120 | Returns 121 | ------- 122 | 123 | pred: gluon.nn.HybridSequential 124 | The box predictor network 125 | ''' 126 | pred = gluon.nn.HybridSequential() 127 | with pred.name_scope(): 128 | pred.add(gluon.nn.Conv2D(channels=num_anchors_predicted*4, kernel_size=3, padding=1)) 129 | return pred 130 | 131 | def get_down_sampler(self, num_filters): 132 | ''' 133 | Creates a two-stacked Conv-BatchNorm-Relu and then a pooling layer to 134 | downsample the image features by half. 135 | ''' 136 | out = gluon.nn.HybridSequential() 137 | for _ in range(2): 138 | out.add(gluon.nn.Conv2D(num_filters, 3, strides=1, padding=1)) 139 | out.add(gluon.nn.BatchNorm(in_channels=num_filters)) 140 | out.add(gluon.nn.Activation('relu')) 141 | out.add(gluon.nn.MaxPool2D(2)) 142 | out.hybridize() 143 | return out 144 | 145 | def get_ssd_model(self): 146 | ''' 147 | Creates the SSD model that includes the image feature, downsample, category 148 | and bounding boxes prediction networks. 149 | ''' 150 | body = self.get_body() 151 | downsamples = gluon.nn.HybridSequential() 152 | class_preds = gluon.nn.HybridSequential() 153 | box_preds = gluon.nn.HybridSequential() 154 | 155 | downsamples.add(self.get_down_sampler(32)) 156 | downsamples.add(self.get_down_sampler(32)) 157 | downsamples.add(self.get_down_sampler(32)) 158 | 159 | for scale in range(self.num_anchors): 160 | num_anchors_predicted = len(self.anchor_sizes[0]) + len(self.anchor_ratios[0]) - 1 161 | class_preds.add(self.get_class_predictor(num_anchors_predicted)) 162 | box_preds.add(self.get_box_predictor(num_anchors_predicted)) 163 | 164 | return body, downsamples, class_preds, box_preds 165 | 166 | def ssd_forward(self, x): 167 | ''' 168 | Helper function of the forward pass of the sdd 169 | ''' 170 | x = self.body(x) 171 | 172 | default_anchors = [] 173 | predicted_boxes = [] 174 | predicted_classes = [] 175 | 176 | for i in range(self.num_anchors): 177 | default_anchors.append(MultiBoxPrior(x, sizes=self.anchor_sizes[i], ratios=self.anchor_ratios[i])) 178 | predicted_boxes.append(self._flatten_prediction(self.box_preds[i](x))) 179 | predicted_classes.append(self._flatten_prediction(self.class_preds[i](x))) 180 | if i < len(self.downsamples): 181 | x = self.downsamples[i](x) 182 | elif i == 3: 183 | x = nd.Pooling(x, global_pool=True, pool_type='max', kernel=(4, 4)) 184 | return default_anchors, predicted_classes, predicted_boxes 185 | 186 | def forward(self, x): 187 | default_anchors, predicted_classes, predicted_boxes = self.ssd_forward(x) 188 | # we want to concatenate anchors, class predictions, box predictions from different layers 189 | anchors = nd.concat(*default_anchors, dim=1) 190 | box_preds = nd.concat(*predicted_boxes, dim=1) 191 | class_preds = nd.concat(*predicted_classes, dim=1) 192 | class_preds = nd.reshape(class_preds, shape=(0, -1, self.num_classes + 1)) 193 | return anchors, class_preds, box_preds 194 | 195 | def _flatten_prediction(self, pred): 196 | ''' 197 | Helper function to flatten the predicted bounding boxes and categories 198 | ''' 199 | return nd.flatten(nd.transpose(pred, axes=(0, 2, 3, 1))) 200 | 201 | def training_targets(self, default_anchors, class_predicts, labels): 202 | ''' 203 | Helper function to obtain the bounding boxes from the anchors. 204 | ''' 205 | class_predicts = nd.transpose(class_predicts, axes=(0, 2, 1)) 206 | box_target, box_mask, cls_target = MultiBoxTarget(default_anchors, labels, class_predicts) 207 | return box_target, box_mask, cls_target 208 | 209 | class SmoothL1Loss(gluon.loss.Loss): 210 | ''' 211 | A SmoothL1loss function defined in https://gluon.mxnet.io/chapter08_computer-vision/object-detection.html 212 | ''' 213 | def __init__(self, batch_axis=0, **kwargs): 214 | super(SmoothL1Loss, self).__init__(None, batch_axis, **kwargs) 215 | 216 | def hybrid_forward(self, F, output, label, mask): 217 | loss = F.smooth_l1((output - label) * mask, scalar=1.0) 218 | return F.mean(loss, self._batch_axis, exclude=True) 219 | 220 | def augment_transform(image, label): 221 | ''' 222 | 1) Function that randomly translates the input image by +-width_range and +-height_range. 223 | The labels (bounding boxes) are also translated by the same amount. 224 | 2) Each line can also be randomly removed for augmentation. Labels are also reduced to correspond to this 225 | data and label are converted into tensors by calling the "transform" function. 226 | ''' 227 | ty = random.uniform(-random_y_translation, random_y_translation) 228 | tx = random.uniform(-random_x_translation, random_x_translation) 229 | 230 | st = skimage_tf.SimilarityTransform(translation=(tx*image.shape[1], ty*image.shape[0])) 231 | image = skimage_tf.warp(image, st, cval=1.0) 232 | 233 | label[:, 0] = label[:, 0] - tx/2 #NOTE: Check why it has to be halfed (found experimentally) 234 | label[:, 1] = label[:, 1] - ty/2 235 | 236 | index = np.random.uniform(0, 1.0, size=label.shape[0]) > random_remove_box 237 | for i, should_output_bb in enumerate(index): 238 | if should_output_bb == False: 239 | (x, y, w, h) = label[i] 240 | (x1, y1, x2, y2) = (x, y, x + w, y + h) 241 | (x1, y1, x2, y2) = (x1 * image.shape[1], y1 * image.shape[0], 242 | x2 * image.shape[1], y2 * image.shape[0]) 243 | (x1, y1, x2, y2) = (int(x1), int(y1), int(x2), int(y2)) 244 | x1 = 0 if x1 < 0 else x1 245 | y1 = 0 if y1 < 0 else y1 246 | x2 = 0 if x2 < 0 else x2 247 | y2 = 0 if y2 < 0 else y2 248 | image_h, image_w = image.shape 249 | x1 = image_w-1 if x1 >= image_w else x1 250 | y1 = image_h-1 if y1 >= image_h else y1 251 | x2 = image_w-1 if x2 >= image_w else x2 252 | y2 = image_h-1 if y2 >= image_h else y2 253 | image[y1:y2, x1:x2] = image[y1, x1] 254 | 255 | augmented_labels = label[index, :] 256 | return transform(image*255., augmented_labels) 257 | 258 | def transform(image, label): 259 | ''' 260 | Function that converts resizes image into the input image tensor for a CNN. 261 | The labels (bounding boxes) are expanded, converted into (x, y, x+w, y+h), and 262 | zero padded to the maximum number of labels. Finally, it is converted into a float 263 | tensor. 264 | ''' 265 | max_label_n = 128 if detection_box == "word" else 13 266 | 267 | # Resize the image 268 | image = np.expand_dims(image, axis=2) 269 | image = mx.nd.array(image) 270 | image = resize_short(image, image_size) 271 | image = image.transpose([2, 0, 1])/255. 272 | 273 | # Expand the bounding box by expand_bb_scale 274 | bb = label.copy() 275 | new_w = (1 + expand_bb_scale) * bb[:, 2] 276 | new_h = (1 + expand_bb_scale) * bb[:, 3] 277 | 278 | bb[:, 0] = bb[:, 0] - (new_w - bb[:, 2])/2 279 | bb[:, 1] = bb[:, 1] - (new_h - bb[:, 3])/2 280 | bb[:, 2] = new_w 281 | bb[:, 3] = new_h 282 | label = bb 283 | 284 | # Convert the predicted bounding box from (x, y, w, h to (x, y, x + w, y + h) 285 | label = label.astype(np.float32) 286 | label[:, 2] = label[:, 0] + label[:, 2] 287 | label[:, 3] = label[:, 1] + label[:, 3] 288 | 289 | # Zero pad the data 290 | label_n = label.shape[0] 291 | label_padded = np.zeros(shape=(max_label_n, 5)) 292 | label_padded[:label_n, 1:] = label 293 | label_padded[:label_n, 0] = np.ones(shape=(1, label_n)) 294 | label_padded = mx.nd.array(label_padded) 295 | return image, label_padded 296 | 297 | 298 | def generate_output_image(box_predictions, default_anchors, cls_probs, box_target, box_mask, cls_target, x, y): 299 | ''' 300 | Generate the image with the predicted and actual bounding boxes. 301 | Parameters 302 | ---------- 303 | box_predictions: nd.array 304 | Bounding box predictions relative to the anchor boxes, output of the network 305 | 306 | default_anchors: nd.array 307 | Anchors used, output of the network 308 | 309 | cls_probs: nd.array 310 | Output of nd.SoftmaxActivation(nd.transpose(class_predictions, (0, 2, 1)), mode='channel') 311 | where class_predictions is the output of the network. 312 | 313 | box_target: nd.array 314 | Output classification probabilities from network.training_targets(default_anchors, class_predictions, y) 315 | 316 | box_mask: nd.array 317 | Output bounding box predictions from network.training_targets(default_anchors, class_predictions, y) 318 | 319 | cls_target: nd.array 320 | Output targets from network.training_targets(default_anchors, class_predictions, y) 321 | 322 | x: nd.array 323 | The input images 324 | 325 | y: nd.array 326 | The actual labels 327 | 328 | Returns 329 | ------- 330 | output_image: np.array 331 | The images with the predicted and actual bounding boxes drawn on 332 | 333 | number_of_bbs: int 334 | The number of predicting bounding boxes 335 | ''' 336 | output = MultiBoxDetection(*[cls_probs, box_predictions, default_anchors], force_suppress=True, clip=False) 337 | output = box_nms(output, overlap_thresh=overlap_thres, valid_thresh=min_c, topk=topk) 338 | output = output.asnumpy() 339 | 340 | number_of_bbs = 0 341 | predicted_bb = [] 342 | for b in range(output.shape[0]): 343 | predicted_bb_ = output[b, output[b, :, 0] != -1] 344 | predicted_bb_ = predicted_bb_[:, 2:] 345 | number_of_bbs += predicted_bb_.shape[0] 346 | predicted_bb_[:, 2] = predicted_bb_[:, 2] - predicted_bb_[:, 0] 347 | predicted_bb_[:, 3] = predicted_bb_[:, 3] - predicted_bb_[:, 1] 348 | predicted_bb.append(predicted_bb_) 349 | 350 | labels = y[:, :, 1:].asnumpy() 351 | labels[:, :, 2] = labels[:, :, 2] - labels[:, :, 0] 352 | labels[:, :, 3] = labels[:, :, 3] - labels[:, :, 1] 353 | 354 | output_image = draw_boxes_on_image(predicted_bb, labels, x.asnumpy()) 355 | output_image[output_image<0] = 0 356 | output_image[output_image>1] = 1 357 | 358 | return output_image, number_of_bbs 359 | 360 | def predict_bounding_boxes(net, image, min_c, overlap_thres, topk, ctx=mx.gpu()): 361 | ''' 362 | Given the outputs of the dataset (image and bounding box) and the network, 363 | the predicted bounding boxes are provided. 364 | 365 | Parameters 366 | ---------- 367 | net: SSD 368 | The trained SSD network. 369 | 370 | image: np.array 371 | A grayscale image of the handwriting passages. 372 | 373 | Returns 374 | ------- 375 | predicted_bb: [(x, y, w, h)] 376 | The predicted bounding boxes. 377 | ''' 378 | image = mx.nd.array(image).expand_dims(axis=2) 379 | image = mx.image.resize_short(image, 350) 380 | image = image.transpose([2, 0, 1])/255. 381 | 382 | image = image.as_in_context(ctx) 383 | image = image.expand_dims(0) 384 | 385 | bb = np.zeros(shape=(13, 5)) 386 | bb = mx.nd.array(bb) 387 | bb = bb.as_in_context(ctx) 388 | bb = bb.expand_dims(axis=0) 389 | 390 | default_anchors, class_predictions, box_predictions = net(image) 391 | 392 | box_target, box_mask, cls_target = net.training_targets(default_anchors, 393 | class_predictions, bb) 394 | 395 | cls_probs = mx.nd.SoftmaxActivation(mx.nd.transpose(class_predictions, (0, 2, 1)), mode='channel') 396 | 397 | predicted_bb = MultiBoxDetection(*[cls_probs, box_predictions, default_anchors], force_suppress=True, clip=False) 398 | predicted_bb = box_nms(predicted_bb, overlap_thresh=overlap_thres, valid_thresh=min_c, topk=topk) 399 | predicted_bb = predicted_bb.asnumpy() 400 | predicted_bb = predicted_bb[0, predicted_bb[0, :, 0] != -1] 401 | predicted_bb = predicted_bb[:, 2:] 402 | predicted_bb[:, 2] = predicted_bb[:, 2] - predicted_bb[:, 0] 403 | predicted_bb[:, 3] = predicted_bb[:, 3] - predicted_bb[:, 1] 404 | 405 | return predicted_bb 406 | 407 | 408 | def run_epoch(e, network, dataloader, trainer, log_dir, print_name, is_train, update_metric): 409 | ''' 410 | Run one epoch to train or test the SSD network 411 | 412 | Parameters 413 | ---------- 414 | 415 | e: int 416 | The epoch number 417 | 418 | network: nn.Gluon.HybridSequential 419 | The SSD network 420 | 421 | dataloader: gluon.data.DataLoader 422 | The train or testing dataloader that is wrapped around the iam_dataset 423 | 424 | log_dir: Str 425 | The directory to store the log files for mxboard 426 | 427 | print_name: Str 428 | Name to print for associating with the data. usually this will be "train" and "test" 429 | 430 | is_train: bool 431 | Boolean to indicate whether or not the CNN should be updated. is_train should only be set to true for the training data 432 | 433 | Returns 434 | ------- 435 | 436 | network: gluon.nn.HybridSequential 437 | The class predictor network 438 | ''' 439 | 440 | total_losses = [0 for ctx_i in ctx] 441 | for i, (X, Y) in enumerate(dataloader): 442 | X = gluon.utils.split_and_load(X, ctx) 443 | Y = gluon.utils.split_and_load(Y, ctx) 444 | 445 | with autograd.record(train_mode=is_train): 446 | losses = [] 447 | for x, y in zip(X, Y): 448 | default_anchors, class_predictions, box_predictions = network(x) 449 | box_target, box_mask, cls_target = network.training_targets(default_anchors, class_predictions, y) 450 | # losses 451 | loss_class = cls_loss(class_predictions, cls_target) 452 | loss_box = box_loss(box_predictions, box_target, box_mask) 453 | # sum all losses 454 | loss = loss_class + loss_box 455 | losses.append(loss) 456 | 457 | if is_train: 458 | for loss in losses: 459 | loss.backward() 460 | step_size = 0 461 | for x in X: 462 | step_size += x.shape[0] 463 | trainer.step(step_size) 464 | 465 | for index, loss in enumerate(losses): 466 | total_losses[index] += loss.mean().asscalar() 467 | 468 | if update_metric: 469 | cls_metric.update([cls_target], [nd.transpose(class_predictions, (0, 2, 1))]) 470 | box_metric.update([box_target], [box_predictions * box_mask]) 471 | 472 | if i == 0 and e % send_image_every_n == 0 and e > 0: 473 | cls_probs = nd.SoftmaxActivation(nd.transpose(class_predictions, (0, 2, 1)), mode='channel') 474 | output_image, number_of_bbs = generate_output_image(box_predictions, default_anchors, 475 | cls_probs, box_target, box_mask, 476 | cls_target, x, y) 477 | print("Number of predicted {} BBs = {}".format(print_name, number_of_bbs)) 478 | with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw: 479 | sw.add_image('bb_{}_image'.format(print_name), output_image, global_step=e) 480 | 481 | 482 | total_loss = 0 483 | for loss in total_losses: 484 | total_loss += loss / (len(dataloader)*len(total_losses)) 485 | 486 | with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw: 487 | if update_metric: 488 | name1, val1 = cls_metric.get() 489 | name2, val2 = box_metric.get() 490 | sw.add_scalar(name1, {"test": val1}, global_step=e) 491 | sw.add_scalar(name2, {"test": val2}, global_step=e) 492 | sw.add_scalar('loss', {print_name: total_loss}, global_step=e) 493 | 494 | return total_loss 495 | 496 | if __name__ == "__main__": 497 | parser = argparse.ArgumentParser() 498 | parser.add_argument("-g", "--gpu_count", default=4, 499 | help="Number of GPUs to use") 500 | 501 | parser.add_argument("-b", "--expand_bb_scale", default=0.05, 502 | help="Scale to expand the bounding box") 503 | parser.add_argument("-m", "--min_c", default=0.01, 504 | help="Minimum probability to be considered a bounding box (used in box_nms)") 505 | parser.add_argument("-o", "--overlap_thres", default=0.1, 506 | help="Maximum overlap between bounding boxes") 507 | parser.add_argument("-t", "--topk", default=150, 508 | help="Maximum number of bounding boxes on one slide") 509 | 510 | parser.add_argument("-e", "--epochs", default=351, 511 | help="Number of epochs to run") 512 | parser.add_argument("-l", "--learning_rate", default=0.0001, 513 | help="Learning rate for training") 514 | parser.add_argument("-s", "--batch_size", default=32, 515 | help="Batch size") 516 | parser.add_argument("-w", "--image_size", default=350, 517 | help="Size of the input image (w and h), the value must be less than 700 pixels ") 518 | 519 | parser.add_argument("-x", "--random_x_translation", default=0.03, 520 | help="Randomly translation the image in the x direction (+ or -)") 521 | parser.add_argument("-y", "--random_y_translation", default=0.03, 522 | help="Randomly translation the image in the y direction (+ or -)") 523 | parser.add_argument("-r", "--random_remove_box", default=0.15, 524 | help="Randomly remove bounding boxes and texts with a probability of r") 525 | 526 | parser.add_argument("-d", "--log_dir", default="./logs", 527 | help="Directory to store the log files") 528 | parser.add_argument("-c", "--checkpoint_dir", default="model_checkpoint", 529 | help="Directory to store the checkpoints") 530 | parser.add_argument("-n", "--checkpoint_name", default="ssd.params", 531 | help="Name to store the checkpoints") 532 | parser.add_argument("-db", "--detection_box", default="word", 533 | help="word or line") 534 | parser.add_argument("-p", "--load_model", default=None, 535 | help="Model to load from") 536 | 537 | args = parser.parse_args() 538 | 539 | print(args) 540 | 541 | gpu_count = int(args.gpu_count) 542 | 543 | ctx = [mx.gpu(i) for i in range(gpu_count)] 544 | 545 | expand_bb_scale = float(args.expand_bb_scale) 546 | min_c = float(args.min_c) 547 | overlap_thres = float(args.overlap_thres) 548 | topk = int(args.topk) 549 | 550 | epochs = int(args.epochs) 551 | learning_rate = float(args.learning_rate) 552 | batch_size = int(args.batch_size) * len(ctx) 553 | image_size = int(args.image_size) 554 | 555 | random_y_translation, random_x_translation = float(args.random_x_translation), float(args.random_y_translation) 556 | random_remove_box = float(args.random_remove_box) 557 | 558 | log_dir = args.log_dir 559 | load_model = args.load_model 560 | detection_box = args.detection_box 561 | checkpoint_dir, checkpoint_name = args.checkpoint_dir, detection_box+"_"+args.checkpoint_name 562 | 563 | train_ds = IAMDataset("form_bb", output_data="bb", output_parse_method=detection_box, train=True) 564 | print("Number of training samples: {}".format(len(train_ds))) 565 | 566 | test_ds = IAMDataset("form_bb", output_data="bb", output_parse_method=detection_box, train=False) 567 | print("Number of testing samples: {}".format(len(test_ds))) 568 | 569 | train_data = gluon.data.DataLoader(train_ds.transform(augment_transform), batch_size, shuffle=True, last_batch="rollover", num_workers=multiprocessing.cpu_count()-4) 570 | test_data = gluon.data.DataLoader(test_ds.transform(transform), batch_size, shuffle=False, last_batch="keep", num_workers=multiprocessing.cpu_count()-4) 571 | 572 | net = SSD(2, ctx=ctx) 573 | net.hybridize() 574 | if load_model is not None: 575 | net.load_parameters(os.path.join(checkpoint_dir, load_model)) 576 | 577 | trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': learning_rate, }) 578 | 579 | cls_loss = gluon.loss.SoftmaxCrossEntropyLoss() 580 | 581 | box_loss = SmoothL1Loss() 582 | 583 | best_test_loss = 10e5 584 | for e in range(epochs): 585 | cls_metric = mx.metric.Accuracy() 586 | box_metric = mx.metric.MAE() 587 | train_loss = run_epoch(e, net, train_data, trainer, log_dir, print_name="train", is_train=True, update_metric=False) 588 | test_loss = run_epoch(e, net, test_data, trainer, log_dir, print_name="test", is_train=False, update_metric=True) 589 | if test_loss < best_test_loss: 590 | print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss)) 591 | net.save_parameters(os.path.join(checkpoint_dir, checkpoint_name)) 592 | best_test_loss = test_loss 593 | 594 | if e % print_every_n == 0: 595 | name1, val1 = cls_metric.get() 596 | name2, val2 = box_metric.get() 597 | print("Epoch {0}, train_loss {1:.6f}, test_loss {2:.6f}, test {3}={4:.6f}, {5}={6:.6f}".format(e, train_loss, test_loss, name1, val1, name2, val2)) 598 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | leven==1.0.4 2 | mxnet==1.4.0 3 | gluonnlp==0.9.0 4 | protobuf==3.8.0 5 | mxboard==0.1.0 6 | sacremoses==0.0.43 --------------------------------------------------------------------------------