├── .gitignore ├── LICENSE.md ├── README.md ├── data └── test │ ├── bentham.jpg │ ├── cvl.jpg │ └── random.jpg ├── doc ├── aabbs.png └── seg.png ├── model └── .gitignore └── src ├── aabb.py ├── aabb_clustering.py ├── coding.py ├── dataloader.py ├── dataset.py ├── eval.py ├── infer.py ├── iou.py ├── loss.py ├── net.py ├── resnet.py ├── train.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | .idea 3 | __pycache__ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Harald Scheidl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handwritten Word Detector 2 | 3 | A neural network based detector for handwritten words. 4 | 5 | ## Run demo 6 | * Download [trained model](https://www.dropbox.com/s/mqhco2q67ovpfjq/model.zip?dl=1), and place the unzipped files into the `model` directory 7 | * Go to the `src` directory and execute `python infer.py` 8 | * This opens a window showing the words detected in the test images (located in `data/test`) 9 | * Required libs: torch, numpy, sklearn, cv2, path, matplotlib 10 | 11 | ![aabbs](./doc/aabbs.png) 12 | 13 | 14 | ## Train model 15 | ### Data 16 | * The model is trained with the [IAM dataset](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database) 17 | * Download the forms and the xml files 18 | * Create a dataset directory on your disk with two subdirectories: `gt` and `img` 19 | * Put all form images into the `img` directory 20 | * Put all xml files into the `gt` directory 21 | 22 | ### Start training 23 | * Go to `src` and execute `python train.py` with the following parameters specified (only the first one is required): 24 | * `--data_dir`: dataset directory containing a `gt` and an `img` directory 25 | * `--batch_size`: 27 images per batch are possible on a 8GB GPU 26 | * `--caching`: cache the dataset to avoid loading and decoding the png images, cache file is stored in the dataset directory 27 | * `--pretrained`: initialize with saved model weights 28 | * `--val_freq`: speed up training by only validating each n-th epoch 29 | * `--early_stopping`: stop training after n validation steps without improvement 30 | * The model weights are saved every time the f1 score on the validation set increases 31 | * A log is written into the `log` directory, which can be opened with tensorboard 32 | * Executing `python eval.py` evaluates the trained model 33 | 34 | 35 | ## Information about model 36 | * The model classifies each pixel into one of three classes (see plot below): 37 | * Inner part of a word (plot: red) 38 | * Outer part of a word (plot: green) 39 | * Background (plot: blue) 40 | * An axis-aligned bounding box is predicted for each inner-word pixel 41 | * DBSCAN clusters the predicted bounding boxes 42 | * The backbone of the neural network is based on the ResNet18 model (taken from torchvision, with modifications) 43 | * The model is inspired by the ideas of [Zhou](https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhou_EAST_An_Efficient_CVPR_2017_paper.pdf) and [Axler](http://www.cs.tau.ac.il/~wolf/papers/dataset-agnostic-word.pdf) 44 | * See [this article](https://githubharald.github.io/word_detector.html) for more details 45 | 46 | ![seg](./doc/seg.png) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /data/test/bentham.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/WordDetectorNN/a11f5bf4a69acc58dcf9c900d0caad17fb464d74/data/test/bentham.jpg -------------------------------------------------------------------------------- /data/test/cvl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/WordDetectorNN/a11f5bf4a69acc58dcf9c900d0caad17fb464d74/data/test/cvl.jpg -------------------------------------------------------------------------------- /data/test/random.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/WordDetectorNN/a11f5bf4a69acc58dcf9c900d0caad17fb464d74/data/test/random.jpg -------------------------------------------------------------------------------- /doc/aabbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/WordDetectorNN/a11f5bf4a69acc58dcf9c900d0caad17fb464d74/doc/aabbs.png -------------------------------------------------------------------------------- /doc/seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubharald/WordDetectorNN/a11f5bf4a69acc58dcf9c900d0caad17fb464d74/doc/seg.png -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /src/aabb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AABB: 5 | """axis aligned bounding box""" 6 | 7 | def __init__(self, xmin, xmax, ymin, ymax): 8 | self.xmin = xmin 9 | self.xmax = xmax 10 | self.ymin = ymin 11 | self.ymax = ymax 12 | 13 | def scale(self, fx, fy): 14 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 15 | new.xmin = fx * new.xmin 16 | new.xmax = fx * new.xmax 17 | new.ymin = fy * new.ymin 18 | new.ymax = fy * new.ymax 19 | return new 20 | 21 | def scale_around_center(self, fx, fy): 22 | cx = (self.xmin + self.xmax) / 2 23 | cy = (self.ymin + self.ymax) / 2 24 | 25 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 26 | new.xmin = cx - fx * (cx - self.xmin) 27 | new.xmax = cx + fx * (self.xmax - cx) 28 | new.ymin = cy - fy * (cy - self.ymin) 29 | new.ymax = cy + fy * (self.ymax - cy) 30 | return new 31 | 32 | def translate(self, tx, ty): 33 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 34 | new.xmin = new.xmin + tx 35 | new.xmax = new.xmax + tx 36 | new.ymin = new.ymin + ty 37 | new.ymax = new.ymax + ty 38 | return new 39 | 40 | def as_type(self, t): 41 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 42 | new.xmin = t(new.xmin) 43 | new.xmax = t(new.xmax) 44 | new.ymin = t(new.ymin) 45 | new.ymax = t(new.ymax) 46 | return new 47 | 48 | def enlarge_to_int_grid(self): 49 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 50 | new.xmin = np.floor(new.xmin) 51 | new.xmax = np.ceil(new.xmax) 52 | new.ymin = np.floor(new.ymin) 53 | new.ymax = np.ceil(new.ymax) 54 | return new 55 | 56 | def clip(self, clip_aabb): 57 | new = AABB(self.xmin, self.xmax, self.ymin, self.ymax) 58 | new.xmin = min(max(new.xmin, clip_aabb.xmin), clip_aabb.xmax) 59 | new.xmax = max(min(new.xmax, clip_aabb.xmax), clip_aabb.xmin) 60 | new.ymin = min(max(new.ymin, clip_aabb.ymin), clip_aabb.ymax) 61 | new.ymax = max(min(new.ymax, clip_aabb.ymax), clip_aabb.ymin) 62 | return new 63 | 64 | def area(self): 65 | return (self.xmax - self.xmin) * (self.ymax - self.ymin) 66 | 67 | def __str__(self): 68 | return f'AABB(xmin={self.xmin},xmax={self.xmax},ymin={self.ymin},ymax={self.ymax})' 69 | 70 | def __repr__(self): 71 | return str(self) 72 | -------------------------------------------------------------------------------- /src/aabb_clustering.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from sklearn.cluster import DBSCAN 5 | 6 | from aabb import AABB 7 | from iou import compute_dist_mat 8 | 9 | 10 | def cluster_aabbs(aabbs): 11 | """cluster aabbs using DBSCAN and the Jaccard distance between bounding boxes""" 12 | if len(aabbs) < 2: 13 | return aabbs 14 | 15 | dists = compute_dist_mat(aabbs) 16 | clustering = DBSCAN(eps=0.7, min_samples=3, metric='precomputed').fit(dists) 17 | 18 | clusters = defaultdict(list) 19 | for i, c in enumerate(clustering.labels_): 20 | if c == -1: 21 | continue 22 | clusters[c].append(aabbs[i]) 23 | 24 | res_aabbs = [] 25 | for curr_cluster in clusters.values(): 26 | xmin = np.median([aabb.xmin for aabb in curr_cluster]) 27 | xmax = np.median([aabb.xmax for aabb in curr_cluster]) 28 | ymin = np.median([aabb.ymin for aabb in curr_cluster]) 29 | ymax = np.median([aabb.ymax for aabb in curr_cluster]) 30 | res_aabbs.append(AABB(xmin, xmax, ymin, ymax)) 31 | 32 | return res_aabbs 33 | -------------------------------------------------------------------------------- /src/coding.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from aabb import AABB 5 | 6 | 7 | class MapOrdering: 8 | """order of the maps encoding the aabbs around the words""" 9 | SEG_WORD = 0 10 | SEG_SURROUNDING = 1 11 | SEG_BACKGROUND = 2 12 | GEO_TOP = 3 13 | GEO_BOTTOM = 4 14 | GEO_LEFT = 5 15 | GEO_RIGHT = 6 16 | NUM_MAPS = 7 17 | 18 | 19 | def encode(shape, gt, f=1.0): 20 | gt_map = np.zeros((MapOrdering.NUM_MAPS,) + shape) 21 | for aabb in gt: 22 | aabb = aabb.scale(f, f) 23 | 24 | # segmentation map 25 | aabb_clip = AABB(0, shape[0] - 1, 0, shape[1] - 1) 26 | 27 | aabb_word = aabb.scale_around_center(0.5, 0.5).as_type(int).clip(aabb_clip) 28 | aabb_sur = aabb.as_type(int).clip(aabb_clip) 29 | gt_map[MapOrdering.SEG_SURROUNDING, aabb_sur.ymin:aabb_sur.ymax + 1, aabb_sur.xmin:aabb_sur.xmax + 1] = 1 30 | gt_map[MapOrdering.SEG_SURROUNDING, aabb_word.ymin:aabb_word.ymax + 1, aabb_word.xmin:aabb_word.xmax + 1] = 0 31 | gt_map[MapOrdering.SEG_WORD, aabb_word.ymin:aabb_word.ymax + 1, aabb_word.xmin:aabb_word.xmax + 1] = 1 32 | 33 | # geometry map TODO vectorize 34 | for x in range(aabb_word.xmin, aabb_word.xmax + 1): 35 | for y in range(aabb_word.ymin, aabb_word.ymax + 1): 36 | gt_map[MapOrdering.GEO_TOP, y, x] = y - aabb.ymin 37 | gt_map[MapOrdering.GEO_BOTTOM, y, x] = aabb.ymax - y 38 | gt_map[MapOrdering.GEO_LEFT, y, x] = x - aabb.xmin 39 | gt_map[MapOrdering.GEO_RIGHT, y, x] = aabb.xmax - x 40 | 41 | gt_map[MapOrdering.SEG_BACKGROUND] = np.clip(1 - gt_map[MapOrdering.SEG_WORD] - gt_map[MapOrdering.SEG_SURROUNDING], 42 | 0, 1) 43 | 44 | return gt_map 45 | 46 | 47 | def subsample(idx, max_num): 48 | """restrict fg indices to a maximum number""" 49 | f = len(idx[0]) / max_num 50 | if f > 1: 51 | a = np.asarray([idx[0][int(j * f)] for j in range(max_num)], np.int64) 52 | b = np.asarray([idx[1][int(j * f)] for j in range(max_num)], np.int64) 53 | idx = (a, b) 54 | return idx 55 | 56 | 57 | def fg_by_threshold(thres, max_num=None): 58 | """all pixels above threshold are fg pixels, optionally limited to a maximum number""" 59 | 60 | def func(seg_map): 61 | idx = np.where(seg_map > thres) 62 | if max_num is not None: 63 | idx = subsample(idx, max_num) 64 | return idx 65 | 66 | return func 67 | 68 | 69 | def fg_by_cc(thres, max_num): 70 | """take a maximum number of pixels per connected component, but at least 3 (->DBSCAN minPts)""" 71 | 72 | def func(seg_map): 73 | seg_mask = (seg_map > thres).astype(np.uint8) 74 | num_labels, label_img = cv2.connectedComponents(seg_mask, connectivity=4) 75 | max_num_per_cc = max(max_num // (num_labels + 1), 3) # at least 3 because of DBSCAN clustering 76 | 77 | all_idx = [np.empty(0, np.int64), np.empty(0, np.int64)] 78 | for curr_label in range(1, num_labels): 79 | curr_idx = np.where(label_img == curr_label) 80 | curr_idx = subsample(curr_idx, max_num_per_cc) 81 | all_idx[0] = np.append(all_idx[0], curr_idx[0]) 82 | all_idx[1] = np.append(all_idx[1], curr_idx[1]) 83 | return tuple(all_idx) 84 | 85 | return func 86 | 87 | 88 | def decode(pred_map, comp_fg=fg_by_threshold(0.5), f=1): 89 | idx = comp_fg(pred_map[MapOrdering.SEG_WORD]) 90 | pred_map_masked = pred_map[..., idx[0], idx[1]] 91 | aabbs = [] 92 | for yc, xc, pred in zip(idx[0], idx[1], pred_map_masked.T): 93 | t = pred[MapOrdering.GEO_TOP] 94 | b = pred[MapOrdering.GEO_BOTTOM] 95 | l = pred[MapOrdering.GEO_LEFT] 96 | r = pred[MapOrdering.GEO_RIGHT] 97 | aabb = AABB(xc - l, xc + r, yc - t, yc + b) 98 | aabbs.append(aabb.scale(f, f)) 99 | return aabbs 100 | 101 | 102 | def main(): 103 | import matplotlib.pyplot as plt 104 | aabbs_in = [AABB(10, 30, 30, 60)] 105 | encoded = encode((50, 50), aabbs_in, f=0.5) 106 | aabbs_out = decode(encoded, f=2) 107 | print(aabbs_out[0]) 108 | plt.subplot(151) 109 | plt.imshow(encoded[MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1].transpose(1, 2, 0)) 110 | 111 | plt.subplot(152) 112 | plt.imshow(encoded[MapOrdering.GEO_TOP]) 113 | plt.subplot(153) 114 | plt.imshow(encoded[MapOrdering.GEO_BOTTOM]) 115 | plt.subplot(154) 116 | plt.imshow(encoded[MapOrdering.GEO_LEFT]) 117 | plt.subplot(155) 118 | plt.imshow(encoded[MapOrdering.GEO_RIGHT]) 119 | 120 | plt.show() 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from aabb import AABB 8 | from coding import encode 9 | from utils import compute_scale_down, prob_true 10 | 11 | DataLoaderItem = namedtuple('DataLoaderItem', 'batch_imgs,batch_gt_maps,batch_aabbs') 12 | 13 | 14 | class DataLoaderIAM: 15 | """loader for IAM dataset""" 16 | 17 | def __init__(self, dataset, batch_size, input_size, output_size): 18 | self.dataset = dataset 19 | self.batch_size = batch_size 20 | self.input_size = input_size 21 | self.output_size = output_size 22 | self.scale_down = compute_scale_down(input_size, output_size) 23 | self.shuffled_indices = np.arange(len(self.dataset)) 24 | self.curr_idx = 0 25 | self.is_random = False 26 | 27 | def __getitem__(self, item): 28 | batch_imgs = [] 29 | batch_gt_maps = [] 30 | batch_aabbs = [] 31 | for b in range(self.batch_size): 32 | if self.is_random: 33 | shuffled_idx = self.shuffled_indices[item * self.batch_size + b] 34 | else: 35 | shuffled_idx = item * self.batch_size + b 36 | 37 | img, aabbs = self.dataset[shuffled_idx] 38 | 39 | if self.is_random: 40 | # geometric data augmentation (image [0..255] and gt) 41 | if prob_true(0.75): 42 | # random scale 43 | fx = np.random.uniform(0.5, 1.5) 44 | fy = np.random.uniform(0.5, 1.5) 45 | 46 | # random position around center 47 | txc = self.input_size[1] * (1 - fx) / 2 48 | tyc = self.input_size[0] * (1 - fy) / 2 49 | freedom_x = self.input_size[1] // 10 50 | freedom_y = self.input_size[0] // 10 51 | tx = txc + np.random.randint(-freedom_x, freedom_x) 52 | ty = tyc + np.random.randint(-freedom_y, freedom_y) 53 | 54 | # map image into target image 55 | M = np.float32([[fx, 0, tx], [0, fy, ty]]) 56 | white_bg = np.ones(self.input_size, np.uint8) * 255 57 | img = cv2.warpAffine(img, M, dsize=self.input_size[::-1], dst=white_bg, 58 | borderMode=cv2.BORDER_TRANSPARENT) 59 | 60 | # apply the same transformations to gt, and clip/remove aabbs outside of target image 61 | aabb_clip = AABB(0, img.shape[1], 0, img.shape[0]) 62 | aabbs = [aabb.scale(fx, fy).translate(tx, ty).clip(aabb_clip) for aabb in aabbs] 63 | aabbs = [aabb for aabb in aabbs if aabb.area() > 0] 64 | 65 | # photometric data augmentation (image [-0.5..0.5] only) 66 | img = (img / 255 - 0.5) 67 | if prob_true(0.25): # random distractors (lines) 68 | num_lines = np.random.randint(1, 20) 69 | for _ in range(num_lines): 70 | rand_pt = lambda: (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0])) 71 | color = np.random.triangular(-0.5, 0, 0.5) 72 | thickness = np.random.randint(1, 3) 73 | cv2.line(img, rand_pt(), rand_pt(), color, thickness) 74 | if prob_true(0.75): # random contrast 75 | img = (img - img.min()) / (img.max() - img.min()) - 0.5 # stretch 76 | img = img * np.random.triangular(0.1, 0.9, 1) # reduce contrast 77 | if prob_true(0.25): # random noise 78 | img = img + np.random.uniform(-0.1, 0.1, size=img.shape) 79 | if prob_true(0.25): # change thickness of text 80 | img = cv2.erode(img, np.ones((3, 3))) 81 | if prob_true(0.25): # change thickness of text 82 | img = cv2.dilate(img, np.ones((3, 3))) 83 | if prob_true(0.25): # invert image 84 | img = 0.5 - img 85 | 86 | else: 87 | img = (img / 255 - 0.5) 88 | 89 | gt_map = encode(self.output_size, aabbs, self.scale_down) 90 | 91 | batch_imgs.append(img[None, ...].astype(np.float32)) 92 | batch_gt_maps.append(gt_map) 93 | batch_aabbs.append(aabbs) 94 | 95 | batch_imgs = np.stack(batch_imgs, axis=0) 96 | batch_gt_maps = np.stack(batch_gt_maps, axis=0) 97 | 98 | batch_imgs = torch.from_numpy(batch_imgs).to('cuda') 99 | batch_gt_maps = torch.from_numpy(batch_gt_maps.astype(np.float32)).to('cuda') 100 | 101 | return DataLoaderItem(batch_imgs, batch_gt_maps, batch_aabbs) 102 | 103 | def reset(self): 104 | self.curr_idx = 0 105 | 106 | def random(self, enable=True): 107 | np.random.shuffle(self.shuffled_indices) 108 | self.is_random = enable 109 | 110 | def __len__(self): 111 | return len(self.dataset) // self.batch_size 112 | 113 | 114 | class DataLoaderImgFile: 115 | """loader which simply goes through all jpg files of a directory""" 116 | 117 | def __init__(self, root_dir, input_size, device, max_side_len=1024): 118 | self.fn_imgs = root_dir.files('*.jpg') 119 | self.input_size = input_size 120 | self.device = device 121 | self.max_side_len = max_side_len 122 | 123 | def ceil32(self, val): 124 | if val % 32 == 0: 125 | return val 126 | val = (val // 32 + 1) * 32 127 | return val 128 | 129 | def __getitem__(self, item): 130 | orig = cv2.imread(self.fn_imgs[item], cv2.IMREAD_GRAYSCALE) 131 | 132 | f = min(self.max_side_len / orig.shape[0], self.max_side_len / orig.shape[1]) 133 | if f < 1: 134 | orig = cv2.resize(orig, dsize=None, fx=f, fy=f) 135 | img = np.ones((self.ceil32(orig.shape[0]), self.ceil32(orig.shape[1])), np.uint8) * 255 136 | img[:orig.shape[0], :orig.shape[1]] = orig 137 | 138 | img = (img / 255 - 0.5).astype(np.float32) 139 | imgs = img[None, None, ...] 140 | imgs = torch.from_numpy(imgs).to(self.device) 141 | return DataLoaderItem(imgs, None, None) 142 | 143 | def get_scale_factor(self, item): 144 | img = cv2.imread(self.fn_imgs[item], cv2.IMREAD_GRAYSCALE) 145 | f = min(self.max_side_len / img.shape[0], self.max_side_len / img.shape[1]) 146 | return f if f < 1 else 1 147 | 148 | def get_original_img(self, item): 149 | img = cv2.imread(self.fn_imgs[item], cv2.IMREAD_GRAYSCALE) 150 | img = (img / 255 - 0.5).astype(np.float32) 151 | return img 152 | 153 | def __len__(self): 154 | return len(self.fn_imgs) 155 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import xml.etree.ElementTree as ET 3 | 4 | import cv2 5 | from path import Path 6 | 7 | from aabb import AABB 8 | 9 | 10 | class DatasetIAM: 11 | """loads the image and ground truth data of the IAM dataset""" 12 | 13 | def __init__(self, root_dir, input_size, output_size, caching=True): 14 | 15 | self.caching = caching 16 | self.input_size = input_size 17 | self.output_size = output_size 18 | self.loaded_img_scale = 0.25 19 | self.fn_gts = [] 20 | self.fn_imgs = [] 21 | self.img_cache = [] 22 | self.gt_cache = [] 23 | self.num_samples = 0 24 | 25 | fn_cache = root_dir / 'cache.pickle' 26 | if self.caching and fn_cache.exists(): 27 | self.img_cache, self.gt_cache = pickle.load(open(fn_cache, 'rb')) 28 | self.num_samples = len(self.img_cache) 29 | return 30 | 31 | gt_dir = root_dir / 'gt' 32 | img_dir = root_dir / 'img' 33 | for fn_gt in sorted(gt_dir.files('*.xml')): 34 | fn_img = img_dir / fn_gt.stem + '.png' 35 | if not fn_img.exists(): 36 | continue 37 | 38 | self.fn_imgs.append(fn_img.abspath()) 39 | self.fn_gts.append(fn_gt.abspath()) 40 | self.num_samples += 1 41 | 42 | if self.caching: 43 | img = cv2.imread(fn_img.abspath(), cv2.IMREAD_GRAYSCALE) 44 | img = cv2.resize(img, dsize=None, fx=self.loaded_img_scale, fy=self.loaded_img_scale) 45 | gt = self.parse_gt(fn_gt.abspath()) 46 | 47 | img, gt = self.crop(img, gt) 48 | img, gt = self.adjust_size(img, gt) 49 | 50 | self.img_cache.append(img) 51 | self.gt_cache.append(gt) 52 | 53 | if self.caching: 54 | pickle.dump([self.img_cache, self.gt_cache], open(fn_cache, 'wb')) 55 | 56 | def parse_gt(self, fn_gt): 57 | tree = ET.parse(fn_gt) 58 | root = tree.getroot() 59 | 60 | aabbs = [] # list of all axis aligned bounding boxes of current sample 61 | 62 | # go over all lines 63 | for line in root.findall("./handwritten-part/line"): 64 | 65 | # go over all words 66 | for word in line.findall('./word'): 67 | xmin, xmax, ymin, ymax = float('inf'), 0, float('inf'), 0 68 | success = False 69 | 70 | # go over all characters 71 | for cmp in word.findall('./cmp'): 72 | success = True 73 | x = float(cmp.attrib['x']) 74 | y = float(cmp.attrib['y']) 75 | w = float(cmp.attrib['width']) 76 | h = float(cmp.attrib['height']) 77 | 78 | # aabb around all characters is aabb around word 79 | xmin = min(xmin, x) 80 | xmax = max(xmax, x + w) 81 | ymin = min(ymin, y) 82 | ymax = max(ymax, y + h) 83 | 84 | if success: 85 | aabbs.append(AABB(xmin, xmax, ymin, ymax).scale(self.loaded_img_scale, self.loaded_img_scale)) 86 | 87 | return aabbs 88 | 89 | def crop(self, img, gt): 90 | xmin = min([aabb.xmin for aabb in gt]) 91 | xmax = max([aabb.xmax for aabb in gt]) 92 | ymin = min([aabb.ymin for aabb in gt]) 93 | ymax = max([aabb.ymax for aabb in gt]) 94 | 95 | gt_crop = [aabb.translate(-xmin, -ymin) for aabb in gt] 96 | img_crop = img[int(ymin):int(ymax), int(xmin):int(xmax)] 97 | return img_crop, gt_crop 98 | 99 | def adjust_size(self, img, gt): 100 | h, w = img.shape 101 | fx = self.input_size[1] / w 102 | fy = self.input_size[0] / h 103 | gt = [aabb.scale(fx, fy) for aabb in gt] 104 | img = cv2.resize(img, dsize=self.input_size) 105 | return img, gt 106 | 107 | def __getitem__(self, idx): 108 | 109 | if self.caching: 110 | img = self.img_cache[idx] 111 | gt = self.gt_cache[idx] 112 | else: 113 | img = cv2.imread(self.fn_imgs[idx], cv2.IMREAD_GRAYSCALE) 114 | img = cv2.resize(img, dsize=None, fx=self.loaded_img_scale, fy=self.loaded_img_scale) 115 | gt = self.parse_gt(self.fn_gts[idx]) 116 | img, gt = self.crop(img, gt) 117 | img, gt = self.adjust_size(img, gt) 118 | 119 | return img, gt 120 | 121 | def __len__(self): 122 | return self.num_samples 123 | 124 | 125 | class DatasetIAMSplit: 126 | """wrapper which provides a dataset interface for a split of the original dataset""" 127 | def __init__(self, dataset, start_idx, end_idx): 128 | assert start_idx >= 0 and end_idx <= len(dataset) 129 | 130 | self.dataset = dataset 131 | self.start_idx = start_idx 132 | self.end_idx = end_idx 133 | 134 | def __getitem__(self, idx): 135 | return self.dataset[self.start_idx + idx] 136 | 137 | def __len__(self): 138 | return self.end_idx - self.start_idx 139 | 140 | 141 | if __name__ == '__main__': 142 | from visualization import visualize 143 | from coding import encode, decode 144 | import matplotlib.pyplot as plt 145 | 146 | dataset = DatasetIAM(Path('../data'), (350, 350), (350, 350), caching=False) 147 | img, gt = dataset[0] 148 | gt_map = encode(img.shape, gt) 149 | gt = decode(gt_map) 150 | 151 | plt.imshow(visualize(img / 255 - 0.5, gt)) 152 | plt.show() 153 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import torch 6 | from path import Path 7 | 8 | from aabb import AABB 9 | from aabb_clustering import cluster_aabbs 10 | from coding import decode, fg_by_cc 11 | from dataloader import DataLoaderIAM 12 | from dataset import DatasetIAM, DatasetIAMSplit 13 | from iou import compute_dist_mat_2 14 | from loss import compute_loss 15 | from net import WordDetectorNet 16 | from utils import compute_scale_down 17 | from visualization import visualize_and_plot 18 | 19 | EvaluateRes = namedtuple('EvaluateRes', 'batch_imgs,batch_aabbs,loss,metrics') 20 | 21 | 22 | class BinaryClassificationMetrics: 23 | def __init__(self, tp, fp, fn): 24 | self.tp = tp 25 | self.fp = fp 26 | self.fn = fn 27 | 28 | def accumulate(self, other): 29 | tp = self.tp + other.tp 30 | fp = self.fp + other.fp 31 | fn = self.fn + other.fn 32 | return BinaryClassificationMetrics(tp, fp, fn) 33 | 34 | def recall(self): 35 | return self.tp / (self.tp + self.fp) if self.tp + self.fp > 0 else 0 36 | 37 | def precision(self): 38 | return self.tp / (self.tp + self.fn) if self.tp + self.fn > 0 else 0 39 | 40 | def f1(self): 41 | re = self.recall() 42 | pr = self.precision() 43 | return 2 * pr * re / (pr + re) if pr + re > 0 else 0 44 | 45 | 46 | def binary_classification_metrics(gt_aabbs, pred_aabbs): 47 | iou_thres = 0.7 48 | 49 | ious = 1 - compute_dist_mat_2(gt_aabbs, pred_aabbs) 50 | match_counter = (ious > iou_thres).astype(np.int) 51 | gt_counter = np.sum(match_counter, axis=1) 52 | pred_counter = np.sum(match_counter, axis=0) 53 | 54 | tp = np.count_nonzero(pred_counter == 1) 55 | fp = np.count_nonzero(pred_counter == 0) 56 | fn = np.count_nonzero(gt_counter == 0) 57 | 58 | return BinaryClassificationMetrics(tp, fp, fn) 59 | 60 | 61 | def evaluate(net, loader, thres=0.5, max_aabbs=None): 62 | batch_imgs = [] 63 | batch_aabbs = [] 64 | loss = 0 65 | 66 | for i in range(len(loader)): 67 | # get batch 68 | loader_item = loader[i] 69 | with torch.no_grad(): 70 | y = net(loader_item.batch_imgs, apply_softmax=True) 71 | y_np = y.to('cpu').numpy() 72 | if loader_item.batch_gt_maps is not None: 73 | loss += compute_loss(y, loader_item.batch_gt_maps).to('cpu').numpy() 74 | 75 | scale_up = 1 / compute_scale_down(WordDetectorNet.input_size, WordDetectorNet.output_size) 76 | metrics = BinaryClassificationMetrics(0, 0, 0) 77 | for i in range(len(y_np)): 78 | img_np = loader_item.batch_imgs[i, 0].to('cpu').numpy() 79 | pred_map = y_np[i] 80 | 81 | aabbs = decode(pred_map, comp_fg=fg_by_cc(thres, max_aabbs), f=scale_up) 82 | h, w = img_np.shape 83 | aabbs = [aabb.clip(AABB(0, w - 1, 0, h - 1)) for aabb in aabbs] # bounding box must be inside img 84 | clustered_aabbs = cluster_aabbs(aabbs) 85 | 86 | if loader_item.batch_aabbs is not None: 87 | curr_metrics = binary_classification_metrics(loader_item.batch_aabbs[i], clustered_aabbs) 88 | metrics = metrics.accumulate(curr_metrics) 89 | 90 | batch_imgs.append(img_np) 91 | batch_aabbs.append(clustered_aabbs) 92 | 93 | return EvaluateRes(batch_imgs, batch_aabbs, loss / len(loader), metrics) 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--batch_size', type=int, default=10) 99 | parser.add_argument('--data_dir', type=Path, required=True) 100 | args = parser.parse_args() 101 | 102 | net = WordDetectorNet() 103 | net.load_state_dict(torch.load('../model/weights')) 104 | net.eval() 105 | net.to('cuda') 106 | 107 | dataset = DatasetIAM(args.data_dir, net.input_size, net.output_size, caching=False) 108 | dataset_eval = DatasetIAMSplit(dataset, 0, 10) 109 | loader = DataLoaderIAM(dataset_eval, args.batch_size, net.input_size, net.output_size) 110 | 111 | res = evaluate(net, loader, max_aabbs=1000) 112 | print(f'Loss: {res.loss}') 113 | print(f'Recall: {res.metrics.recall()}') 114 | print(f'Precision: {res.metrics.precision()}') 115 | print(f'F1 score: {res.metrics.f1()}') 116 | 117 | for img, aabbs in zip(res.batch_imgs, res.batch_aabbs): 118 | visualize_and_plot(img, aabbs) 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from path import Path 5 | 6 | from dataloader import DataLoaderImgFile 7 | from eval import evaluate 8 | from net import WordDetectorNet 9 | from visualization import visualize_and_plot 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--device', choices=['cpu', 'cuda'], default='cuda') 15 | args = parser.parse_args() 16 | 17 | net = WordDetectorNet() 18 | net.load_state_dict(torch.load('../model/weights', map_location=args.device)) 19 | net.eval() 20 | net.to(args.device) 21 | 22 | loader = DataLoaderImgFile(Path('../data/test'), net.input_size, args.device) 23 | res = evaluate(net, loader, max_aabbs=1000) 24 | 25 | for i, (img, aabbs) in enumerate(zip(res.batch_imgs, res.batch_aabbs)): 26 | f = loader.get_scale_factor(i) 27 | aabbs = [aabb.scale(1 / f, 1 / f) for aabb in aabbs] 28 | img = loader.get_original_img(i) 29 | visualize_and_plot(img, aabbs) 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /src/iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_iou(ra, rb): 5 | """intersection over union of two axis aligned rectangles ra and rb""" 6 | if ra.xmax < rb.xmin or rb.xmax < ra.xmin or ra.ymax < rb.ymin or rb.ymax < ra.ymin: 7 | return 0 8 | 9 | l = max(ra.xmin, rb.xmin) 10 | r = min(ra.xmax, rb.xmax) 11 | t = max(ra.ymin, rb.ymin) 12 | b = min(ra.ymax, rb.ymax) 13 | 14 | intersection = (r - l) * (b - t) 15 | union = ra.area() + rb.area() - intersection 16 | 17 | iou = intersection / union 18 | return iou 19 | 20 | 21 | def compute_dist_mat(aabbs): 22 | """Jaccard distance matrix of all pairs of aabbs""" 23 | num_aabbs = len(aabbs) 24 | 25 | dists = np.zeros((num_aabbs, num_aabbs)) 26 | for i in range(num_aabbs): 27 | for j in range(num_aabbs): 28 | if j > i: 29 | break 30 | 31 | dists[i, j] = dists[j, i] = 1 - compute_iou(aabbs[i], aabbs[j]) 32 | 33 | return dists 34 | 35 | 36 | def compute_dist_mat_2(aabbs1, aabbs2): 37 | """Jaccard distance matrix of all pairs of aabbs from lists aabbs1 and aabbs2""" 38 | num_aabbs1 = len(aabbs1) 39 | num_aabbs2 = len(aabbs2) 40 | 41 | dists = np.zeros((num_aabbs1, num_aabbs2)) 42 | for i in range(num_aabbs1): 43 | for j in range(num_aabbs2): 44 | dists[i, j] = 1 - compute_iou(aabbs1[i], aabbs2[j]) 45 | 46 | return dists 47 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from coding import MapOrdering 5 | 6 | 7 | def compute_loss(y, gt_map): 8 | # 1. segmentation loss 9 | target_labels = torch.argmax(gt_map[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], dim=1) 10 | loss_seg = F.cross_entropy(y[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], target_labels) 11 | 12 | # 2. geometry loss 13 | # distances to all sides of aabb 14 | t = torch.minimum(y[:, MapOrdering.GEO_TOP], gt_map[:, MapOrdering.GEO_TOP]) 15 | b = torch.minimum(y[:, MapOrdering.GEO_BOTTOM], gt_map[:, MapOrdering.GEO_BOTTOM]) 16 | l = torch.minimum(y[:, MapOrdering.GEO_LEFT], gt_map[:, MapOrdering.GEO_LEFT]) 17 | r = torch.minimum(y[:, MapOrdering.GEO_RIGHT], gt_map[:, MapOrdering.GEO_RIGHT]) 18 | 19 | # area of predicted aabb 20 | y_width = y[:, MapOrdering.GEO_LEFT, ...] + y[:, MapOrdering.GEO_RIGHT, ...] 21 | y_height = y[:, MapOrdering.GEO_TOP, ...] + y[:, MapOrdering.GEO_BOTTOM, ...] 22 | area1 = y_width * y_height 23 | 24 | # area of gt aabb 25 | gt_width = gt_map[:, MapOrdering.GEO_LEFT, ...] + gt_map[:, MapOrdering.GEO_RIGHT, ...] 26 | gt_height = gt_map[:, MapOrdering.GEO_TOP, ...] + gt_map[:, MapOrdering.GEO_BOTTOM, ...] 27 | area2 = gt_width * gt_height 28 | 29 | # compute intersection over union 30 | intersection = (r + l) * (b + t) 31 | union = area1 + area2 - intersection 32 | eps = 0.01 # avoid division by 0 33 | iou = intersection / (union + eps) 34 | iou = iou[gt_map[:, MapOrdering.SEG_WORD] > 0] 35 | loss_aabb = -torch.log(torch.mean(iou)) 36 | 37 | # total loss is simply the sum of both losses 38 | loss = loss_seg + loss_aabb 39 | return loss 40 | -------------------------------------------------------------------------------- /src/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from resnet import resnet18 5 | from coding import MapOrdering 6 | from utils import compute_scale_down 7 | 8 | 9 | class UpscaleAndConcatLayer(torch.nn.Module): 10 | """ 11 | take small map with cx channels 12 | upscale to size of large map (s*s) 13 | concat large map with cy channels and upscaled small map 14 | apply conv and output map with cz channels 15 | """ 16 | 17 | def __init__(self, cx, cy, cz): 18 | super(UpscaleAndConcatLayer, self).__init__() 19 | self.conv = torch.nn.Conv2d(cx + cy, cz, 3, padding=1) 20 | 21 | def forward(self, x, y, s): 22 | x = F.interpolate(x, s) 23 | z = torch.cat((x, y), 1) 24 | z = F.relu(self.conv(z)) 25 | return z 26 | 27 | 28 | class WordDetectorNet(torch.nn.Module): 29 | # fixed sizes for training 30 | input_size = (448, 448) 31 | output_size = (224, 224) 32 | scale_down = compute_scale_down(input_size, output_size) 33 | 34 | def __init__(self): 35 | super(WordDetectorNet, self).__init__() 36 | 37 | self.backbone = resnet18() 38 | 39 | self.up1 = UpscaleAndConcatLayer(512, 256, 256) # input//16 40 | self.up2 = UpscaleAndConcatLayer(256, 128, 128) # input//8 41 | self.up3 = UpscaleAndConcatLayer(128, 64, 64) # input//4 42 | self.up4 = UpscaleAndConcatLayer(64, 64, 32) # input//2 43 | 44 | self.conv1 = torch.nn.Conv2d(32, MapOrdering.NUM_MAPS, 3, 1, padding=1) 45 | 46 | @staticmethod 47 | def scale_shape(s, f): 48 | assert s[0] % f == 0 and s[1] % f == 0 49 | return s[0] // f, s[1] // f 50 | 51 | def output_activation(self, x, apply_softmax): 52 | if apply_softmax: 53 | seg = torch.softmax(x[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1], dim=1) 54 | else: 55 | seg = x[:, MapOrdering.SEG_WORD:MapOrdering.SEG_BACKGROUND + 1] 56 | geo = torch.sigmoid(x[:, MapOrdering.GEO_TOP:]) * self.input_size[0] 57 | y = torch.cat([seg, geo], dim=1) 58 | return y 59 | 60 | def forward(self, x, apply_softmax=False): 61 | # x: BxCxHxW 62 | # eval backbone with 448px: bb1: 224px, bb2: 112px, bb3: 56px, bb4: 28px, bb5: 14px 63 | s = x.shape[2:] 64 | bb5, bb4, bb3, bb2, bb1 = self.backbone(x) 65 | 66 | x = self.up1(bb5, bb4, self.scale_shape(s, 16)) 67 | x = self.up2(x, bb3, self.scale_shape(s, 8)) 68 | x = self.up3(x, bb2, self.scale_shape(s, 4)) 69 | x = self.up4(x, bb1, self.scale_shape(s, 2)) 70 | x = self.conv1(x) 71 | 72 | return self.output_activation(x, apply_softmax) 73 | -------------------------------------------------------------------------------- /src/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet 3 | taken from https://raw.githubusercontent.com/pytorch/vision/master/torchvision/models/resnet.py 4 | with modifications 5 | """ 6 | 7 | from typing import Type, Any, Callable, Union, List, Optional 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | 13 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion: int = 1 26 | 27 | def __init__( 28 | self, 29 | inplanes: int, 30 | planes: int, 31 | stride: int = 1, 32 | downsample: Optional[nn.Module] = None, 33 | groups: int = 1, 34 | base_width: int = 64, 35 | dilation: int = 1, 36 | norm_layer: Optional[Callable[..., nn.Module]] = None 37 | ) -> None: 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x: Tensor) -> Tensor: 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 77 | # This variant is also known as ResNet V1.5 and improves accuracy according to 78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 79 | 80 | expansion: int = 4 81 | 82 | def __init__( 83 | self, 84 | inplanes: int, 85 | planes: int, 86 | stride: int = 1, 87 | downsample: Optional[nn.Module] = None, 88 | groups: int = 1, 89 | base_width: int = 64, 90 | dilation: int = 1, 91 | norm_layer: Optional[Callable[..., nn.Module]] = None 92 | ) -> None: 93 | super(Bottleneck, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | width = int(planes * (base_width / 64.)) * groups 97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 98 | self.conv1 = conv1x1(inplanes, width) 99 | self.bn1 = norm_layer(width) 100 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 101 | self.bn2 = norm_layer(width) 102 | self.conv3 = conv1x1(width, planes * self.expansion) 103 | self.bn3 = norm_layer(planes * self.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | 133 | def __init__( 134 | self, 135 | block: Type[Union[BasicBlock, Bottleneck]], 136 | layers: List[int], 137 | num_classes: int = 1000, 138 | zero_init_residual: bool = False, 139 | groups: int = 1, 140 | width_per_group: int = 64, 141 | replace_stride_with_dilation: Optional[List[bool]] = None, 142 | norm_layer: Optional[Callable[..., nn.Module]] = None 143 | ) -> None: 144 | super(ResNet, self).__init__() 145 | if norm_layer is None: 146 | norm_layer = nn.BatchNorm2d 147 | self._norm_layer = norm_layer 148 | 149 | self.inplanes = 64 150 | self.dilation = 1 151 | if replace_stride_with_dilation is None: 152 | # each element in the tuple indicates if we should replace 153 | # the 2x2 stride with a dilated convolution instead 154 | replace_stride_with_dilation = [False, False, False] 155 | if len(replace_stride_with_dilation) != 3: 156 | raise ValueError("replace_stride_with_dilation should be None " 157 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 158 | self.groups = groups 159 | self.base_width = width_per_group 160 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 161 | bias=False) 162 | self.bn1 = norm_layer(self.inplanes) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 165 | self.layer1 = self._make_layer(block, 64, layers[0]) 166 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 167 | dilate=replace_stride_with_dilation[0]) 168 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 169 | dilate=replace_stride_with_dilation[1]) 170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 171 | dilate=replace_stride_with_dilation[2]) 172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 173 | self.fc = nn.Linear(512 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 179 | nn.init.constant_(m.weight, 1) 180 | nn.init.constant_(m.bias, 0) 181 | 182 | # Zero-initialize the last BN in each residual branch, 183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 185 | if zero_init_residual: 186 | for m in self.modules(): 187 | if isinstance(m, Bottleneck): 188 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 189 | elif isinstance(m, BasicBlock): 190 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 191 | 192 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 193 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 194 | norm_layer = self._norm_layer 195 | downsample = None 196 | previous_dilation = self.dilation 197 | if dilate: 198 | self.dilation *= stride 199 | stride = 1 200 | if stride != 1 or self.inplanes != planes * block.expansion: 201 | downsample = nn.Sequential( 202 | conv1x1(self.inplanes, planes * block.expansion, stride), 203 | norm_layer(planes * block.expansion), 204 | ) 205 | 206 | layers = [] 207 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 208 | self.base_width, previous_dilation, norm_layer)) 209 | self.inplanes = planes * block.expansion 210 | for _ in range(1, blocks): 211 | layers.append(block(self.inplanes, planes, groups=self.groups, 212 | base_width=self.base_width, dilation=self.dilation, 213 | norm_layer=norm_layer)) 214 | 215 | return nn.Sequential(*layers) 216 | 217 | def _forward_impl(self, x: Tensor) -> Tensor: 218 | # See note [TorchScript super()] 219 | x = self.conv1(x) 220 | x = self.bn1(x) 221 | out1 = self.relu(x) 222 | x = self.maxpool(out1) 223 | 224 | out2 = self.layer1(x) 225 | out3 = self.layer2(out2) 226 | out4 = self.layer3(out3) 227 | out5 = self.layer4(out4) 228 | 229 | return out5, out4, out3, out2, out1 230 | 231 | def forward(self, x: Tensor) -> Tensor: 232 | return self._forward_impl(x) 233 | 234 | 235 | def _resnet( 236 | arch: str, 237 | block: Type[Union[BasicBlock, Bottleneck]], 238 | layers: List[int], 239 | pretrained: bool, 240 | progress: bool, 241 | **kwargs: Any 242 | ) -> ResNet: 243 | model = ResNet(block, layers, **kwargs) 244 | return model 245 | 246 | 247 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 248 | r"""ResNet-18 model from 249 | `"Deep Residual Learning for Image Recognition" `_. 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 260 | r"""ResNet-34 model from 261 | `"Deep Residual Learning for Image Recognition" `_. 262 | 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 272 | r"""ResNet-50 model from 273 | `"Deep Residual Learning for Image Recognition" `_. 274 | 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 280 | **kwargs) 281 | 282 | 283 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 284 | r"""ResNet-101 model from 285 | `"Deep Residual Learning for Image Recognition" `_. 286 | 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | progress (bool): If True, displays a progress bar of the download to stderr 290 | """ 291 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 292 | **kwargs) 293 | 294 | 295 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 296 | r"""ResNet-152 model from 297 | `"Deep Residual Learning for Image Recognition" `_. 298 | 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 304 | **kwargs) 305 | 306 | 307 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 308 | r"""ResNeXt-50 32x4d model from 309 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 310 | 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | kwargs['groups'] = 32 316 | kwargs['width_per_group'] = 4 317 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 318 | pretrained, progress, **kwargs) 319 | 320 | 321 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 322 | r"""ResNeXt-101 32x8d model from 323 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 324 | 325 | Args: 326 | pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | progress (bool): If True, displays a progress bar of the download to stderr 328 | """ 329 | kwargs['groups'] = 32 330 | kwargs['width_per_group'] = 8 331 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 332 | pretrained, progress, **kwargs) 333 | 334 | 335 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 336 | r"""Wide ResNet-50-2 model from 337 | `"Wide Residual Networks" `_. 338 | 339 | The model is the same as ResNet except for the bottleneck number of channels 340 | which is twice larger in every block. The number of channels in outer 1x1 341 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 342 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | progress (bool): If True, displays a progress bar of the download to stderr 347 | """ 348 | kwargs['width_per_group'] = 64 * 2 349 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 350 | pretrained, progress, **kwargs) 351 | 352 | 353 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 354 | r"""Wide ResNet-101-2 model from 355 | `"Wide Residual Networks" `_. 356 | 357 | The model is the same as ResNet except for the bottleneck number of channels 358 | which is twice larger in every block. The number of channels in outer 1x1 359 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 360 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 361 | 362 | Args: 363 | pretrained (bool): If True, returns a model pre-trained on ImageNet 364 | progress (bool): If True, displays a progress bar of the download to stderr 365 | """ 366 | kwargs['width_per_group'] = 64 * 2 367 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 368 | pretrained, progress, **kwargs) 369 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | from path import Path 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from dataloader import DataLoaderIAM 9 | from dataset import DatasetIAM, DatasetIAMSplit 10 | from eval import evaluate 11 | from loss import compute_loss 12 | from net import WordDetectorNet 13 | from visualization import visualize 14 | 15 | global_step = 0 16 | 17 | 18 | def validate(net, loader, writer): 19 | global global_step 20 | 21 | net.eval() 22 | loader.reset() 23 | res = evaluate(net, loader, max_aabbs=1000) 24 | 25 | for i, (img, aabbs) in enumerate(zip(res.batch_imgs, res.batch_aabbs)): 26 | vis = visualize(img, aabbs) 27 | writer.add_image(f'img{i}', vis.transpose((2, 0, 1)), global_step) 28 | writer.add_scalar('val_loss', res.loss, global_step) 29 | writer.add_scalar('val_recall', res.metrics.recall(), global_step) 30 | writer.add_scalar('val_precision', res.metrics.precision(), global_step) 31 | writer.add_scalar('val_f1', res.metrics.f1(), global_step) 32 | 33 | return res.metrics.f1() 34 | 35 | 36 | def train(net, optimizer, loader, writer): 37 | global global_step 38 | 39 | net.train() 40 | loader.reset() 41 | loader.random() 42 | for i in range(len(loader)): 43 | # get batch 44 | loader_item = loader[i] 45 | 46 | # forward pass 47 | optimizer.zero_grad() 48 | y = net(loader_item.batch_imgs) 49 | loss = compute_loss(y, loader_item.batch_gt_maps) 50 | 51 | # backward pass, optimize loss 52 | loss.backward() 53 | optimizer.step() 54 | 55 | # output 56 | print(f'{i + 1}/{len(loader)}: {loss}') 57 | writer.add_scalar('loss', loss, global_step) 58 | global_step += 1 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--batch_size', type=int, default=10) 64 | parser.add_argument('--caching', action='store_true') 65 | parser.add_argument('--data_dir', type=Path, required=True) 66 | parser.add_argument('--pretrained', action='store_true') 67 | parser.add_argument('--val_freq', type=int, default=1) 68 | parser.add_argument('--early_stopping', type=int, default=50) 69 | args = parser.parse_args() 70 | 71 | writer = SummaryWriter('../log') 72 | 73 | net = WordDetectorNet() 74 | if args.pretrained: 75 | net.load_state_dict(torch.load('../model/weights')) 76 | net.to('cuda') 77 | 78 | # dataset that actually holds the data and 2 views for training and validation set 79 | dataset = DatasetIAM(args.data_dir, net.input_size, net.output_size, caching=args.caching) 80 | dataset_train = DatasetIAMSplit(dataset, 2 * args.batch_size, len(dataset)) 81 | dataset_val = DatasetIAMSplit(dataset, 0, 2 * args.batch_size) 82 | 83 | # loaders 84 | loader_train = DataLoaderIAM(dataset_train, args.batch_size, net.input_size, net.output_size) 85 | loader_val = DataLoaderIAM(dataset_val, args.batch_size, net.input_size, net.output_size) 86 | 87 | # optimizer 88 | optimizer = torch.optim.Adam(net.parameters()) 89 | 90 | # main training loop 91 | epoch = 0 92 | best_val_f1 = 0 93 | no_improvement_since = 0 94 | while True: 95 | epoch += 1 96 | print(f'Epoch: {epoch}') 97 | train(net, optimizer, loader_train, writer) 98 | 99 | if epoch % args.val_freq == 0: 100 | val_f1 = validate(net, loader_val, writer) 101 | if val_f1 > best_val_f1: 102 | print(f'Improved on validation set (f1: {best_val_f1}->{val_f1}), save model') 103 | no_improvement_since = 0 104 | best_val_f1 = val_f1 105 | torch.save(net.state_dict(), '../model/weights') 106 | with open('../model/metadata.json', 'w') as f: 107 | json.dump({'epoch': epoch, 'val_f1': val_f1}, f) 108 | else: 109 | no_improvement_since += 1 110 | 111 | # stop training if there were too many validation steps without improvement 112 | if no_improvement_since >= args.early_stopping: 113 | print(f'No improvement for {no_improvement_since} validation steps, stop training') 114 | break 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_scale_down(input_size, output_size): 5 | """compute scale down factor of neural network, given input and output size""" 6 | return output_size[0] / input_size[0] 7 | 8 | 9 | def prob_true(p): 10 | """return True with probability p""" 11 | return np.random.random() < p 12 | -------------------------------------------------------------------------------- /src/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | 6 | def visualize(img, aabbs): 7 | img = ((img + 0.5) * 255).astype(np.uint8) 8 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 9 | 10 | for aabb in aabbs: 11 | aabb = aabb.enlarge_to_int_grid().as_type(int) 12 | cv2.rectangle(img, (aabb.xmin, aabb.ymin), (aabb.xmax, aabb.ymax), (255, 0, 255), 2) 13 | 14 | return img 15 | 16 | 17 | def visualize_and_plot(img, aabbs): 18 | plt.imshow(img, cmap='gray') 19 | for aabb in aabbs: 20 | plt.plot([aabb.xmin, aabb.xmin, aabb.xmax, aabb.xmax, aabb.xmin], 21 | [aabb.ymin, aabb.ymax, aabb.ymax, aabb.ymin, aabb.ymin]) 22 | 23 | plt.show() 24 | --------------------------------------------------------------------------------