├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── corpus.txt ├── line.png └── word.png ├── doc ├── decoder_comparison.png ├── graphics.svg └── htr.png ├── model ├── .gitignore └── wordCharList.txt ├── requirements.txt └── src ├── create_lmdb.py ├── dataloader_iam.py ├── main.py ├── model.py └── preprocessor.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/words* 2 | data/words.txt 3 | src/__pycache__/ 4 | notes/ 5 | *.so 6 | *.pyc 7 | .idea/ 8 | dump/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Harald Scheidl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten Text Recognition with TensorFlow 2 | 3 | * **Update 2023/2: a [web demo](https://githubharald.github.io/text_reader.html) is available** 4 | * **Update 2023/1: see [HTRPipeline](https://github.com/githubharald/HTRPipeline) for a package to read full pages** 5 | * **Update 2021/2: recognize text on line level (multiple words)** 6 | * **Update 2021/1: more robust model, faster dataloader, word beam search decoder also available for Windows** 7 | * **Update 2020: code is compatible with TF2** 8 | 9 | 10 | Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset. 11 | The model takes **images of single words or text lines (multiple words) as input** and **outputs the recognized text**. 12 | 3/4 of the words from the validation-set are correctly recognized, and the character error rate is around 10%. 13 | 14 | ![htr](./doc/htr.png) 15 | 16 | 17 | ## Run demo 18 | 19 | * Download one of the pretrained models 20 | * [Model trained on word images](https://www.dropbox.com/s/mya8hw6jyzqm0a3/word-model.zip?dl=1): 21 | only handles single words per image, but gives better results on the IAM word dataset 22 | * [Model trained on text line images](https://www.dropbox.com/s/7xwkcilho10rthn/line-model.zip?dl=1): 23 | can handle multiple words in one image 24 | * Put the contents of the downloaded zip-file into the `model` directory of the repository 25 | * Go to the `src` directory 26 | * Run inference code: 27 | * Execute `python main.py` to run the model on an image of a word 28 | * Execute `python main.py --img_file ../data/line.png` to run the model on an image of a text line 29 | 30 | The input images, and the expected outputs are shown below when the text line model is used. 31 | 32 | ![test](./data/word.png) 33 | ``` 34 | > python main.py 35 | Init with stored values from ../model/snapshot-13 36 | Recognized: "word" 37 | Probability: 0.9806370139122009 38 | ``` 39 | 40 | ![test](./data/line.png) 41 | 42 | ``` 43 | > python main.py --img_file ../data/line.png 44 | Init with stored values from ../model/snapshot-13 45 | Recognized: "or work on line level" 46 | Probability: 0.6674373149871826 47 | ``` 48 | 49 | ## Command line arguments 50 | * `--mode`: select between "train", "validate" and "infer". Defaults to "infer". 51 | * `--decoder`: select from CTC decoders "bestpath", "beamsearch" and "wordbeamsearch". Defaults to "bestpath". For option "wordbeamsearch" see details below. 52 | * `--batch_size`: batch size. 53 | * `--data_dir`: directory containing IAM dataset (with subdirectories `img` and `gt`). 54 | * `--fast`: use LMDB to load images faster. 55 | * `--line_mode`: train reading text lines instead of single words. 56 | * `--img_file`: image that is used for inference. 57 | * `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder). 58 | 59 | 60 | ## Integrate word beam search decoding 61 | 62 | The [word beam search decoder](https://repositum.tuwien.ac.at/obvutwoa/download/pdf/2774578) can be used instead of the two decoders shipped with TF. 63 | Words are constrained to those contained in a dictionary, but arbitrary non-word character strings (numbers, punctuation marks) can still be recognized. 64 | The following illustration shows a sample for which word beam search is able to recognize the correct text, while the other decoders fail. 65 | 66 | ![decoder_comparison](./doc/decoder_comparison.png) 67 | 68 | Follow these instructions to integrate word beam search decoding: 69 | 70 | 1. Clone repository [CTCWordBeamSearch](https://github.com/githubharald/CTCWordBeamSearch) 71 | 2. Compile and install by running `pip install .` at the root level of the CTCWordBeamSearch repository 72 | 3. Specify the command line option `--decoder wordbeamsearch` when executing `main.py` to actually use the decoder 73 | 74 | The dictionary is automatically created in training and validation mode by using all words contained in the IAM dataset (i.e. also including words from validation set) and is saved into the file `data/corpus.txt`. 75 | Further, the manually created list of word-characters can be found in the file `model/wordCharList.txt`. 76 | Beam width is set to 50 to conform with the beam width of vanilla beam search decoding. 77 | 78 | 79 | ## Train model on IAM dataset 80 | 81 | ### Prepare dataset 82 | Follow these instructions to get the IAM dataset: 83 | 84 | * Register for free at this [website](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database) 85 | * Download `words/words.tgz` 86 | * Download `ascii/words.txt` 87 | * Create a directory for the dataset on your disk, and create two subdirectories: `img` and `gt` 88 | * Put `words.txt` into the `gt` directory 89 | * Put the content (directories `a01`, `a02`, ...) of `words.tgz` into the `img` directory 90 | 91 | ### Run training 92 | 93 | * Delete files from `model` directory if you want to train from scratch 94 | * Go to the `src` directory and execute `python main.py --mode train --data_dir path/to/IAM` 95 | * The IAM dataset is split into 95% training data and 5% validation data 96 | * If the option `--line_mode` is specified, 97 | the model is trained on text line images created by combining multiple word images into one 98 | * Training stops after a fixed number of epochs without improvement 99 | 100 | The pretrained word model was trained with this command on a GTX 1050 Ti: 101 | ``` 102 | python main.py --mode train --fast --data_dir path/to/iam --batch_size 500 --early_stopping 15 103 | ``` 104 | 105 | And the line model with: 106 | ``` 107 | python main.py --mode train --fast --data_dir path/to/iam --batch_size 250 --early_stopping 10 108 | ``` 109 | 110 | 111 | ### Fast image loading 112 | Loading and decoding the png image files from the disk is the bottleneck even when using only a small GPU. 113 | The database LMDB is used to speed up image loading: 114 | * Go to the `src` directory and run `create_lmdb.py --data_dir path/to/iam` with the IAM data directory specified 115 | * A subfolder `lmdb` is created in the IAM data directory containing the LMDB files 116 | * When training the model, add the command line option `--fast` 117 | 118 | The dataset should be located on an SSD drive. 119 | Using the `--fast` option and a GTX 1050 Ti training on single words takes around 3h with a batch size of 500. 120 | Training on text lines takes a bit longer. 121 | 122 | 123 | ## Information about model 124 | 125 | The model is a stripped-down version of the HTR system I implemented for [my thesis]((https://repositum.tuwien.ac.at/obvutwhs/download/pdf/2874742)). 126 | What remains is the bare minimum to recognize text with an acceptable accuracy. 127 | It consists of 5 CNN layers, 2 RNN (LSTM) layers and the CTC loss and decoding layer. 128 | For more details see this [Medium article](https://towardsdatascience.com/2326a3487cd5). 129 | 130 | 131 | ## References 132 | * [Build a Handwritten Text Recognition System using TensorFlow](https://towardsdatascience.com/2326a3487cd5) 133 | * [Scheidl - Handwritten Text Recognition in Historical Documents](https://repositum.tuwien.ac.at/obvutwhs/download/pdf/2874742) 134 | * [Scheidl - Word Beam Search: A Connectionist Temporal Classification Decoding Algorithm](https://repositum.tuwien.ac.at/obvutwoa/download/pdf/2774578) 135 | 136 | -------------------------------------------------------------------------------- /data/line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/SimpleHTR/63424013cfd8cc377e45a257a2cb62add6b38b87/data/line.png -------------------------------------------------------------------------------- /data/word.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/SimpleHTR/63424013cfd8cc377e45a257a2cb62add6b38b87/data/word.png -------------------------------------------------------------------------------- /doc/decoder_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/SimpleHTR/63424013cfd8cc377e45a257a2cb62add6b38b87/doc/decoder_comparison.png -------------------------------------------------------------------------------- /doc/htr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/SimpleHTR/63424013cfd8cc377e45a257a2cb62add6b38b87/doc/htr.png -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file and wordCharList.txt 4 | !.gitignore 5 | wordCharList.txt -------------------------------------------------------------------------------- /model/wordCharList.txt: -------------------------------------------------------------------------------- 1 | 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance==0.5.2 2 | lmdb==1.0.0 3 | matplotlib==3.2.1 4 | numpy==1.19.5 5 | opencv-python==4.4.0.46 6 | path==15.0.0 7 | tensorflow==2.4.0 -------------------------------------------------------------------------------- /src/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import cv2 5 | import lmdb 6 | from path import Path 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--data_dir', type=Path, required=True) 10 | args = parser.parse_args() 11 | 12 | # 2GB is enough for IAM dataset 13 | assert not (args.data_dir / 'lmdb').exists() 14 | env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2) 15 | 16 | # go over all png files 17 | fn_imgs = list((args.data_dir / 'img').walkfiles('*.png')) 18 | 19 | # and put the imgs into lmdb as pickled grayscale imgs 20 | with env.begin(write=True) as txn: 21 | for i, fn_img in enumerate(fn_imgs): 22 | print(i, len(fn_imgs)) 23 | img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE) 24 | basename = fn_img.basename() 25 | txn.put(basename.encode("ascii"), pickle.dumps(img)) 26 | 27 | env.close() 28 | -------------------------------------------------------------------------------- /src/dataloader_iam.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | from collections import namedtuple 4 | from typing import Tuple 5 | 6 | import cv2 7 | import lmdb 8 | import numpy as np 9 | from path import Path 10 | 11 | Sample = namedtuple('Sample', 'gt_text, file_path') 12 | Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size') 13 | 14 | 15 | class DataLoaderIAM: 16 | """ 17 | Loads data which corresponds to IAM format, 18 | see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database 19 | """ 20 | 21 | def __init__(self, 22 | data_dir: Path, 23 | batch_size: int, 24 | data_split: float = 0.95, 25 | fast: bool = True) -> None: 26 | """Loader for dataset.""" 27 | 28 | assert data_dir.exists() 29 | 30 | self.fast = fast 31 | if fast: 32 | self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) 33 | 34 | self.data_augmentation = False 35 | self.curr_idx = 0 36 | self.batch_size = batch_size 37 | self.samples = [] 38 | 39 | f = open(data_dir / 'gt/words.txt') 40 | chars = set() 41 | bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset 42 | for line in f: 43 | # ignore empty and comment lines 44 | line = line.strip() 45 | if not line or line[0] == '#': 46 | continue 47 | 48 | line_split = line.split(' ') 49 | assert len(line_split) >= 9 50 | 51 | # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png 52 | file_name_split = line_split[0].split('-') 53 | file_name_subdir1 = file_name_split[0] 54 | file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}' 55 | file_base_name = line_split[0] + '.png' 56 | file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name 57 | 58 | if line_split[0] in bad_samples_reference: 59 | print('Ignoring known broken image:', file_name) 60 | continue 61 | 62 | # GT text are columns starting at 9 63 | gt_text = ' '.join(line_split[8:]) 64 | chars = chars.union(set(list(gt_text))) 65 | 66 | # put sample into list 67 | self.samples.append(Sample(gt_text, file_name)) 68 | 69 | # split into training and validation set: 95% - 5% 70 | split_idx = int(data_split * len(self.samples)) 71 | self.train_samples = self.samples[:split_idx] 72 | self.validation_samples = self.samples[split_idx:] 73 | 74 | # put words into lists 75 | self.train_words = [x.gt_text for x in self.train_samples] 76 | self.validation_words = [x.gt_text for x in self.validation_samples] 77 | 78 | # start with train set 79 | self.train_set() 80 | 81 | # list of all chars in dataset 82 | self.char_list = sorted(list(chars)) 83 | 84 | def train_set(self) -> None: 85 | """Switch to randomly chosen subset of training set.""" 86 | self.data_augmentation = True 87 | self.curr_idx = 0 88 | random.shuffle(self.train_samples) 89 | self.samples = self.train_samples 90 | self.curr_set = 'train' 91 | 92 | def validation_set(self) -> None: 93 | """Switch to validation set.""" 94 | self.data_augmentation = False 95 | self.curr_idx = 0 96 | self.samples = self.validation_samples 97 | self.curr_set = 'val' 98 | 99 | def get_iterator_info(self) -> Tuple[int, int]: 100 | """Current batch index and overall number of batches.""" 101 | if self.curr_set == 'train': 102 | num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches 103 | else: 104 | num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller 105 | curr_batch = self.curr_idx // self.batch_size + 1 106 | return curr_batch, num_batches 107 | 108 | def has_next(self) -> bool: 109 | """Is there a next element?""" 110 | if self.curr_set == 'train': 111 | return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches 112 | else: 113 | return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller 114 | 115 | def _get_img(self, i: int) -> np.ndarray: 116 | if self.fast: 117 | with self.env.begin() as txn: 118 | basename = Path(self.samples[i].file_path).basename() 119 | data = txn.get(basename.encode("ascii")) 120 | img = pickle.loads(data) 121 | else: 122 | img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) 123 | 124 | return img 125 | 126 | def get_next(self) -> Batch: 127 | """Get next element.""" 128 | batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) 129 | 130 | imgs = [self._get_img(i) for i in batch_range] 131 | gt_texts = [self.samples[i].gt_text for i in batch_range] 132 | 133 | self.curr_idx += self.batch_size 134 | return Batch(imgs, gt_texts, len(imgs)) 135 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import Tuple, List 4 | 5 | import cv2 6 | import editdistance 7 | from path import Path 8 | 9 | from dataloader_iam import DataLoaderIAM, Batch 10 | from model import Model, DecoderType 11 | from preprocessor import Preprocessor 12 | 13 | 14 | class FilePaths: 15 | """Filenames and paths to data.""" 16 | fn_char_list = '../model/charList.txt' 17 | fn_summary = '../model/summary.json' 18 | fn_corpus = '../data/corpus.txt' 19 | 20 | 21 | def get_img_height() -> int: 22 | """Fixed height for NN.""" 23 | return 32 24 | 25 | 26 | def get_img_size(line_mode: bool = False) -> Tuple[int, int]: 27 | """Height is fixed for NN, width is set according to training mode (single words or text lines).""" 28 | if line_mode: 29 | return 256, get_img_height() 30 | return 128, get_img_height() 31 | 32 | 33 | def write_summary(average_train_loss: List[float], char_error_rates: List[float], word_accuracies: List[float]) -> None: 34 | """Writes training summary file for NN.""" 35 | with open(FilePaths.fn_summary, 'w') as f: 36 | json.dump({'averageTrainLoss': average_train_loss, 'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) 37 | 38 | 39 | def char_list_from_file() -> List[str]: 40 | with open(FilePaths.fn_char_list) as f: 41 | return list(f.read()) 42 | 43 | 44 | def train(model: Model, 45 | loader: DataLoaderIAM, 46 | line_mode: bool, 47 | early_stopping: int = 25) -> None: 48 | """Trains NN.""" 49 | epoch = 0 # number of training epochs since start 50 | summary_char_error_rates = [] 51 | summary_word_accuracies = [] 52 | 53 | train_loss_in_epoch = [] 54 | average_train_loss = [] 55 | 56 | preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) 57 | best_char_error_rate = float('inf') # best validation character error rate 58 | no_improvement_since = 0 # number of epochs no improvement of character error rate occurred 59 | # stop training after this number of epochs without improvement 60 | while True: 61 | epoch += 1 62 | print('Epoch:', epoch) 63 | 64 | # train 65 | print('Train NN') 66 | loader.train_set() 67 | while loader.has_next(): 68 | iter_info = loader.get_iterator_info() 69 | batch = loader.get_next() 70 | batch = preprocessor.process_batch(batch) 71 | loss = model.train_batch(batch) 72 | print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}') 73 | train_loss_in_epoch.append(loss) 74 | 75 | # validate 76 | char_error_rate, word_accuracy = validate(model, loader, line_mode) 77 | 78 | # write summary 79 | summary_char_error_rates.append(char_error_rate) 80 | summary_word_accuracies.append(word_accuracy) 81 | average_train_loss.append((sum(train_loss_in_epoch)) / len(train_loss_in_epoch)) 82 | write_summary(average_train_loss, summary_char_error_rates, summary_word_accuracies) 83 | 84 | # reset train loss list 85 | train_loss_in_epoch = [] 86 | 87 | # if best validation accuracy so far, save model parameters 88 | if char_error_rate < best_char_error_rate: 89 | print('Character error rate improved, save model') 90 | best_char_error_rate = char_error_rate 91 | no_improvement_since = 0 92 | model.save() 93 | else: 94 | print(f'Character error rate not improved, best so far: {best_char_error_rate * 100.0}%') 95 | no_improvement_since += 1 96 | 97 | # stop training if no more improvement in the last x epochs 98 | if no_improvement_since >= early_stopping: 99 | print(f'No more improvement for {early_stopping} epochs. Training stopped.') 100 | break 101 | 102 | 103 | def validate(model: Model, loader: DataLoaderIAM, line_mode: bool) -> Tuple[float, float]: 104 | """Validates NN.""" 105 | print('Validate NN') 106 | loader.validation_set() 107 | preprocessor = Preprocessor(get_img_size(line_mode), line_mode=line_mode) 108 | num_char_err = 0 109 | num_char_total = 0 110 | num_word_ok = 0 111 | num_word_total = 0 112 | while loader.has_next(): 113 | iter_info = loader.get_iterator_info() 114 | print(f'Batch: {iter_info[0]} / {iter_info[1]}') 115 | batch = loader.get_next() 116 | batch = preprocessor.process_batch(batch) 117 | recognized, _ = model.infer_batch(batch) 118 | 119 | print('Ground truth -> Recognized') 120 | for i in range(len(recognized)): 121 | num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0 122 | num_word_total += 1 123 | dist = editdistance.eval(recognized[i], batch.gt_texts[i]) 124 | num_char_err += dist 125 | num_char_total += len(batch.gt_texts[i]) 126 | print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->', 127 | '"' + recognized[i] + '"') 128 | 129 | # print validation result 130 | char_error_rate = num_char_err / num_char_total 131 | word_accuracy = num_word_ok / num_word_total 132 | print(f'Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.') 133 | return char_error_rate, word_accuracy 134 | 135 | 136 | def infer(model: Model, fn_img: Path) -> None: 137 | """Recognizes text in image provided by file path.""" 138 | img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE) 139 | assert img is not None 140 | 141 | preprocessor = Preprocessor(get_img_size(), dynamic_width=True, padding=16) 142 | img = preprocessor.process_img(img) 143 | 144 | batch = Batch([img], None, 1) 145 | recognized, probability = model.infer_batch(batch, True) 146 | print(f'Recognized: "{recognized[0]}"') 147 | print(f'Probability: {probability[0]}') 148 | 149 | 150 | def parse_args() -> argparse.Namespace: 151 | """Parses arguments from the command line.""" 152 | parser = argparse.ArgumentParser() 153 | 154 | parser.add_argument('--mode', choices=['train', 'validate', 'infer'], default='infer') 155 | parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath') 156 | parser.add_argument('--batch_size', help='Batch size.', type=int, default=100) 157 | parser.add_argument('--data_dir', help='Directory containing IAM dataset.', type=Path, required=False) 158 | parser.add_argument('--fast', help='Load samples from LMDB.', action='store_true') 159 | parser.add_argument('--line_mode', help='Train to read text lines instead of single words.', action='store_true') 160 | parser.add_argument('--img_file', help='Image used for inference.', type=Path, default='../data/word.png') 161 | parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25) 162 | parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true') 163 | 164 | return parser.parse_args() 165 | 166 | 167 | def main(): 168 | """Main function.""" 169 | 170 | # parse arguments and set CTC decoder 171 | args = parse_args() 172 | decoder_mapping = {'bestpath': DecoderType.BestPath, 173 | 'beamsearch': DecoderType.BeamSearch, 174 | 'wordbeamsearch': DecoderType.WordBeamSearch} 175 | decoder_type = decoder_mapping[args.decoder] 176 | 177 | # train the model 178 | if args.mode == 'train': 179 | loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast) 180 | 181 | # when in line mode, take care to have a whitespace in the char list 182 | char_list = loader.char_list 183 | if args.line_mode and ' ' not in char_list: 184 | char_list = [' '] + char_list 185 | 186 | # save characters and words 187 | with open(FilePaths.fn_char_list, 'w') as f: 188 | f.write(''.join(char_list)) 189 | 190 | with open(FilePaths.fn_corpus, 'w') as f: 191 | f.write(' '.join(loader.train_words + loader.validation_words)) 192 | 193 | model = Model(char_list, decoder_type) 194 | train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping) 195 | 196 | # evaluate it on the validation set 197 | elif args.mode == 'validate': 198 | loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast) 199 | model = Model(char_list_from_file(), decoder_type, must_restore=True) 200 | validate(model, loader, args.line_mode) 201 | 202 | # infer text on test image 203 | elif args.mode == 'infer': 204 | model = Model(char_list_from_file(), decoder_type, must_restore=True, dump=args.dump) 205 | infer(model, args.img_file) 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from dataloader_iam import Batch 9 | 10 | # Disable eager mode 11 | tf.compat.v1.disable_eager_execution() 12 | 13 | 14 | class DecoderType: 15 | """CTC decoder types.""" 16 | BestPath = 0 17 | BeamSearch = 1 18 | WordBeamSearch = 2 19 | 20 | 21 | class Model: 22 | """Minimalistic TF model for HTR.""" 23 | 24 | def __init__(self, 25 | char_list: List[str], 26 | decoder_type: str = DecoderType.BestPath, 27 | must_restore: bool = False, 28 | dump: bool = False) -> None: 29 | """Init model: add CNN, RNN and CTC and initialize TF.""" 30 | self.dump = dump 31 | self.char_list = char_list 32 | self.decoder_type = decoder_type 33 | self.must_restore = must_restore 34 | self.snap_ID = 0 35 | 36 | # Whether to use normalization over a batch or a population 37 | self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') 38 | 39 | # input image batch 40 | self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None)) 41 | 42 | # setup CNN, RNN and CTC 43 | self.setup_cnn() 44 | self.setup_rnn() 45 | self.setup_ctc() 46 | 47 | # setup optimizer to train NN 48 | self.batches_trained = 0 49 | self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) 50 | with tf.control_dependencies(self.update_ops): 51 | self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) 52 | 53 | # initialize TF 54 | self.sess, self.saver = self.setup_tf() 55 | 56 | def setup_cnn(self) -> None: 57 | """Create CNN layers.""" 58 | cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3) 59 | 60 | # list of parameters for the layers 61 | kernel_vals = [5, 5, 3, 3, 3] 62 | feature_vals = [1, 32, 64, 128, 128, 256] 63 | stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] 64 | num_layers = len(stride_vals) 65 | 66 | # create layers 67 | pool = cnn_in4d # input to first CNN layer 68 | for i in range(num_layers): 69 | kernel = tf.Variable( 70 | tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]], 71 | stddev=0.1)) 72 | conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) 73 | conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) 74 | relu = tf.nn.relu(conv_norm) 75 | pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1), 76 | strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID') 77 | 78 | self.cnn_out_4d = pool 79 | 80 | def setup_rnn(self) -> None: 81 | """Create RNN layers.""" 82 | rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2]) 83 | 84 | # basic cells which is used to build RNN 85 | num_hidden = 256 86 | cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in 87 | range(2)] # 2 layers 88 | 89 | # stack basic cells 90 | stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) 91 | 92 | # bidirectional RNN 93 | # BxTxF -> BxTx2H 94 | (fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d, 95 | dtype=rnn_in3d.dtype) 96 | 97 | # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H 98 | concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) 99 | 100 | # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC 101 | kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1)) 102 | self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), 103 | axis=[2]) 104 | 105 | def setup_ctc(self) -> None: 106 | """Create CTC loss and decoder.""" 107 | # BxTxC -> TxBxC 108 | self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2]) 109 | # ground truth text as sparse tensor 110 | self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), 111 | tf.compat.v1.placeholder(tf.int32, [None]), 112 | tf.compat.v1.placeholder(tf.int64, [2])) 113 | 114 | # calc loss for batch 115 | self.seq_len = tf.compat.v1.placeholder(tf.int32, [None]) 116 | self.loss = tf.reduce_mean( 117 | input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc, 118 | sequence_length=self.seq_len, 119 | ctc_merge_repeated=True)) 120 | 121 | # calc loss for each element to compute label probability 122 | self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32, 123 | shape=[None, None, len(self.char_list) + 1]) 124 | self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input, 125 | sequence_length=self.seq_len, ctc_merge_repeated=True) 126 | 127 | # best path decoding or beam search decoding 128 | if self.decoder_type == DecoderType.BestPath: 129 | self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len) 130 | elif self.decoder_type == DecoderType.BeamSearch: 131 | self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len, 132 | beam_width=50) 133 | # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch) 134 | elif self.decoder_type == DecoderType.WordBeamSearch: 135 | # prepare information about language (dictionary, characters in dataset, characters forming words) 136 | chars = ''.join(self.char_list) 137 | word_chars = open('../model/wordCharList.txt').read().splitlines()[0] 138 | corpus = open('../data/corpus.txt').read() 139 | 140 | # decode using the "Words" mode of word beam search 141 | from word_beam_search import WordBeamSearch 142 | self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), 143 | word_chars.encode('utf8')) 144 | 145 | # the input to the decoder must have softmax already applied 146 | self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2) 147 | 148 | def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]: 149 | """Initialize TF.""" 150 | print('Python: ' + sys.version) 151 | print('Tensorflow: ' + tf.__version__) 152 | 153 | sess = tf.compat.v1.Session() # TF session 154 | 155 | saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file 156 | model_dir = '../model/' 157 | latest_snapshot = tf.train.latest_checkpoint(model_dir) # is there a saved model? 158 | 159 | # if model must be restored (for inference), there must be a snapshot 160 | if self.must_restore and not latest_snapshot: 161 | raise Exception('No saved model found in: ' + model_dir) 162 | 163 | # load saved model if available 164 | if latest_snapshot: 165 | print('Init with stored values from ' + latest_snapshot) 166 | saver.restore(sess, latest_snapshot) 167 | else: 168 | print('Init with new values') 169 | sess.run(tf.compat.v1.global_variables_initializer()) 170 | 171 | return sess, saver 172 | 173 | def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]: 174 | """Put ground truth texts into sparse tensor for ctc_loss.""" 175 | indices = [] 176 | values = [] 177 | shape = [len(texts), 0] # last entry must be max(labelList[i]) 178 | 179 | # go over all texts 180 | for batchElement, text in enumerate(texts): 181 | # convert to string of label (i.e. class-ids) 182 | label_str = [self.char_list.index(c) for c in text] 183 | # sparse tensor must have size of max. label-string 184 | if len(label_str) > shape[1]: 185 | shape[1] = len(label_str) 186 | # put each label into sparse tensor 187 | for i, label in enumerate(label_str): 188 | indices.append([batchElement, i]) 189 | values.append(label) 190 | 191 | return indices, values, shape 192 | 193 | def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]: 194 | """Extract texts from output of CTC decoder.""" 195 | 196 | # word beam search: already contains label strings 197 | if self.decoder_type == DecoderType.WordBeamSearch: 198 | label_strs = ctc_output 199 | 200 | # TF decoders: label strings are contained in sparse tensor 201 | else: 202 | # ctc returns tuple, first element is SparseTensor 203 | decoded = ctc_output[0][0] 204 | 205 | # contains string of labels for each batch element 206 | label_strs = [[] for _ in range(batch_size)] 207 | 208 | # go over all indices and save mapping: batch -> values 209 | for (idx, idx2d) in enumerate(decoded.indices): 210 | label = decoded.values[idx] 211 | batch_element = idx2d[0] # index according to [b,t] 212 | label_strs[batch_element].append(label) 213 | 214 | # map labels to chars for all batch elements 215 | return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs] 216 | 217 | def train_batch(self, batch: Batch) -> float: 218 | """Feed a batch into the NN to train it.""" 219 | num_batch_elements = len(batch.imgs) 220 | max_text_len = batch.imgs[0].shape[0] // 4 221 | sparse = self.to_sparse(batch.gt_texts) 222 | eval_list = [self.optimizer, self.loss] 223 | feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse, 224 | self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True} 225 | _, loss_val = self.sess.run(eval_list, feed_dict) 226 | self.batches_trained += 1 227 | return loss_val 228 | 229 | @staticmethod 230 | def dump_nn_output(rnn_output: np.ndarray) -> None: 231 | """Dump the output of the NN to CSV file(s).""" 232 | dump_dir = '../dump/' 233 | if not os.path.isdir(dump_dir): 234 | os.mkdir(dump_dir) 235 | 236 | # iterate over all batch elements and create a CSV file for each one 237 | max_t, max_b, max_c = rnn_output.shape 238 | for b in range(max_b): 239 | csv = '' 240 | for t in range(max_t): 241 | csv += ';'.join([str(rnn_output[t, b, c]) for c in range(max_c)]) + ';\n' 242 | fn = dump_dir + 'rnnOutput_' + str(b) + '.csv' 243 | print('Write dump of NN to file: ' + fn) 244 | with open(fn, 'w') as f: 245 | f.write(csv) 246 | 247 | def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False): 248 | """Feed a batch into the NN to recognize the texts.""" 249 | 250 | # decode, optionally save RNN output 251 | num_batch_elements = len(batch.imgs) 252 | 253 | # put tensors to be evaluated into list 254 | eval_list = [] 255 | 256 | if self.decoder_type == DecoderType.WordBeamSearch: 257 | eval_list.append(self.wbs_input) 258 | else: 259 | eval_list.append(self.decoder) 260 | 261 | if self.dump or calc_probability: 262 | eval_list.append(self.ctc_in_3d_tbc) 263 | 264 | # sequence length depends on input image size (model downsizes width by 4) 265 | max_text_len = batch.imgs[0].shape[0] // 4 266 | 267 | # dict containing all tensor fed into the model 268 | feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements, 269 | self.is_train: False} 270 | 271 | # evaluate model 272 | eval_res = self.sess.run(eval_list, feed_dict) 273 | 274 | # TF decoders: decoding already done in TF graph 275 | if self.decoder_type != DecoderType.WordBeamSearch: 276 | decoded = eval_res[0] 277 | # word beam search decoder: decoding is done in C++ function compute() 278 | else: 279 | decoded = self.decoder.compute(eval_res[0]) 280 | 281 | # map labels (numbers) to character string 282 | texts = self.decoder_output_to_text(decoded, num_batch_elements) 283 | 284 | # feed RNN output and recognized text into CTC loss to compute labeling probability 285 | probs = None 286 | if calc_probability: 287 | sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts) 288 | ctc_input = eval_res[1] 289 | eval_list = self.loss_per_element 290 | feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse, 291 | self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False} 292 | loss_vals = self.sess.run(eval_list, feed_dict) 293 | probs = np.exp(-loss_vals) 294 | 295 | # dump the output of the NN to CSV file(s) 296 | if self.dump: 297 | self.dump_nn_output(eval_res[1]) 298 | 299 | return texts, probs 300 | 301 | def save(self) -> None: 302 | """Save model to file.""" 303 | self.snap_ID += 1 304 | self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID) 305 | -------------------------------------------------------------------------------- /src/preprocessor.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from dataloader_iam import Batch 8 | 9 | 10 | class Preprocessor: 11 | def __init__(self, 12 | img_size: Tuple[int, int], 13 | padding: int = 0, 14 | dynamic_width: bool = False, 15 | data_augmentation: bool = False, 16 | line_mode: bool = False) -> None: 17 | # dynamic width only supported when no data augmentation happens 18 | assert not (dynamic_width and data_augmentation) 19 | # when padding is on, we need dynamic width enabled 20 | assert not (padding > 0 and not dynamic_width) 21 | 22 | self.img_size = img_size 23 | self.padding = padding 24 | self.dynamic_width = dynamic_width 25 | self.data_augmentation = data_augmentation 26 | self.line_mode = line_mode 27 | 28 | @staticmethod 29 | def _truncate_label(text: str, max_text_len: int) -> str: 30 | """ 31 | Function ctc_loss can't compute loss if it cannot find a mapping between text label and input 32 | labels. Repeat letters cost double because of the blank symbol needing to be inserted. 33 | If a too-long label is provided, ctc_loss returns an infinite gradient. 34 | """ 35 | cost = 0 36 | for i in range(len(text)): 37 | if i != 0 and text[i] == text[i - 1]: 38 | cost += 2 39 | else: 40 | cost += 1 41 | if cost > max_text_len: 42 | return text[:i] 43 | return text 44 | 45 | def _simulate_text_line(self, batch: Batch) -> Batch: 46 | """Create image of a text line by pasting multiple word images into an image.""" 47 | 48 | default_word_sep = 30 49 | default_num_words = 5 50 | 51 | # go over all batch elements 52 | res_imgs = [] 53 | res_gt_texts = [] 54 | for i in range(batch.batch_size): 55 | # number of words to put into current line 56 | num_words = random.randint(1, 8) if self.data_augmentation else default_num_words 57 | 58 | # concat ground truth texts 59 | curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)]) 60 | res_gt_texts.append(curr_gt) 61 | 62 | # put selected word images into list, compute target image size 63 | sel_imgs = [] 64 | word_seps = [0] 65 | h = 0 66 | w = 0 67 | for j in range(num_words): 68 | curr_sel_img = batch.imgs[(i + j) % batch.batch_size] 69 | curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep 70 | h = max(h, curr_sel_img.shape[0]) 71 | w += curr_sel_img.shape[1] 72 | sel_imgs.append(curr_sel_img) 73 | if j + 1 < num_words: 74 | w += curr_word_sep 75 | word_seps.append(curr_word_sep) 76 | 77 | # put all selected word images into target image 78 | target = np.ones([h, w], np.uint8) * 255 79 | x = 0 80 | for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps): 81 | x += curr_word_sep 82 | y = (h - curr_sel_img.shape[0]) // 2 83 | target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img 84 | x += curr_sel_img.shape[1] 85 | 86 | # put image of line into result 87 | res_imgs.append(target) 88 | 89 | return Batch(res_imgs, res_gt_texts, batch.batch_size) 90 | 91 | def process_img(self, img: np.ndarray) -> np.ndarray: 92 | """Resize to target size, apply data augmentation.""" 93 | 94 | # there are damaged files in IAM dataset - just use black image instead 95 | if img is None: 96 | img = np.zeros(self.img_size[::-1]) 97 | 98 | # data augmentation 99 | img = img.astype(np.float) 100 | if self.data_augmentation: 101 | # photometric data augmentation 102 | if random.random() < 0.25: 103 | def rand_odd(): 104 | return random.randint(1, 3) * 2 + 1 105 | img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) 106 | if random.random() < 0.25: 107 | img = cv2.dilate(img, np.ones((3, 3))) 108 | if random.random() < 0.25: 109 | img = cv2.erode(img, np.ones((3, 3))) 110 | 111 | # geometric data augmentation 112 | wt, ht = self.img_size 113 | h, w = img.shape 114 | f = min(wt / w, ht / h) 115 | fx = f * np.random.uniform(0.75, 1.05) 116 | fy = f * np.random.uniform(0.75, 1.05) 117 | 118 | # random position around center 119 | txc = (wt - w * fx) / 2 120 | tyc = (ht - h * fy) / 2 121 | freedom_x = max((wt - fx * w) / 2, 0) 122 | freedom_y = max((ht - fy * h) / 2, 0) 123 | tx = txc + np.random.uniform(-freedom_x, freedom_x) 124 | ty = tyc + np.random.uniform(-freedom_y, freedom_y) 125 | 126 | # map image into target image 127 | M = np.float32([[fx, 0, tx], [0, fy, ty]]) 128 | target = np.ones(self.img_size[::-1]) * 255 129 | img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) 130 | 131 | # photometric data augmentation 132 | if random.random() < 0.5: 133 | img = img * (0.25 + random.random() * 0.75) 134 | if random.random() < 0.25: 135 | img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255) 136 | if random.random() < 0.1: 137 | img = 255 - img 138 | 139 | # no data augmentation 140 | else: 141 | if self.dynamic_width: 142 | ht = self.img_size[1] 143 | h, w = img.shape 144 | f = ht / h 145 | wt = int(f * w + self.padding) 146 | wt = wt + (4 - wt) % 4 147 | tx = (wt - w * f) / 2 148 | ty = 0 149 | else: 150 | wt, ht = self.img_size 151 | h, w = img.shape 152 | f = min(wt / w, ht / h) 153 | tx = (wt - w * f) / 2 154 | ty = (ht - h * f) / 2 155 | 156 | # map image into target image 157 | M = np.float32([[f, 0, tx], [0, f, ty]]) 158 | target = np.ones([ht, wt]) * 255 159 | img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT) 160 | 161 | # transpose for TF 162 | img = cv2.transpose(img) 163 | 164 | # convert to range [-1, 1] 165 | img = img / 255 - 0.5 166 | return img 167 | 168 | def process_batch(self, batch: Batch) -> Batch: 169 | if self.line_mode: 170 | batch = self._simulate_text_line(batch) 171 | 172 | res_imgs = [self.process_img(img) for img in batch.imgs] 173 | max_text_len = res_imgs[0].shape[0] // 4 174 | res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts] 175 | return Batch(res_imgs, res_gt_texts, batch.batch_size) 176 | 177 | 178 | def main(): 179 | import matplotlib.pyplot as plt 180 | 181 | img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) 182 | img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img) 183 | plt.subplot(121) 184 | plt.imshow(img, cmap='gray') 185 | plt.subplot(122) 186 | plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1) 187 | plt.show() 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | --------------------------------------------------------------------------------