├── .readthedocs.yml ├── LICENSE ├── README.md ├── demo.py ├── demo ├── demo_cbad_config.json ├── demo_config.json ├── demo_page_config.json └── interactive_demo.ipynb ├── dh_segment ├── __init__.py ├── estimator_fn.py ├── inference │ ├── __init__.py │ └── loader.py ├── io │ ├── PAGE.py │ ├── __init__.py │ ├── input.py │ ├── input_utils.py │ └── via.py ├── network │ ├── __init__.py │ ├── model.py │ └── pretrained_models.py ├── post_processing │ ├── __init__.py │ ├── binarization.py │ ├── boxes_detection.py │ ├── line_vectorization.py │ └── polygon_detection.py └── utils │ ├── __init__.py │ ├── evaluation.py │ ├── labels.py │ ├── misc.py │ └── params_config.py ├── doc ├── Makefile ├── _static │ ├── cbad.jpg │ ├── cini.jpg │ ├── cini_input.jpg │ ├── cini_labels.jpg │ ├── diva.jpg │ ├── diva_preds.png │ ├── ornaments.jpg │ ├── page.jpg │ ├── system.png │ ├── tensorboard_1.png │ ├── tensorboard_2.png │ └── tensorboard_3.png ├── changelog.rst ├── conf.py ├── index.rst ├── intro │ └── intro.rst ├── reference │ ├── index.rst │ ├── inference.rst │ ├── io.rst │ ├── network.rst │ ├── post_processing.rst │ └── utils.rst ├── references.bib ├── references.rst ├── start │ ├── annotating.rst │ ├── demo.rst │ ├── index.rst │ ├── install.rst │ └── training.rst └── tutorials │ └── index.rst ├── exps ├── README.md ├── __init__.py ├── cbad │ ├── README.md │ ├── __init__.py │ ├── demo_processing.py │ ├── evaluation.py │ ├── example_evaluation.ipynb │ ├── make_cbad.py │ ├── process.py │ └── utils.py ├── commonutils.py └── page │ ├── README.md │ ├── __init__.py │ ├── demo_processing.py │ ├── evaluation.py │ ├── example_evaluation.ipynb │ ├── example_processing.ipynb │ ├── make_page.py │ ├── process.py │ └── utils.py ├── general_config.json ├── pretrained_models ├── download_resnet_pretrained_model.py └── download_vgg_pretrained_model.py ├── setup.py └── train.py /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | python: 2 | version: 3.5 3 | pip_install: true 4 | extra_requirements: 5 | - doc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dhSegment 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/dhsegment/badge/?version=latest)](https://dhsegment.readthedocs.io/en/latest/?badge=latest) 4 | 5 | **dhSegment** is a tool for Historical Document Processing. Its generic approach allows to segment regions and 6 | extract content from different type of documents. See 7 | [some examples here](https://dhsegment.readthedocs.io/en/latest/intro/intro.html#use-cases). 8 | 9 | The complete description of the system can be found in the corresponding [paper](https://arxiv.org/abs/1804.10371). 10 | 11 | It was created by [Benoit Seguin](https://twitter.com/Seguin_Be) and Sofia Ares Oliveira at DHLAB, EPFL. 12 | 13 | ## Installation and usage 14 | The [installation procedure](https://dhsegment.readthedocs.io/en/latest/start/install.html) 15 | and examples of usage can be found in the documentation (see section below). 16 | 17 | ## Demo 18 | Have a try at the [demo](https://dhsegment.readthedocs.io/en/latest/start/demo.html) to train (optional) and apply dhSegment in page extraction using the `demo.py` script. 19 | 20 | ## Documentation 21 | The documentation is available on [readthedocs](https://dhsegment.readthedocs.io/). 22 | 23 | ## 24 | If you are using this code for your research, you can cite the corresponding paper as : 25 | ``` 26 | @inproceedings{oliveiraseguinkaplan2018dhsegment, 27 | title={dhSegment: A generic deep-learning approach for document segmentation}, 28 | author={Ares Oliveira, Sofia and Seguin, Benoit and Kaplan, Frederic}, 29 | booktitle={Frontiers in Handwriting Recognition (ICFHR), 2018 16th International Conference on}, 30 | pages={7--12}, 31 | year={2018}, 32 | organization={IEEE} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from glob import glob 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | from imageio import imread, imsave 10 | from tqdm import tqdm 11 | 12 | from dh_segment.io import PAGE 13 | from dh_segment.inference import LoadedModel 14 | from dh_segment.post_processing import boxes_detection, binarization 15 | 16 | # To output results in PAGE XML format (http://www.primaresearch.org/schema/PAGE/gts/pagecontent/2013-07-15/) 17 | PAGE_XML_DIR = './page_xml' 18 | 19 | 20 | def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray: 21 | """ 22 | Computes the binary mask of the detected Page from the probabilities outputed by network 23 | :param probs: array with values in range [0, 1] 24 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 25 | :return: binary mask 26 | """ 27 | 28 | mask = binarization.thresholding(probs, threshold) 29 | mask = binarization.cleaning_binary(mask, kernel_size=5) 30 | return mask 31 | 32 | 33 | def format_quad_to_string(quad): 34 | """ 35 | Formats the corner points into a string. 36 | :param quad: coordinates of the quadrilateral 37 | :return: 38 | """ 39 | s = '' 40 | for corner in quad: 41 | s += '{},{},'.format(corner[0], corner[1]) 42 | return s[:-1] 43 | 44 | 45 | if __name__ == '__main__': 46 | 47 | # If the model has been trained load the model, otherwise use the given model 48 | model_dir = 'demo/page_model/export' 49 | if not os.path.exists(model_dir): 50 | model_dir = 'demo/model/' 51 | 52 | input_files = glob('demo/pages/test_a1/images/*') 53 | 54 | output_dir = 'demo/processed_images' 55 | os.makedirs(output_dir, exist_ok=True) 56 | # PAGE XML format output 57 | output_pagexml_dir = os.path.join(output_dir, PAGE_XML_DIR) 58 | os.makedirs(output_pagexml_dir, exist_ok=True) 59 | 60 | # Store coordinates of page in a .txt file 61 | txt_coordinates = '' 62 | 63 | with tf.Session(): # Start a tensorflow session 64 | # Load the model 65 | m = LoadedModel(model_dir, predict_mode='filename') 66 | 67 | for filename in tqdm(input_files, desc='Processed files'): 68 | # For each image, predict each pixel's label 69 | prediction_outputs = m.predict(filename) 70 | probs = prediction_outputs['probs'][0] 71 | original_shape = prediction_outputs['original_shape'] 72 | probs = probs[:, :, 1] # Take only class '1' (class 0 is the background, class 1 is the page) 73 | probs = probs / np.max(probs) # Normalize to be in [0, 1] 74 | 75 | # Binarize the predictions 76 | page_bin = page_make_binary_mask(probs) 77 | 78 | # Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes) 79 | bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False), 80 | tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST) 81 | 82 | # Find quadrilateral enclosing the page 83 | pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False), 84 | mode='min_rectangle', min_area=0.2, n_max_boxes=1) 85 | 86 | # Draw page box on original image and export it. Add also box coordinates to the txt file 87 | original_img = imread(filename, pilmode='RGB') 88 | if pred_page_coords is not None: 89 | cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5) 90 | # Write corners points into a .txt file 91 | txt_coordinates += '{},{}\n'.format(filename, format_quad_to_string(pred_page_coords)) 92 | 93 | # Create page region and XML file 94 | page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :])) 95 | else: 96 | print('No box found in {}'.format(filename)) 97 | page_border = PAGE.Border() 98 | 99 | basename = os.path.basename(filename).split('.')[0] 100 | imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img) 101 | 102 | page_xml = PAGE.Page(image_filename=filename, image_width=original_shape[1], image_height=original_shape[0], 103 | page_border=page_border) 104 | xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename)) 105 | page_xml.write_to_file(xml_filename, creator_name='PageExtractor') 106 | 107 | # Save txt file 108 | with open(os.path.join(output_dir, 'pages.txt'), 'w') as f: 109 | f.write(txt_coordinates) 110 | -------------------------------------------------------------------------------- /demo/demo_cbad_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 5e-5, 4 | "batch_size": 1, 5 | "make_patches": false, 6 | "training_margin" : 0, 7 | "n_epochs": 30, 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": true, 12 | "data_augmentation_flip_ud": false, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10, 15 | "input_resized_size": 1e6 16 | }, 17 | "pretrained_model_name" : "resnet50", 18 | "prediction_type": "CLASSIFICATION", 19 | "train_data" : "data/cbad-masks/simple/train/train_data.csv", 20 | "eval_data" : "data/cbad-masks/simple/train/eval_data.csv", 21 | "classes_file" : "data/cbad-masks/simple/train/classes.txt", 22 | "model_output_dir" : "demo/cbad_simple_model", 23 | "gpu" : "0" 24 | } -------------------------------------------------------------------------------- /demo/demo_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 5e-5, 4 | "batch_size": 1, 5 | "make_patches": false, 6 | "training_margin" : 0, 7 | "n_epochs": 30, 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": true, 12 | "data_augmentation_flip_ud": true, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10 15 | }, 16 | "pretrained_model_name" : "resnet50", 17 | "prediction_type": "CLASSIFICATION", 18 | "train_data" : "demo/pages/train/", 19 | "eval_data" : "demo/pages/val_a1", 20 | "classes_file" : "demo/pages/classes.txt", 21 | "model_output_dir" : "demo/page_model", 22 | "gpu" : "0" 23 | } -------------------------------------------------------------------------------- /demo/demo_page_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 5e-5, 4 | "batch_size": 1, 5 | "make_patches": false, 6 | "training_margin" : 0, 7 | "n_epochs": 30, 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": true, 12 | "data_augmentation_flip_ud": true, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10 15 | }, 16 | "pretrained_model_name" : "resnet50", 17 | "prediction_type": "CLASSIFICATION", 18 | "train_data" : "data/page-masks/train/", 19 | "eval_data" : "data/page-masks/eval", 20 | "classes_file" : "data/page-masks/train/classes.txt", 21 | "model_output_dir" : "demo/page_model", 22 | "gpu" : "0" 23 | } -------------------------------------------------------------------------------- /demo/interactive_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Interactive demo to load a trained model for page extraction and apply it to a randomly selected file" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "#### 1. Get the annotated sample dataset, which already contains the folders images and labels. Unzip it into `demo/pages_sample`." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip\n", 24 | "! unzip pages.zip" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "#### 2. Download the provided model (download and unzip it in `demo/model`)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "! wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip\n", 41 | "! unzip model.zip" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "#### 3. Run the code step by step" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "import os\n", 58 | "import cv2\n", 59 | "from glob import glob\n", 60 | "import numpy as np\n", 61 | "import random\n", 62 | "import tensorflow as tf\n", 63 | "from imageio import imread, imsave" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import matplotlib.pyplot as plt\n", 73 | "%matplotlib inline" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "from dh_segment.io import PAGE\n", 83 | "from dh_segment.inference import LoadedModel\n", 84 | "from dh_segment.post_processing import boxes_detection, binarization" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray:\n", 94 | " \"\"\"\n", 95 | " Computes the binary mask of the detected Page from the probabilities outputed by network\n", 96 | " :param probs: array with values in range [0, 1]\n", 97 | " :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used\n", 98 | " :return: binary mask\n", 99 | " \"\"\"\n", 100 | "\n", 101 | " mask = binarization.thresholding(probs, threshold)\n", 102 | " mask = binarization.cleaning_binary(mask, kernel_size=5)\n", 103 | " return mask" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Define input and output directories / files" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "model_dir = 'page_model/export'\n", 129 | "if not os.path.exists(model_dir):\n", 130 | " model_dir = 'model/'\n", 131 | "assert(os.path.exists(model_dir))\n", 132 | "\n", 133 | "input_files = glob(os.path.join('pages', 'test_a1', 'images/*'))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "output_dir = './processed_images'\n", 143 | "os.makedirs(output_dir, exist_ok=True)\n", 144 | "# PAGE XML format output\n", 145 | "output_pagexml_dir = os.path.join(output_dir, 'page_xml')\n", 146 | "os.makedirs(output_pagexml_dir, exist_ok=True)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "Start a tensorflow session" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "session = tf.InteractiveSession()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Select a random image" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "file_to_process = random.sample(input_files, 1)[0]" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "Load the model" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "m = LoadedModel(model_dir, predict_mode='filename')" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "Predict each pixel's label" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# For each image, predict each pixel's label\n", 211 | "prediction_outputs = m.predict(file_to_process)\n", 212 | "probs = prediction_outputs['probs'][0]\n", 213 | "original_shape = prediction_outputs['original_shape']\n", 214 | "\n", 215 | "probs = probs[:, :, 1] # Take only class '1' (class 0 is the background, class 1 is the page)\n", 216 | "probs = probs / np.max(probs) # Normalize to be in [0, 1]\n", 217 | "\n", 218 | "# Binarize the predictions\n", 219 | "page_bin = page_make_binary_mask(probs)\n", 220 | "\n", 221 | "# Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes)\n", 222 | "bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False),\n", 223 | " tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "Show the probability map and binarized mask" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "plt.figure(figsize=(10,10))\n", 240 | "plt.subplot(1,2,1)\n", 241 | "plt.imshow(probs, cmap='gray')\n", 242 | "plt.axis('off')\n", 243 | "plt.title('Probability map')\n", 244 | "plt.subplot(1,2,2)\n", 245 | "plt.imshow(page_bin, cmap='gray')\n", 246 | "plt.axis('off')\n", 247 | "plt.title('Binary mask')" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "Find quadrilateral enclosing the page" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False),\n", 264 | " mode='min_rectangle', n_max_boxes=1)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# Draw page box on original image and export it. Add also box coordinates to the txt file\n", 274 | "original_img = imread(file_to_process, pilmode='RGB')\n", 275 | "if pred_page_coords is not None:\n", 276 | " cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5)\n", 277 | "else:\n", 278 | " print('No box found in {}'.format(filename))" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "plt.figure(figsize=(10,10))\n", 288 | "plt.imshow(original_img)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "Export image and create page region and XML file" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "basename = os.path.basename(file_to_process).split('.')[0]\n", 305 | "imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img)\n", 306 | "\n", 307 | "page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :]))\n", 308 | "page_xml = PAGE.Page(image_filename=file_to_process, image_width=original_shape[1], image_height=original_shape[0], page_border=page_border)\n", 309 | "xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename))\n", 310 | "page_xml.write_to_file(xml_filename, creator_name='PageExtractor')" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "#### 4. Have a look at the results in ``demo/processed_images``" 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "Python [conda env:dhsegment]", 333 | "language": "python", 334 | "name": "conda-env-dhsegment-py" 335 | }, 336 | "language_info": { 337 | "codemirror_mode": { 338 | "name": "ipython", 339 | "version": 3.0 340 | }, 341 | "file_extension": ".py", 342 | "mimetype": "text/x-python", 343 | "name": "python", 344 | "nbconvert_exporter": "python", 345 | "pygments_lexer": "ipython3", 346 | "version": "3.5.6" 347 | } 348 | }, 349 | "nbformat": 4, 350 | "nbformat_minor": 0 351 | } -------------------------------------------------------------------------------- /dh_segment/__init__.py: -------------------------------------------------------------------------------- 1 | # _MODEL = [ 2 | # 'inference_vgg16', 3 | # 'inference_resnet_v1_50', 4 | # 'inference_u_net', 5 | # 'vgg_16_fn', 6 | # 'resnet_v1_50_fn' 7 | # ] 8 | # 9 | # _INPUT = [ 10 | # 'input_fn', 11 | # 'serving_input_filename', 12 | # 'serving_input_image', 13 | # 'data_augmentation_fn', 14 | # 'rotate_crop', 15 | # 'resize_image', 16 | # 'load_and_resize_image', 17 | # 'extract_patches_fn', 18 | # 'local_entropy' 19 | # ] 20 | # 21 | # _ESTIMATOR = [ 22 | # 'model_fn' 23 | # ] 24 | # 25 | # _LOADER = [ 26 | # 'LoadedModel' 27 | # ] 28 | # 29 | # _UTILS = [ 30 | # 'PredictionType', 31 | # 'VGG16ModelParams', 32 | # 'ResNetModelParams', 33 | # 'UNetModelParams', 34 | # 'ModelParams', 35 | # 'TrainingParams', 36 | # 'label_image_to_class', 37 | # 'class_to_label_image', 38 | # 'multilabel_image_to_class', 39 | # 'multiclass_to_label_image', 40 | # 'get_classes_color_from_file', 41 | # 'get_n_classes_from_file', 42 | # 'get_classes_color_from_file_multilabel', 43 | # 'get_n_classes_from_file_multilabel', 44 | # '_get_image_shape_tensor', 45 | # ] 46 | # 47 | # __all__ = _MODEL + _INPUT + _ESTIMATOR + _LOADER + _UTILS 48 | # 49 | # from dh_segment.model.pretrained_models import * 50 | # 51 | # from dh_segment.network import * 52 | # from .estimator_fn import * 53 | # from .io import * 54 | # from .network import * 55 | # from .inference import * 56 | # from .utils import * -------------------------------------------------------------------------------- /dh_segment/estimator_fn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .utils import PredictionType, ModelParams, TrainingParams, \ 3 | class_to_label_image, multiclass_to_label_image 4 | import numpy as np 5 | from .network.model import inference_resnet_v1_50, inference_vgg16, inference_u_net 6 | 7 | 8 | def model_fn(mode, features, labels, params): 9 | model_params = ModelParams(**params['model_params']) 10 | training_params = TrainingParams.from_dict(params['training_params']) 11 | prediction_type = params['prediction_type'] 12 | classes_file = params['classes_file'] 13 | 14 | input_images = features['images'] 15 | 16 | if mode == tf.estimator.ModeKeys.PREDICT: 17 | margin = training_params.training_margin 18 | input_images = tf.pad(input_images, [[0, 0], [margin, margin], [margin, margin], [0, 0]], 19 | mode='SYMMETRIC', name='mirror_padding') 20 | 21 | if model_params.pretrained_model_name == 'vgg16': 22 | network_output = inference_vgg16(input_images, 23 | model_params, 24 | model_params.n_classes, 25 | use_batch_norm=model_params.batch_norm, 26 | weight_decay=model_params.weight_decay, 27 | is_training=(mode == tf.estimator.ModeKeys.TRAIN) 28 | ) 29 | key_restore_model = 'vgg_16' 30 | 31 | elif model_params.pretrained_model_name == 'resnet50': 32 | network_output = inference_resnet_v1_50(input_images, 33 | model_params, 34 | model_params.n_classes, 35 | use_batch_norm=model_params.batch_norm, 36 | weight_decay=model_params.weight_decay, 37 | is_training=(mode == tf.estimator.ModeKeys.TRAIN) 38 | ) 39 | key_restore_model = 'resnet_v1_50' 40 | elif model_params.pretrained_model_name == 'unet': 41 | network_output = inference_u_net(input_images, 42 | model_params, 43 | model_params.n_classes, 44 | use_batch_norm=model_params.batch_norm, 45 | weight_decay=model_params.weight_decay, 46 | is_training=(mode == tf.estimator.ModeKeys.TRAIN) 47 | ) 48 | key_restore_model = None 49 | else: 50 | raise NotImplementedError 51 | 52 | if mode == tf.estimator.ModeKeys.TRAIN: 53 | if key_restore_model is not None: 54 | # Pretrained weights as initialization 55 | pretrained_restorer = tf.train.Saver(var_list=[v for v in tf.global_variables() 56 | if key_restore_model in v.name]) 57 | 58 | def init_fn(scaffold, session): 59 | pretrained_restorer.restore(session, model_params.pretrained_model_file) 60 | else: 61 | init_fn = None 62 | else: 63 | init_fn = None 64 | 65 | if mode == tf.estimator.ModeKeys.PREDICT: 66 | margin = training_params.training_margin 67 | # Crop padding 68 | if margin > 0: 69 | network_output = network_output[:, margin:-margin, margin:-margin, :] 70 | 71 | # Prediction 72 | # ---------- 73 | if prediction_type == PredictionType.CLASSIFICATION: 74 | prediction_probs = tf.nn.softmax(network_output, name='softmax') 75 | prediction_labels = tf.argmax(network_output, axis=-1, name='label_preds') 76 | predictions = {'probs': prediction_probs, 'labels': prediction_labels} 77 | elif prediction_type == PredictionType.REGRESSION: 78 | predictions = {'output_values': network_output} 79 | prediction_labels = network_output 80 | elif prediction_type == PredictionType.MULTILABEL: 81 | with tf.name_scope('prediction_ops'): 82 | prediction_probs = tf.nn.sigmoid(network_output, name='sigmoid') # [B,H,W,C] 83 | prediction_labels = tf.cast(tf.greater_equal(prediction_probs, 0.5, name='labels'), tf.int32) # [B,H,W,C] 84 | predictions = {'probs': prediction_probs, 'labels': prediction_labels} 85 | else: 86 | raise NotImplementedError 87 | 88 | # Loss 89 | # ---- 90 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 91 | regularized_loss = tf.losses.get_regularization_loss() 92 | if prediction_type == PredictionType.CLASSIFICATION: 93 | onehot_labels = tf.one_hot(indices=labels, depth=model_params.n_classes) 94 | with tf.name_scope("loss"): 95 | per_pixel_loss = tf.nn.softmax_cross_entropy_with_logits(logits=network_output, 96 | labels=onehot_labels, name='per_pixel_loss') 97 | if training_params.focal_loss_gamma > 0.0: 98 | # Probability per pixel of getting the correct label 99 | probs_correct_label = tf.reduce_max(tf.multiply(prediction_probs, onehot_labels)) 100 | modulation = tf.pow((1. - probs_correct_label), training_params.focal_loss_gamma) 101 | per_pixel_loss = tf.multiply(per_pixel_loss, modulation) 102 | 103 | if training_params.weights_labels is not None: 104 | weight_mask = tf.reduce_sum( 105 | tf.constant(np.array(training_params.weights_labels, dtype=np.float32)[None, None, None]) * 106 | onehot_labels, axis=-1) 107 | per_pixel_loss = per_pixel_loss * weight_mask 108 | if training_params.local_entropy_ratio > 0: 109 | assert 'weight_maps' in features 110 | r = training_params.local_entropy_ratio 111 | per_pixel_loss = per_pixel_loss * ((1 - r) + r * features['weight_maps']) 112 | 113 | elif prediction_type == PredictionType.REGRESSION: 114 | per_pixel_loss = tf.squared_difference(labels, network_output, name='per_pixel_loss') 115 | elif prediction_type == PredictionType.MULTILABEL: 116 | with tf.name_scope('sigmoid_xentropy_loss'): 117 | labels_floats = tf.cast(labels, tf.float32) 118 | per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats, 119 | logits=network_output, name='per_pixel_loss') 120 | if training_params.weights_labels is not None: 121 | weight_mask = tf.maximum( 122 | tf.reduce_max(tf.constant( 123 | np.array(training_params.weights_labels, dtype=np.float32)[None, None, None]) 124 | * labels_floats, axis=-1), 1.0) 125 | per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] 126 | else: 127 | raise NotImplementedError 128 | 129 | margin = training_params.training_margin 130 | input_shapes = features['shapes'] 131 | with tf.name_scope('Loss'): 132 | def _fn(_in): 133 | output, shape = _in 134 | return tf.reduce_mean(output[margin:shape[0] - margin, margin:shape[1] - margin]) 135 | 136 | per_img_loss = tf.map_fn(_fn, (per_pixel_loss, input_shapes), dtype=tf.float32) 137 | loss = tf.reduce_mean(per_img_loss, name='loss') 138 | 139 | loss += regularized_loss 140 | else: 141 | loss, regularized_loss = None, None 142 | 143 | # Train 144 | # ----- 145 | if mode == tf.estimator.ModeKeys.TRAIN: 146 | # >> Stucks the training... Why ? 147 | # ema = tf.train.ExponentialMovingAverage(0.9) 148 | # tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema.apply([loss])) 149 | # ema_loss = ema.average(loss) 150 | 151 | if training_params.exponential_learning: 152 | global_step = tf.train.get_or_create_global_step() 153 | learning_rate = tf.train.exponential_decay(training_params.learning_rate, global_step, decay_steps=200, 154 | decay_rate=0.95, staircase=False) 155 | else: 156 | learning_rate = training_params.learning_rate 157 | tf.summary.scalar('learning_rate', learning_rate) 158 | optimizer = tf.train.AdamOptimizer(learning_rate) 159 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 160 | train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step()) 161 | else: 162 | ema_loss, train_op = None, None 163 | 164 | # Summaries 165 | # --------- 166 | if mode == tf.estimator.ModeKeys.TRAIN: 167 | with tf.name_scope('summaries'): 168 | tf.summary.scalar('losses/loss', loss) 169 | tf.summary.scalar('losses/loss_per_batch', loss) 170 | tf.summary.scalar('losses/regularized_loss', regularized_loss) 171 | if prediction_type == PredictionType.CLASSIFICATION: 172 | tf.summary.image('output/prediction', 173 | tf.image.resize_images(class_to_label_image(prediction_labels, classes_file), 174 | tf.cast(tf.shape(network_output)[1:3] / 3, tf.int32)), 175 | max_outputs=1) 176 | if model_params.n_classes == 3: 177 | tf.summary.image('output/probs', 178 | tf.image.resize_images(prediction_probs[:, :, :, :], 179 | tf.cast(tf.shape(network_output)[1:3] / 3, tf.int32)), 180 | max_outputs=1) 181 | if model_params.n_classes == 2: 182 | tf.summary.image('output/probs', 183 | tf.image.resize_images(prediction_probs[:, :, :, 1:2], 184 | tf.cast(tf.shape(network_output)[1:3] / 3, tf.int32)), 185 | max_outputs=1) 186 | elif prediction_type == PredictionType.REGRESSION: 187 | summary_img = tf.nn.relu(network_output)[:, :, :, 0:1] # Put negative values to zero 188 | tf.summary.image('output/prediction', summary_img, max_outputs=1) 189 | elif prediction_type == PredictionType.MULTILABEL: 190 | labels_visualization = tf.cast(prediction_labels, tf.int32) 191 | labels_visualization = multiclass_to_label_image(labels_visualization, classes_file) 192 | tf.summary.image('output/prediction_image', 193 | tf.image.resize_images(labels_visualization, 194 | tf.cast(tf.shape(labels_visualization)[1:3] / 3, tf.int32)), 195 | max_outputs=1) 196 | class_dim = prediction_probs.get_shape().as_list()[-1] 197 | for c in range(0, class_dim): 198 | tf.summary.image('output/prediction_probs_{}'.format(c), 199 | tf.image.resize_images(prediction_probs[:, :, :, c:c + 1], 200 | tf.cast(tf.shape(network_output)[1:3] / 3, tf.int32)), 201 | max_outputs=1) 202 | 203 | # beta = tf.get_default_graph().get_tensor_by_name('upsampling/deconv_5/conv5/batch_norm/beta/read:0') 204 | # tf.summary.histogram('Beta', beta) 205 | 206 | # Evaluation 207 | # ---------- 208 | if mode == tf.estimator.ModeKeys.EVAL: 209 | if prediction_type == PredictionType.CLASSIFICATION: 210 | metrics = { 211 | 'eval/accuracy': tf.metrics.accuracy(labels, predictions=prediction_labels), 212 | 'eval/mIOU': tf.metrics.mean_iou(labels, prediction_labels, num_classes=model_params.n_classes,) 213 | # weights=tf.cast(training_params.weights_evaluation_miou, tf.float32)) 214 | } 215 | elif prediction_type == PredictionType.REGRESSION: 216 | metrics = {'eval/accuracy': tf.metrics.mean_squared_error(labels, predictions=prediction_labels)} 217 | elif prediction_type == PredictionType.MULTILABEL: 218 | metrics = {'eval/MSE': tf.metrics.mean_squared_error(tf.cast(labels, tf.float32), 219 | predictions=prediction_probs), 220 | 'eval/accuracy': tf.metrics.accuracy(tf.cast(labels, tf.bool), 221 | predictions=tf.cast(prediction_labels, tf.bool)), 222 | 'eval/mIOU': tf.metrics.mean_iou(labels, prediction_labels, num_classes=model_params.n_classes) 223 | # weights=training_params.weights_evaluation_miou) 224 | } 225 | else: 226 | metrics = None 227 | 228 | # Export 229 | # ------ 230 | if mode == tf.estimator.ModeKeys.PREDICT: 231 | 232 | export_outputs = dict() 233 | 234 | if 'original_shape' in features.keys(): 235 | with tf.name_scope('ResizeOutput'): 236 | resized_predictions = dict() 237 | # Resize all the elements in predictions 238 | for k, v in predictions.items(): 239 | # Labels is rank-3 so we need to be careful in using tf.image.resize_images 240 | assert isinstance(v, tf.Tensor) 241 | v2 = v if len(v.get_shape()) == 4 else v[:, :, :, None] 242 | v2 = tf.image.resize_images(v2, features['original_shape'], 243 | method=tf.image.ResizeMethod.BILINEAR if v.dtype == tf.float32 244 | else tf.image.ResizeMethod.NEAREST_NEIGHBOR) 245 | v2 = v2 if len(v.get_shape()) == 4 else v2[:, :, :, 0] 246 | resized_predictions[k] = v2 247 | export_outputs['resized_output'] = tf.estimator.export.PredictOutput(resized_predictions) 248 | 249 | predictions['original_shape'] = features['original_shape'] 250 | 251 | export_outputs['output'] = tf.estimator.export.PredictOutput(predictions) 252 | 253 | export_outputs[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['output'] 254 | else: 255 | export_outputs = None 256 | 257 | return tf.estimator.EstimatorSpec(mode, 258 | predictions=predictions, 259 | loss=loss, 260 | train_op=train_op, 261 | eval_metric_ops=metrics, 262 | export_outputs=export_outputs, 263 | scaffold=tf.train.Scaffold(init_fn=init_fn) 264 | ) 265 | -------------------------------------------------------------------------------- /dh_segment/inference/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.inference` module implements the function related to the usage of a dhSegment model, 3 | for instance to use a trained model to inference on new data. 4 | 5 | Loading a model 6 | --------------- 7 | 8 | .. autosummary:: 9 | LoadedModel 10 | 11 | 12 | ----- 13 | """ 14 | 15 | __all__ = ['LoadedModel'] 16 | 17 | from .loader import * -------------------------------------------------------------------------------- /dh_segment/inference/loader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from threading import Semaphore 4 | import numpy as np 5 | import tempfile 6 | from imageio import imsave, imread 7 | 8 | _original_shape_key = 'original_shape' 9 | 10 | 11 | class LoadedModel: 12 | """ 13 | Loads an exported dhSegment model 14 | 15 | :param model_base_dir: the model directory i.e. containing `saved_model.{pb|pbtxt}`. If not, it is assumed to \ 16 | be a TF exporter directory, and the latest export directory will be automatically selected. 17 | :param predict_mode: defines the input/output format of the prediction output (see `.predict()`) 18 | :param num_parallel_predictions: limits the number of conccurent calls of `predict` to avoid Out-Of-Memory \ 19 | issues if predicting on GPU 20 | """ 21 | 22 | def __init__(self, model_base_dir, predict_mode='filename', num_parallel_predictions=2): 23 | if os.path.exists(os.path.join(model_base_dir, 'saved_model.pbtxt')) or \ 24 | os.path.exists(os.path.join(model_base_dir, 'saved_model.pb')): 25 | model_dir = model_base_dir 26 | else: 27 | possible_dirs = os.listdir(model_base_dir) 28 | model_dir = os.path.join(model_base_dir, max(possible_dirs)) # Take latest export 29 | print("Loading {}".format(model_dir)) 30 | 31 | if predict_mode == 'filename': 32 | input_dict_key = 'filename' 33 | signature_def_key = 'serving_default' 34 | elif predict_mode == 'filename_original_shape': 35 | input_dict_key = 'filename' 36 | signature_def_key = 'resized_output' 37 | elif predict_mode == 'image': 38 | input_dict_key = 'image' 39 | signature_def_key = 'from_image:serving_default' 40 | elif predict_mode == 'image_original_shape': 41 | input_dict_key = 'image' 42 | signature_def_key = 'from_image:resized_output' 43 | elif predict_mode == 'resized_images': 44 | input_dict_key = 'resized_images' 45 | signature_def_key = 'from_resized_images:serving_default' 46 | else: 47 | raise NotImplementedError 48 | self.predict_mode = predict_mode 49 | 50 | self.sess = tf.get_default_session() 51 | loaded_model = tf.saved_model.loader.load(self.sess, ['serve'], model_dir) 52 | assert 'serving_default' in list(loaded_model.signature_def) 53 | 54 | input_dict, output_dict = _signature_def_to_tensors(loaded_model.signature_def[signature_def_key]) 55 | assert input_dict_key in input_dict.keys(), "{} not present in input_keys, " \ 56 | "possible values: {}".format(input_dict_key, input_dict.keys()) 57 | self._input_tensor = input_dict[input_dict_key] 58 | self._output_dict = output_dict 59 | if predict_mode == 'resized_images': 60 | # This node is not defined in this specific run-mode as there is no original image 61 | del self._output_dict['original_shape'] 62 | self.sema = Semaphore(num_parallel_predictions) 63 | 64 | def predict(self, input_tensor, prediction_key=None): 65 | """ 66 | Performs the prediction from the loaded model according to the prediction mode. \n 67 | Prediction modes: 68 | 69 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 70 | | `prediction_mode` | `input_tensor` | Output prediction dictionnary | Comment | 71 | +=============================+===============================================+======================================+===================================================================================================+ 72 | | `filename` | Single filename string | `labels`, `probs`, `original_shape` | Loads the image, resizes it, and predicts | 73 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 74 | | `filename_original_shape` | Single filename string | `labels`, `probs` | Loads the image, resizes it, predicts and scale the output to the original resolution of the file | 75 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 76 | | `image` | Single input image [1,H,W,3] float32 (0..255) | `labels`, `probs`, `original_shape` | Resizes the image, and predicts | 77 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 78 | | `image_original_shape` | Single input image [1,H,W,3] float32 (0..255) | `labels`, `probs` | Resizes the image, predicts, and scale the output to the original resolution of the input | 79 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 80 | | `image_resized` | Single input image [1,H,W,3] float32 (0..255) | `labels`, `probs` | Predicts from the image input directly | 81 | +-----------------------------+-----------------------------------------------+--------------------------------------+---------------------------------------------------------------------------------------------------+ 82 | 83 | :param input_tensor: a single input whose format should match the prediction mode 84 | :param prediction_key: if not `None`, will returns the value of the corresponding key of the output dictionnary \ 85 | instead of the full dictionnary 86 | :return: the prediction output 87 | """ 88 | with self.sema: 89 | if prediction_key: 90 | desired_output = self._output_dict[prediction_key] 91 | else: 92 | desired_output = self._output_dict 93 | return self.sess.run(desired_output, feed_dict={self._input_tensor: input_tensor}) 94 | 95 | def predict_with_tiles(self, filename: str, resized_size: int=None, tile_size: int=500, 96 | min_overlap: float=0.2, linear_interpolation: bool=True): 97 | 98 | # TODO this part should only happen if self.predict_mode == 'resized_images' 99 | 100 | if resized_size is None or resized_size < 0: 101 | image_np = imread(filename) 102 | h, w = image_np.shape[:2] 103 | batch_size = 1 104 | else: 105 | raise NotImplementedError 106 | assert h > tile_size, w > tile_size 107 | # Get x and y coordinates of beginning of tiles and compute prediction for each tile 108 | y_step = np.ceil((h - tile_size) / (tile_size * (1 - min_overlap))) 109 | x_step = np.ceil((w - tile_size) / (tile_size * (1 - min_overlap))) 110 | y_pos = np.round(np.arange(y_step + 1) / y_step * (h - tile_size)).astype(np.int32) 111 | x_pos = np.round(np.arange(x_step + 1) / x_step * (w - tile_size)).astype(np.int32) 112 | 113 | all_outputs = list() 114 | with tempfile.TemporaryDirectory() as tmpdirname: 115 | for i, y in enumerate(y_pos): 116 | inside_list = list() 117 | for j, x in enumerate(x_pos): 118 | filename_tile = os.path.join(tmpdirname, 'tile{}{}.png'.format(i, j)) 119 | imsave(filename_tile, image_np[y:y + tile_size, x:x + tile_size]) 120 | inside_list.append(self.predict(filename_tile))#, prediction_key='probs')) 121 | all_outputs.append(inside_list) 122 | 123 | def _merge_x(full_output, assigned_up_to, new_input, begin_position): 124 | assert full_output.shape[1] == new_input.shape[1], \ 125 | "Shape full output is {}, but shape new_input is {}".format(full_output.shape[1], new_input.shape[1]) 126 | overlap_size = assigned_up_to - begin_position 127 | normal_part_size = new_input.shape[2] - overlap_size 128 | assert normal_part_size > 0 129 | full_output[:, :, assigned_up_to:assigned_up_to + normal_part_size] = new_input[:, :, overlap_size:] 130 | if overlap_size > 0: 131 | weights = np.arange(0, overlap_size) / overlap_size 132 | full_output[:, :, begin_position:assigned_up_to] = (1 - weights)[:, None] * full_output[:, :, 133 | begin_position:assigned_up_to] + \ 134 | weights[:, None] * new_input[:, :, :overlap_size] 135 | 136 | def _merge_y(full_output, assigned_up_to, new_input, begin_position): 137 | assert full_output.shape[2] == new_input.shape[2] 138 | overlap_size = assigned_up_to - begin_position 139 | normal_part_size = new_input.shape[1] - overlap_size 140 | assert normal_part_size > 0 141 | full_output[:, assigned_up_to:assigned_up_to + normal_part_size] = new_input[:, overlap_size:] 142 | if overlap_size > 0: 143 | weights = np.arange(0, overlap_size) / overlap_size 144 | full_output[:, begin_position:assigned_up_to] = (1 - weights)[:, None, None] * full_output[:, 145 | begin_position:assigned_up_to] + \ 146 | weights[:, None, None] * new_input[:, :overlap_size] 147 | 148 | result = {k: np.empty([batch_size, h, w] + list(v.shape[3:]), v.dtype) for k, v in all_outputs[0][0].items() 149 | if k != _original_shape_key} # do not try to merge 'original_shape' content... 150 | if linear_interpolation: 151 | for k in result.keys(): 152 | assigned_up_to_y = 0 153 | for y, y_outputs in zip(y_pos, all_outputs): 154 | s = list(result[k].shape) 155 | tmp = np.zeros([batch_size, tile_size] + s[2:], result[k].dtype) 156 | assigned_up_to_x = 0 157 | for x, output in zip(x_pos, y_outputs): 158 | _merge_x(tmp, assigned_up_to_x, output[k], x) 159 | assigned_up_to_x = x + tile_size 160 | _merge_y(result[k], assigned_up_to_y, tmp, y) 161 | assigned_up_to_y = y + tile_size 162 | else: 163 | for k in result.keys(): 164 | for y, y_outputs in zip(y_pos, all_outputs): 165 | for x, output in zip(x_pos, y_outputs): 166 | result[k][:, y:y + tile_size, x:x + tile_size] = output[k] 167 | 168 | result[_original_shape_key] = np.array([h, w], np.uint) 169 | return result 170 | 171 | 172 | def _signature_def_to_tensors(signature_def): 173 | g = tf.get_default_graph() 174 | return {k: g.get_tensor_by_name(v.name) for k, v in signature_def.inputs.items()}, \ 175 | {k: g.get_tensor_by_name(v.name) for k, v in signature_def.outputs.items()} 176 | -------------------------------------------------------------------------------- /dh_segment/io/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.io` module implements input / output functions and classes. 3 | 4 | Input functions for ``tf.Estimator`` 5 | ------------------------------------ 6 | 7 | **Input function** 8 | 9 | .. autosummary:: 10 | input_fn 11 | 12 | **Data augmentation** 13 | 14 | .. autosummary:: 15 | data_augmentation_fn 16 | extract_patches_fn 17 | rotate_crop 18 | 19 | **Resizing function** 20 | 21 | .. autosummary:: 22 | resize_image 23 | load_and_resize_image 24 | 25 | 26 | Tensorflow serving functions 27 | ---------------------------- 28 | 29 | .. autosummary:: 30 | serving_input_filename 31 | serving_input_image 32 | 33 | ---- 34 | 35 | PAGE XML and JSON import / export 36 | --------------------------------- 37 | 38 | **PAGE classes** 39 | 40 | .. autosummary:: 41 | PAGE.Point 42 | PAGE.Text 43 | PAGE.Border 44 | PAGE.TextRegion 45 | PAGE.TextLine 46 | PAGE.GraphicRegion 47 | PAGE.TableRegion 48 | PAGE.SeparatorRegion 49 | PAGE.GroupSegment 50 | PAGE.Metadata 51 | PAGE.Page 52 | 53 | **Abstract classes** 54 | 55 | .. autosummary:: 56 | PAGE.BaseElement 57 | PAGE.Region 58 | 59 | **Parsing and helpers** 60 | 61 | .. autosummary:: 62 | PAGE.parse_file 63 | PAGE.json_serialize 64 | 65 | ---- 66 | 67 | .. _ref_via: 68 | 69 | VGG Image Annotator helpers 70 | --------------------------- 71 | 72 | 73 | **VIA objects** 74 | 75 | .. autosummary:: 76 | via.WorkingItem 77 | via.VIAttribute 78 | 79 | 80 | **Creating masks with VIA annotations** 81 | 82 | .. autosummary:: 83 | via.load_annotation_data 84 | via.export_annotation_dict 85 | via.get_annotations_per_file 86 | via.parse_via_attributes 87 | via.get_via_attributes 88 | via.collect_working_items 89 | via.create_masks 90 | 91 | 92 | **Formatting in VIA JSON format** 93 | 94 | .. autosummary:: 95 | via.create_via_region_from_coordinates 96 | via.create_via_annotation_single_image 97 | 98 | ---- 99 | 100 | """ 101 | 102 | 103 | _INPUT = [ 104 | 'input_fn', 105 | 'serving_input_filename', 106 | 'serving_input_image', 107 | 'data_augmentation_fn', 108 | 'rotate_crop', 109 | 'resize_image', 110 | 'load_and_resize_image', 111 | 'extract_patches_fn', 112 | 'local_entropy' 113 | ] 114 | 115 | # _PAGE_OBJECTS = [ 116 | # 'Point', 117 | # 'Text', 118 | # 'Region', 119 | # 'TextLine', 120 | # 'GraphicRegion', 121 | # 'TextRegion', 122 | # 'TableRegion', 123 | # 'SeparatorRegion', 124 | # 'Border', 125 | # 'Metadata', 126 | # 'GroupSegment', 127 | # 'Page' 128 | # ] 129 | # 130 | # _PAGE_FN = [ 131 | # 'parse_file', 132 | # 'json_serialize' 133 | # ] 134 | 135 | __all__ = _INPUT # + _PAGE_OBJECTS + _PAGE_FN 136 | 137 | from .input import * 138 | from .input_utils import * 139 | from . import PAGE 140 | from . import via 141 | 142 | -------------------------------------------------------------------------------- /dh_segment/io/input.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import tensorflow as tf 4 | import numpy as np 5 | from .. import utils 6 | from tqdm import tqdm 7 | from typing import Union, List 8 | from enum import Enum 9 | import pandas as pd 10 | from .input_utils import data_augmentation_fn, extract_patches_fn, load_and_resize_image, \ 11 | rotate_crop, resize_image, local_entropy 12 | 13 | 14 | class InputCase(Enum): 15 | INPUT_LIST = 'INPUT_LIST' 16 | INPUT_DIR = 'INPUT_DIR' 17 | INPUT_CSV = 'INPUT_CSV' 18 | 19 | 20 | def input_fn(input_data: Union[str, List[str]], params: dict, input_label_dir: str=None, 21 | data_augmentation: bool=False, batch_size: int=5, make_patches: bool=False, num_epochs: int=1, 22 | num_threads: int=4, image_summaries: bool=False): 23 | """ 24 | Input_fn for estimator 25 | 26 | :param input_data: input data. It can be a directory containing the images, it can be 27 | a list of image filenames, or it can be a path to a csv file. 28 | :param params: params from utils.Params object 29 | :param input_label_dir: directory containing the label images 30 | :param data_augmentation: boolean, if True will scale, roatate, ... the images 31 | :param batch_size: size of the bach 32 | :param make_patches: bool, whether to make patches (crop image in smaller pieces) or not 33 | :param num_epochs: number of epochs to cycle trough data (set it to None for infinite repeat) 34 | :param num_threads: number of thread to use in parallele when usin tf.data.Dataset.map 35 | :param image_summaries: boolean, whether to make tf.Summary to watch on tensorboard 36 | :return: fn 37 | """ 38 | training_params = utils.TrainingParams.from_dict(params['training_params']) 39 | prediction_type = params['prediction_type'] 40 | classes_file = params['classes_file'] 41 | 42 | # --- Map functions 43 | def _make_patches_fn(input_image: tf.Tensor, label_image: tf.Tensor, offsets: tuple) -> (tf.Tensor, tf.Tensor): 44 | with tf.name_scope('patching'): 45 | patches_image = extract_patches_fn(input_image, training_params.patch_shape, offsets) 46 | patches_label = extract_patches_fn(label_image, training_params.patch_shape, offsets) 47 | 48 | return patches_image, patches_label 49 | 50 | # Load and resize images 51 | def _load_image_fn(image_filename, label_filename): 52 | if training_params.data_augmentation and training_params.input_resized_size > 0: 53 | random_scaling = tf.random_uniform([], 54 | np.maximum(1 - training_params.data_augmentation_max_scaling, 0), 55 | 1 + training_params.data_augmentation_max_scaling) 56 | new_size = training_params.input_resized_size * random_scaling 57 | else: 58 | new_size = training_params.input_resized_size 59 | 60 | if prediction_type in [utils.PredictionType.CLASSIFICATION, utils.PredictionType.MULTILABEL]: 61 | label_image = load_and_resize_image(label_filename, 3, new_size, interpolation='NEAREST') 62 | elif prediction_type == utils.PredictionType.REGRESSION: 63 | label_image = load_and_resize_image(label_filename, 1, new_size, interpolation='NEAREST') 64 | else: 65 | raise NotImplementedError 66 | input_image = load_and_resize_image(image_filename, 3, new_size) 67 | return input_image, label_image 68 | 69 | # Data augmentation, patching 70 | def _scaling_and_patch_fn(input_image, label_image): 71 | if data_augmentation: 72 | # Rotation of the original image 73 | if training_params.data_augmentation_max_rotation > 0: 74 | with tf.name_scope('random_rotation'): 75 | rotation_angle = tf.random_uniform([], 76 | -training_params.data_augmentation_max_rotation, 77 | training_params.data_augmentation_max_rotation) 78 | label_image = rotate_crop(label_image, rotation_angle, 79 | minimum_shape=[(i * 3) // 2 for i in training_params.patch_shape], 80 | interpolation='NEAREST') 81 | input_image = rotate_crop(input_image, rotation_angle, 82 | minimum_shape=[(i * 3) // 2 for i in training_params.patch_shape], 83 | interpolation='BILINEAR') 84 | 85 | if make_patches: 86 | # Offsets for patch extraction 87 | offsets = (tf.random_uniform(shape=[], minval=0, maxval=1, dtype=tf.float32), 88 | tf.random_uniform(shape=[], minval=0, maxval=1, dtype=tf.float32)) 89 | # offsets = (0, 0) 90 | batch_image, batch_label = _make_patches_fn(input_image, label_image, offsets) 91 | else: 92 | with tf.name_scope('formatting'): 93 | batch_image = tf.expand_dims(input_image, axis=0) 94 | batch_label = tf.expand_dims(label_image, axis=0) 95 | return tf.data.Dataset.from_tensor_slices((batch_image, batch_label)) 96 | 97 | # Data augmentation 98 | def _augment_data_fn(input_image, label_image): \ 99 | return data_augmentation_fn(input_image, label_image, training_params.data_augmentation_flip_lr, 100 | training_params.data_augmentation_flip_ud, training_params.data_augmentation_color) 101 | 102 | # Assign color to class id 103 | def _assign_color_to_class_id(input_image, label_image): 104 | # Convert RGB to class id 105 | if prediction_type == utils.PredictionType.CLASSIFICATION: 106 | label_image = utils.label_image_to_class(label_image, classes_file) 107 | elif prediction_type == utils.PredictionType.MULTILABEL: 108 | label_image = utils.multilabel_image_to_class(label_image, classes_file) 109 | output = {'images': input_image, 'labels': label_image} 110 | 111 | if training_params.local_entropy_ratio > 0 and prediction_type == utils.PredictionType.CLASSIFICATION: 112 | output['weight_maps'] = local_entropy(tf.equal(label_image, 1), 113 | sigma=training_params.local_entropy_sigma) 114 | return output 115 | # --- 116 | 117 | # Finding the list of images to be used 118 | if isinstance(input_data, list): 119 | input_case = InputCase.INPUT_LIST 120 | input_image_filenames = input_data 121 | print('Found {} images'.format(len(input_image_filenames))) 122 | 123 | elif os.path.isdir(input_data): 124 | input_case = InputCase.INPUT_DIR 125 | input_image_filenames = glob(os.path.join(input_data, '**', '*.jpg'), 126 | recursive=True) + \ 127 | glob(os.path.join(input_data, '**', '*.png'), 128 | recursive=True) 129 | print('Found {} images'.format(len(input_image_filenames))) 130 | 131 | elif os.path.isfile(input_data) and \ 132 | input_data.endswith('.csv'): 133 | input_case = InputCase.INPUT_CSV 134 | else: 135 | raise NotImplementedError('Input data should be a directory, a csv file or a list of filenames but got {}'.format(input_data)) 136 | 137 | # Finding the list of labelled images if available 138 | has_labelled_data = False 139 | if input_label_dir and input_case in [InputCase.INPUT_LIST, InputCase.INPUT_DIR]: 140 | label_image_filenames = [] 141 | for input_image_filename in input_image_filenames: 142 | label_image_filename = os.path.join(input_label_dir, os.path.basename(input_image_filename)) 143 | if not os.path.exists(label_image_filename): 144 | filename, extension = os.path.splitext(os.path.basename(input_image_filename)) 145 | new_extension = '.png' if extension == '.jpg' else '.jpg' 146 | label_image_filename = os.path.join(input_label_dir, filename + new_extension) 147 | label_image_filenames.append(label_image_filename) 148 | has_labelled_data = True 149 | 150 | # Read image filenames and labels in case of csv file 151 | if input_case == InputCase.INPUT_CSV: 152 | df = pd.read_csv(input_data, header=None, names=['images', 'labels']) 153 | input_image_filenames = list(df.images.values) 154 | # If the label column exists 155 | if not np.alltrue(pd.isnull(df.labels.values)): 156 | label_image_filenames = list(df.labels.values) 157 | has_labelled_data = True 158 | 159 | # Checks that all image files can be found 160 | for img_filename in input_image_filenames: 161 | if not os.path.exists(img_filename): 162 | raise FileNotFoundError(img_filename) 163 | if has_labelled_data: 164 | for label_filename in label_image_filenames: 165 | if not os.path.exists(label_filename): 166 | raise FileNotFoundError(label_filename) 167 | 168 | # Tensorflow input_fn 169 | def fn(): 170 | if not has_labelled_data: 171 | encoded_filenames = [f.encode() for f in input_image_filenames] 172 | dataset = tf.data.Dataset.from_generator(lambda: tqdm(encoded_filenames, desc='Dataset'), 173 | tf.string, tf.TensorShape([])) 174 | dataset = dataset.repeat(count=num_epochs) 175 | dataset = dataset.map(lambda filename: {'images': load_and_resize_image(filename, 3, 176 | training_params.input_resized_size)}) 177 | else: 178 | encoded_filenames = [(i.encode(), l.encode()) for i, l in zip(input_image_filenames, label_image_filenames)] 179 | dataset = tf.data.Dataset.from_generator(lambda: tqdm(utils.shuffled(encoded_filenames), desc='Dataset'), 180 | (tf.string, tf.string), (tf.TensorShape([]), tf.TensorShape([]))) 181 | 182 | dataset = dataset.repeat(count=num_epochs) 183 | dataset = dataset.map(_load_image_fn, num_threads).flat_map(_scaling_and_patch_fn) 184 | 185 | if data_augmentation: 186 | dataset = dataset.map(_augment_data_fn, num_threads) 187 | dataset = dataset.map(_assign_color_to_class_id, num_threads) 188 | 189 | # Save original size of images 190 | dataset = dataset.map(lambda d: {'shapes': tf.shape(d['images'])[:2], **d}) 191 | if make_patches: 192 | dataset = dataset.shuffle(128) 193 | 194 | if make_patches and input_label_dir: 195 | base_shape_images = list(training_params.patch_shape) 196 | elif make_patches and input_case == InputCase.INPUT_CSV: 197 | base_shape_images = list(training_params.patch_shape) 198 | else: 199 | base_shape_images = [-1, -1] 200 | # Pad things 201 | padded_shapes = { 202 | 'images': base_shape_images + [3], 203 | 'shapes': [2] 204 | } 205 | if 'labels' in dataset.output_shapes.keys(): 206 | output_shapes_label = dataset.output_shapes['labels'] 207 | padded_shapes['labels'] = base_shape_images + list(output_shapes_label[2:]) 208 | if 'weight_maps' in dataset.output_shapes.keys(): 209 | padded_shapes['weight_maps'] = base_shape_images 210 | 211 | dataset = dataset.padded_batch(batch_size=batch_size, padded_shapes=padded_shapes).prefetch(8) 212 | prepared_batch = dataset.make_one_shot_iterator().get_next() 213 | 214 | # Summaries for checking that the loading and data augmentation goes fine 215 | if image_summaries: 216 | shape_summary_img = tf.cast(tf.shape(prepared_batch['images'])[1:3] / 3, tf.int32) 217 | tf.summary.image('input/image', 218 | tf.image.resize_images(prepared_batch['images'], shape_summary_img), 219 | max_outputs=1) 220 | if 'labels' in prepared_batch: 221 | label_export = prepared_batch['labels'] 222 | if prediction_type == utils.PredictionType.CLASSIFICATION: 223 | label_export = utils.class_to_label_image(label_export, classes_file) 224 | if prediction_type == utils.PredictionType.MULTILABEL: 225 | label_export = tf.cast(label_export, tf.int32) 226 | label_export = utils.multiclass_to_label_image(label_export, classes_file) 227 | tf.summary.image('input/label', 228 | tf.image.resize_images(label_export, shape_summary_img), max_outputs=1) 229 | if 'weight_maps' in prepared_batch: 230 | tf.summary.image('input/weight_map', 231 | tf.image.resize_images(prepared_batch['weight_maps'][:, :, :, None], 232 | shape_summary_img), 233 | max_outputs=1) 234 | 235 | return prepared_batch, prepared_batch.get('labels') 236 | 237 | return fn 238 | 239 | 240 | def serving_input_filename(resized_size): 241 | def serving_input_fn(): 242 | # define placeholder for filename 243 | filename = tf.placeholder(dtype=tf.string) 244 | 245 | # TODO : make it batch-compatible (with Dataset or string input producer) 246 | decoded_image = tf.to_float(tf.image.decode_jpeg(tf.read_file(filename), channels=3, 247 | try_recover_truncated=True)) 248 | original_shape = tf.shape(decoded_image)[:2] 249 | 250 | if resized_size is not None and resized_size > 0: 251 | image = resize_image(decoded_image, resized_size) 252 | else: 253 | image = decoded_image 254 | 255 | image_batch = image[None] 256 | features = {'images': image_batch, 'original_shape': original_shape} 257 | 258 | receiver_inputs = {'filename': filename} 259 | 260 | input_from_resized_images = {'resized_images': image_batch} 261 | input_from_original_image = {'image': decoded_image} 262 | 263 | return tf.estimator.export.ServingInputReceiver(features, receiver_inputs, 264 | receiver_tensors_alternatives={'from_image': 265 | input_from_original_image, 266 | 'from_resized_images': 267 | input_from_resized_images}) 268 | 269 | return serving_input_fn 270 | 271 | 272 | def serving_input_image(): 273 | dic_input_serving = {'images': tf.placeholder(tf.float32, [None, None, None, 3])} 274 | return tf.estimator.export.build_raw_serving_input_receiver_fn(dic_input_serving) 275 | -------------------------------------------------------------------------------- /dh_segment/io/input_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.image import rotate as tf_rotate 3 | from scipy import ndimage 4 | import numpy as np 5 | from typing import Tuple 6 | 7 | 8 | def data_augmentation_fn(input_image: tf.Tensor, label_image: tf.Tensor, flip_lr: bool=True, 9 | flip_ud: bool=True, color: bool=True) -> (tf.Tensor, tf.Tensor): 10 | """Applies data augmentation to both images and label images. 11 | Includes left-right flip, up-down flip and color change. 12 | 13 | :param input_image: images to be augmented [B, H, W, C] 14 | :param label_image: corresponding label images [B, H, W, C] 15 | :param flip_lr: option to flip image in left-right direction 16 | :param flip_ud: option to flip image in up-down direction 17 | :param color: option to change color of images 18 | :return: the tuple (augmented images, augmented label images) [B, H, W, C] 19 | """ 20 | with tf.name_scope('DataAugmentation'): 21 | if flip_lr: 22 | with tf.name_scope('random_flip_lr'): 23 | sample = tf.random_uniform([], 0, 1) 24 | label_image = tf.cond(sample > 0.5, lambda: tf.image.flip_left_right(label_image), lambda: label_image) 25 | input_image = tf.cond(sample > 0.5, lambda: tf.image.flip_left_right(input_image), lambda: input_image) 26 | if flip_ud: 27 | with tf.name_scope('random_flip_ud'): 28 | sample = tf.random_uniform([], 0, 1) 29 | label_image = tf.cond(sample > 0.5, lambda: tf.image.flip_up_down(label_image), lambda: label_image) 30 | input_image = tf.cond(sample > 0.5, lambda: tf.image.flip_up_down(input_image), lambda: input_image) 31 | 32 | chanels = input_image.get_shape()[-1] 33 | if color: 34 | input_image = tf.image.random_contrast(input_image, lower=0.8, upper=1.0) 35 | if chanels == 3: 36 | input_image = tf.image.random_hue(input_image, max_delta=0.1) 37 | input_image = tf.image.random_saturation(input_image, lower=0.8, upper=1.2) 38 | return input_image, label_image 39 | 40 | 41 | def rotate_crop(image: tf.Tensor, rotation: float, crop: bool=True, minimum_shape: Tuple[int, int]=[0, 0], 42 | interpolation: str='NEAREST') -> tf.Tensor: 43 | """Rotates and crops the images. 44 | 45 | :param image: image to be rotated and cropped [H, W, C] 46 | :param rotation: angle of rotation (in radians) 47 | :param crop: option to crop rotated image to avoid black borders due to rotation 48 | :param minimum_shape: minimum shape of the rotated image / cropped image 49 | :param interpolation: which interpolation to use ``NEAREST`` or ``BILINEAR`` 50 | :return: 51 | """ 52 | with tf.name_scope('RotateCrop'): 53 | rotated_image = tf_rotate(image, rotation, interpolation) 54 | if crop: 55 | rotation = tf.abs(rotation) 56 | original_shape = tf.shape(rotated_image)[:2] 57 | h, w = original_shape[0], original_shape[1] 58 | # see https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders for formulae 59 | old_l, old_s = tf.cond(h > w, lambda: [h, w], lambda: [w, h]) 60 | old_l, old_s = tf.cast(old_l, tf.float32), tf.cast(old_s, tf.float32) 61 | new_l = (old_l * tf.cos(rotation) - old_s * tf.sin(rotation)) / tf.cos(2 * rotation) 62 | new_s = (old_s - tf.sin(rotation) * new_l) / tf.cos(rotation) 63 | new_h, new_w = tf.cond(h > w, lambda: [new_l, new_s], lambda: [new_s, new_l]) 64 | new_h, new_w = tf.cast(new_h, tf.int32), tf.cast(new_w, tf.int32) 65 | bb_begin = tf.cast(tf.ceil((h - new_h) / 2), tf.int32), tf.cast(tf.ceil((w - new_w) / 2), tf.int32) 66 | rotated_image_crop = rotated_image[bb_begin[0]:h - bb_begin[0], bb_begin[1]:w - bb_begin[1], :] 67 | 68 | # If crop removes the entire image, keep the original image 69 | rotated_image = tf.cond(tf.less_equal(tf.reduce_min(tf.shape(rotated_image_crop)[:2]), 70 | tf.reduce_max(minimum_shape)), 71 | true_fn=lambda: image, 72 | false_fn=lambda: rotated_image_crop) 73 | return rotated_image 74 | 75 | 76 | def resize_image(image: tf.Tensor, size: int, interpolation: str='BILINEAR') -> tf.Tensor: 77 | """Resizes the image 78 | 79 | :param image: image to be resized [H, W, C] 80 | :param size: size of the resized image (in pixels) 81 | :param interpolation: which interpolation to use, ``NEAREST`` or ``BILINEAR`` 82 | :return: resized image 83 | """ 84 | assert interpolation in ['BILINEAR', 'NEAREST'] 85 | 86 | with tf.name_scope('ImageRescaling'): 87 | input_shape = tf.cast(tf.shape(image)[:2], tf.float32) 88 | size = tf.cast(size, tf.float32) 89 | # Compute new shape 90 | # We want X/Y = x/y and we have size = x*y so : 91 | ratio = tf.div(input_shape[1], input_shape[0]) 92 | new_height = tf.sqrt(tf.div(size, ratio)) 93 | new_width = tf.div(size, new_height) 94 | new_shape = tf.cast([new_height, new_width], tf.int32) 95 | resize_method = { 96 | 'NEAREST': tf.image.ResizeMethod.NEAREST_NEIGHBOR, 97 | 'BILINEAR': tf.image.ResizeMethod.BILINEAR 98 | } 99 | return tf.image.resize_images(image, new_shape, method=resize_method[interpolation]) 100 | 101 | 102 | def load_and_resize_image(filename: str, channels: int, size: int=None, interpolation: str='BILINEAR') -> tf.Tensor: 103 | """Loads an image from its filename and resizes it to the desired output size. 104 | 105 | :param filename: string tensor 106 | :param channels: number of channels for the decoded image 107 | :param size: number of desired pixels in the resized image, tf.Tensor or int (None for no resizing) 108 | :param interpolation: 109 | :param return_original_shape: returns the original shape of the image before resizing if this flag is True 110 | :return: decoded and resized float32 tensor [h, w, channels], 111 | """ 112 | with tf.name_scope('load_img'): 113 | decoded_image = tf.to_float(tf.image.decode_jpeg(tf.read_file(filename), channels=channels, 114 | try_recover_truncated=True)) 115 | # TODO : if one side is smaller than size of patches (and make patches == true), 116 | # TODO : force the image to have at least patch size 117 | if size is not None and not(isinstance(size, int) and size <= 0): 118 | result_image = resize_image(decoded_image, size, interpolation) 119 | else: 120 | result_image = decoded_image 121 | 122 | return result_image 123 | 124 | 125 | def extract_patches_fn(image: tf.Tensor, patch_shape: Tuple[int, int], offsets: Tuple[int, int]) -> tf.Tensor: 126 | """Will cut a given image into patches. 127 | 128 | :param image: tf.Tensor 129 | :param patch_shape: shape of the extracted patches [h, w] 130 | :param offsets: offset to add to the origin of first patch top-right coordinate, useful during data augmentation \ 131 | to have slighlty different patches each time. This value will be multiplied by [h/2, w/2] (range values [0,1]) 132 | :return: patches [batch_patches, h, w, c] 133 | """ 134 | with tf.name_scope('patch_extraction'): 135 | h, w = patch_shape 136 | c = image.get_shape()[-1] 137 | 138 | offset_h = tf.cast(tf.round(offsets[0] * h // 2), dtype=tf.int32) 139 | offset_w = tf.cast(tf.round(offsets[1] * w // 2), dtype=tf.int32) 140 | offset_img = image[offset_h:, offset_w:, :] 141 | offset_img = offset_img[None, :, :, :] 142 | 143 | patches = tf.extract_image_patches(offset_img, ksizes=[1, h, w, 1], strides=[1, h // 2, w // 2, 1], 144 | rates=[1, 1, 1, 1], padding='VALID') 145 | patches_shape = tf.shape(patches) 146 | return tf.reshape(patches, [tf.reduce_prod(patches_shape[:3]), h, w, int(c)]) 147 | 148 | 149 | def local_entropy(tf_binary_img: tf.Tensor, sigma: float=3) -> tf.Tensor: 150 | """ 151 | 152 | :param tf_binary_img: 153 | :param sigma: 154 | :return: 155 | """ 156 | tf_binary_img.get_shape().assert_has_rank(2) 157 | 158 | def get_gaussian_filter_1d(sigma): 159 | sigma_r = int(np.round(sigma)) 160 | x = np.zeros(6 * sigma_r + 1, dtype=np.float32) 161 | x[3 * sigma_r] = 1 162 | return ndimage.filters.gaussian_filter(x, sigma=sigma) 163 | 164 | def _fn(img): 165 | labelled, nb_components = ndimage.measurements.label(img) 166 | lut = np.concatenate( 167 | [np.array([0], np.int32), np.random.randint(20, size=nb_components + 1, dtype=np.int32) + 1]) 168 | output = lut[labelled] 169 | return output 170 | 171 | label_components = tf.py_func(_fn, [tf_binary_img], tf.int32) 172 | label_components.set_shape([None, None]) 173 | one_hot_components = tf.one_hot(label_components, tf.reduce_max(label_components)) 174 | one_hot_components = tf.transpose(one_hot_components, [2, 0, 1]) 175 | 176 | local_components_avg = tf.nn.conv2d(one_hot_components[:, :, :, None], 177 | get_gaussian_filter_1d(sigma)[None, :, None, None], (1, 1, 1, 1), 178 | padding='SAME') 179 | local_components_avg = tf.nn.conv2d(local_components_avg, get_gaussian_filter_1d(sigma)[:, None, None, None], 180 | (1, 1, 1, 1), padding='SAME') 181 | local_components_avg = tf.transpose(local_components_avg[:, :, :, 0], [1, 2, 0]) 182 | local_components_avg = tf.pow(local_components_avg, 1 / 1.4) 183 | local_components_avg = local_components_avg / (tf.reduce_sum(local_components_avg, axis=2, keep_dims=True) + 1e-6) 184 | return -tf.reduce_sum(local_components_avg * tf.log(local_components_avg + 1e-6), axis=2) 185 | -------------------------------------------------------------------------------- /dh_segment/network/__init__.py: -------------------------------------------------------------------------------- 1 | _MODEL = [ 2 | 'inference_vgg16', 3 | 'inference_resnet_v1_50', 4 | 'inference_u_net', 5 | 'vgg_16_fn', 6 | 'resnet_v1_50_fn' 7 | ] 8 | 9 | __all__ = _MODEL 10 | 11 | from .model import * 12 | from .pretrained_models import * 13 | -------------------------------------------------------------------------------- /dh_segment/network/pretrained_models.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib import slim, layers 2 | import tensorflow as tf 3 | from tensorflow.contrib.slim import nets 4 | import numpy as np 5 | 6 | _VGG_MEANS = [123.68, 116.78, 103.94] 7 | 8 | 9 | def mean_substraction(input_tensor, means=_VGG_MEANS): 10 | return tf.subtract(input_tensor, np.array(means)[None, None, None, :], name='MeanSubstraction') 11 | 12 | 13 | def vgg_16_fn(input_tensor: tf.Tensor, scope='vgg_16', blocks=5, weight_decay=0.0005) \ 14 | -> (tf.Tensor, list): # list of tf.Tensors (layers) 15 | intermediate_levels = [] 16 | # intermediate_levels.append(input_tensor) 17 | with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay=weight_decay)): 18 | with tf.variable_scope(scope, 'vgg_16', [input_tensor]) as sc: 19 | input_tensor = mean_substraction(input_tensor) 20 | intermediate_levels.append(input_tensor) 21 | end_points_collection = sc.original_name_scope + '_end_points' 22 | # Collect outputs for conv2d, fully_connected and max_pool2d. 23 | with slim.arg_scope( 24 | [layers.conv2d, layers.fully_connected, layers.max_pool2d], 25 | outputs_collections=end_points_collection): 26 | net = layers.repeat( 27 | input_tensor, 2, layers.conv2d, 64, [3, 3], scope='conv1') 28 | intermediate_levels.append(net) 29 | net = layers.max_pool2d(net, [2, 2], scope='pool1') 30 | if blocks >= 2: 31 | net = layers.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2') 32 | intermediate_levels.append(net) 33 | net = layers.max_pool2d(net, [2, 2], scope='pool2') 34 | if blocks >= 3: 35 | net = layers.repeat(net, 3, layers.conv2d, 256, [3, 3], scope='conv3') 36 | intermediate_levels.append(net) 37 | net = layers.max_pool2d(net, [2, 2], scope='pool3') 38 | if blocks >= 4: 39 | net = layers.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv4') 40 | intermediate_levels.append(net) 41 | net = layers.max_pool2d(net, [2, 2], scope='pool4') 42 | if blocks >= 5: 43 | net = layers.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv5') 44 | intermediate_levels.append(net) 45 | net = layers.max_pool2d(net, [2, 2], scope='pool5') 46 | 47 | return net, intermediate_levels 48 | 49 | 50 | def resnet_v1_50_fn(input_tensor: tf.Tensor, is_training=False, blocks=4, weight_decay=0.0001, 51 | renorm=True, corrected_version=False) -> tf.Tensor: 52 | with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(weight_decay=weight_decay, batch_norm_decay=0.999)), \ 53 | slim.arg_scope([layers.batch_norm], renorm_decay=0.95, renorm=renorm): 54 | input_tensor = mean_substraction(input_tensor) 55 | assert 0 < blocks <= 4 56 | 57 | if corrected_version: 58 | def corrected_resnet_v1_block(scope, base_depth, num_units, stride): 59 | """Helper function for creating a resnet_v1 bottleneck block. 60 | 61 | Args: 62 | scope: The scope of the block. 63 | base_depth: The depth of the bottleneck layer for each unit. 64 | num_units: The number of units in the block. 65 | stride: The stride of the block, implemented as a stride in the last unit. 66 | All other units have stride=1. 67 | 68 | Returns: 69 | A resnet_v1 bottleneck block. 70 | """ 71 | return nets.resnet_utils.Block(scope, nets.resnet_v1.bottleneck,[{ 72 | 'depth': base_depth * 4, 73 | 'depth_bottleneck': base_depth, 74 | 'stride': stride 75 | }] + [{ 76 | 'depth': base_depth * 4, 77 | 'depth_bottleneck': base_depth, 78 | 'stride': 1 79 | }] * (num_units - 1)) 80 | 81 | blocks_list = [ 82 | corrected_resnet_v1_block('block1', base_depth=64, num_units=3, stride=1), 83 | corrected_resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 84 | corrected_resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 85 | corrected_resnet_v1_block('block4', base_depth=512, num_units=3, stride=2), 86 | ] 87 | desired_endpoints = [ 88 | 'resnet_v1_50/conv1', 89 | 'resnet_v1_50/block1/unit_3/bottleneck_v1', 90 | 'resnet_v1_50/block2/unit_4/bottleneck_v1', 91 | 'resnet_v1_50/block3/unit_6/bottleneck_v1', 92 | 'resnet_v1_50/block4/unit_3/bottleneck_v1' 93 | ] 94 | else: 95 | blocks_list = [ 96 | nets.resnet_v1.resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 97 | nets.resnet_v1.resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 98 | nets.resnet_v1.resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 99 | nets.resnet_v1.resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 100 | ] 101 | desired_endpoints = [ 102 | 'resnet_v1_50/conv1', 103 | 'resnet_v1_50/block1/unit_2/bottleneck_v1', 104 | 'resnet_v1_50/block2/unit_3/bottleneck_v1', 105 | 'resnet_v1_50/block3/unit_5/bottleneck_v1', 106 | 'resnet_v1_50/block4/unit_3/bottleneck_v1' 107 | ] 108 | 109 | net, endpoints = nets.resnet_v1.resnet_v1(input_tensor, 110 | blocks=blocks_list[:blocks], 111 | num_classes=None, 112 | is_training=is_training, 113 | global_pool=False, 114 | output_stride=None, 115 | include_root_block=True, 116 | reuse=None, 117 | scope='resnet_v1_50') 118 | 119 | intermediate_layers = list() 120 | for d in desired_endpoints[:blocks + 1]: 121 | intermediate_layers.append(endpoints[d]) 122 | 123 | return net, intermediate_layers 124 | -------------------------------------------------------------------------------- /dh_segment/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.post_processing` module contains functions to post-process probability maps. 3 | 4 | **Binarization** 5 | 6 | .. autosummary:: 7 | thresholding 8 | cleaning_binary 9 | 10 | **Detection** 11 | 12 | .. autosummary:: 13 | find_boxes 14 | find_polygonal_regions 15 | 16 | **Vectorization** 17 | 18 | .. autosummary:: 19 | find_lines 20 | 21 | ------ 22 | 23 | """ 24 | 25 | _BINARIZATION = [ 26 | 'thresholding', 27 | 'cleaning_binary', 28 | 29 | ] 30 | 31 | _DETECTION = [ 32 | 'find_boxes', 33 | 'find_polygonal_regions' 34 | ] 35 | 36 | _VECTORIZATION = [ 37 | 'find_lines' 38 | ] 39 | 40 | __all__ = _BINARIZATION + _DETECTION + _VECTORIZATION 41 | 42 | from .binarization import * 43 | from .boxes_detection import * 44 | from .line_vectorization import * 45 | from .polygon_detection import * 46 | 47 | -------------------------------------------------------------------------------- /dh_segment/post_processing/binarization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.ndimage import label 4 | 5 | 6 | def thresholding(probs: np.ndarray, threshold: float=-1) -> np.ndarray: 7 | """ 8 | Computes the binary mask of the detected Page from the probabilities output by network. 9 | 10 | :param probs: array in range [0, 1] of shape HxWx2 11 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 12 | :return: binary mask 13 | """ 14 | 15 | if threshold < 0: # Otsu's thresholding 16 | probs = np.uint8(probs * 255) 17 | #TODO Correct that weird gaussianBlur 18 | probs = cv2.GaussianBlur(probs, (5, 5), 0) 19 | 20 | thresh_val, bin_img = cv2.threshold(probs, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 21 | mask = np.uint8(bin_img / 255) 22 | else: 23 | mask = np.uint8(probs > threshold) 24 | 25 | return mask 26 | 27 | 28 | def cleaning_binary(mask: np.ndarray, kernel_size: int=5) -> np.ndarray: 29 | """ 30 | Uses mathematical morphology to clean and remove small elements from binary images. 31 | 32 | :param mask: the binary image to clean 33 | :param kernel_size: size of the kernel 34 | :return: the cleaned mask 35 | """ 36 | 37 | ksize_open = (kernel_size, kernel_size) 38 | ksize_close = (kernel_size, kernel_size) 39 | mask = cv2.morphologyEx((mask.astype(np.uint8, copy=False) * 255), cv2.MORPH_OPEN, kernel=np.ones(ksize_open)) 40 | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel=np.ones(ksize_close)) 41 | return np.uint8(mask / 255) 42 | 43 | 44 | def hysteresis_thresholding(probs: np.array, low_threshold: float, high_threshold: float, 45 | candidates_mask: np.ndarray=None) -> np.ndarray: 46 | low_mask = probs > low_threshold 47 | if candidates_mask is not None: 48 | low_mask = candidates_mask & low_mask 49 | # Connected components extraction 50 | label_components, count = label(low_mask, np.ones((3, 3))) 51 | # Keep components with high threshold elements 52 | good_labels = np.unique(label_components[low_mask & (probs > high_threshold)]) 53 | label_masks = np.zeros((count + 1,), bool) 54 | label_masks[good_labels] = 1 55 | return label_masks[label_components] 56 | 57 | 58 | def cleaning_probs(probs: np.ndarray, sigma: float) -> np.ndarray: 59 | # Smooth 60 | if sigma > 0.: 61 | return cv2.GaussianBlur(probs, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) 62 | elif sigma == 0.: 63 | return cv2.fastNlMeansDenoising((probs*255).astype(np.uint8), h=20)/255 64 | else: # Negative sigma, do not do anything 65 | return probs 66 | -------------------------------------------------------------------------------- /dh_segment/post_processing/boxes_detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | from shapely import geometry 5 | from scipy.spatial import KDTree 6 | 7 | 8 | def find_boxes(boxes_mask: np.ndarray, 9 | mode: str= 'min_rectangle', 10 | min_area: float=0., 11 | p_arc_length: float=0.01, 12 | n_max_boxes=math.inf) -> list: 13 | """ 14 | Finds the coordinates of the box in the binary image `boxes_mask`. 15 | 16 | :param boxes_mask: Binary image: the mask of the box to find. uint8, 2D array 17 | :param mode: 'min_rectangle' : minimum enclosing rectangle, can be rotated 18 | 'rectangle' : minimum enclosing rectangle, not rotated 19 | 'quadrilateral' : minimum polygon approximated by a quadrilateral 20 | :param min_area: minimum area of the box to be found. A value in percentage of the total area of the image. 21 | :param p_arc_length: used to compute the epsilon value to approximate the polygon with a quadrilateral. 22 | Only used when 'quadrilateral' mode is chosen. 23 | :param n_max_boxes: maximum number of boxes that can be found (default inf). 24 | This will select n_max_boxes with largest area. 25 | :return: list of length n_max_boxes containing boxes with 4 corners [[x1,y1], ..., [x4,y4]] 26 | """ 27 | 28 | assert len(boxes_mask.shape) == 2, \ 29 | 'Input mask must be a 2D array ! Mask is now of shape {}'.format(boxes_mask.shape) 30 | 31 | contours, _ = cv2.findContours(boxes_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 32 | if contours is None: 33 | print('No contour found') 34 | return None 35 | found_boxes = list() 36 | 37 | h_img, w_img = boxes_mask.shape[:2] 38 | 39 | def validate_box(box: np.array) -> (np.array, float): 40 | """ 41 | 42 | :param box: array of 4 coordinates with format [[x1,y1], ..., [x4,y4]] 43 | :return: (box, area) 44 | """ 45 | polygon = geometry.Polygon([point for point in box]) 46 | if polygon.area > min_area * boxes_mask.size: 47 | 48 | # Correct out of range corners 49 | box = np.maximum(box, 0) 50 | box = np.stack((np.minimum(box[:, 0], boxes_mask.shape[1]), 51 | np.minimum(box[:, 1], boxes_mask.shape[0])), axis=1) 52 | 53 | # return box 54 | return box, polygon.area 55 | 56 | if mode not in ['quadrilateral', 'min_rectangle', 'rectangle']: 57 | raise NotImplementedError 58 | if mode == 'quadrilateral': 59 | for c in contours: 60 | epsilon = p_arc_length * cv2.arcLength(c, True) 61 | cnt = cv2.approxPolyDP(c, epsilon, True) 62 | # box = np.vstack(simplify_douglas_peucker(cnt[:, 0, :], 4)) 63 | 64 | # Find extreme points in Convex Hull 65 | hull_points = cv2.convexHull(cnt, returnPoints=True) 66 | # points = cnt 67 | points = hull_points 68 | if len(points) > 4: 69 | # Find closes points to corner using nearest neighbors 70 | tree = KDTree(points[:, 0, :]) 71 | _, ul = tree.query((0, 0)) 72 | _, ur = tree.query((w_img, 0)) 73 | _, dl = tree.query((0, h_img)) 74 | _, dr = tree.query((w_img, h_img)) 75 | box = np.vstack([points[ul, 0, :], points[ur, 0, :], 76 | points[dr, 0, :], points[dl, 0, :]]) 77 | elif len(hull_points) == 4: 78 | box = hull_points[:, 0, :] 79 | else: 80 | continue 81 | # Todo : test if it looks like a rectangle (2 sides must be more or less parallel) 82 | # todo : (otherwise we may end with strange quadrilaterals) 83 | if len(box) != 4: 84 | mode = 'min_rectangle' 85 | print('Quadrilateral has {} points. Switching to minimal rectangle mode'.format(len(box))) 86 | else: 87 | # found_box = validate_box(box) 88 | found_boxes.append(validate_box(box)) 89 | if mode == 'min_rectangle': 90 | for c in contours: 91 | rect = cv2.minAreaRect(c) 92 | box = np.int0(cv2.boxPoints(rect)) 93 | found_boxes.append(validate_box(box)) 94 | elif mode == 'rectangle': 95 | for c in contours: 96 | x, y, w, h = cv2.boundingRect(c) 97 | box = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]], dtype=int) 98 | found_boxes.append(validate_box(box)) 99 | # sort by area 100 | found_boxes = [fb for fb in found_boxes if fb is not None] 101 | found_boxes = sorted(found_boxes, key=lambda x: x[1], reverse=True) 102 | if n_max_boxes == 1: 103 | if found_boxes: 104 | return found_boxes[0][0] 105 | else: 106 | return None 107 | else: 108 | return [fb[0] for i, fb in enumerate(found_boxes) if i < n_max_boxes] 109 | -------------------------------------------------------------------------------- /dh_segment/post_processing/line_vectorization.py: -------------------------------------------------------------------------------- 1 | from skimage.graph import MCP_Connect 2 | from skimage.morphology import skeletonize 3 | from skimage.measure import label as skimage_label 4 | from sklearn.metrics.pairwise import euclidean_distances 5 | from scipy.signal import convolve2d 6 | from collections import defaultdict 7 | import numpy as np 8 | 9 | 10 | def find_lines(lines_mask: np.ndarray) -> list: 11 | """ 12 | Finds the longest central line for each connected component in the given binary mask. 13 | 14 | :param lines_mask: Binary mask of the detected line-areas 15 | :return: a list of Opencv-style polygonal lines (each contour encoded as [N,1,2] elements where each tuple is (x,y) ) 16 | """ 17 | # Make sure one-pixel wide 8-connected mask 18 | lines_mask = skeletonize(lines_mask) 19 | 20 | class MakeLineMCP(MCP_Connect): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.connections = dict() 24 | self.scores = defaultdict(lambda: np.inf) 25 | 26 | def create_connection(self, id1, id2, pos1, pos2, cost1, cost2): 27 | k = (min(id1, id2), max(id1, id2)) 28 | s = cost1 + cost2 29 | if self.scores[k] > s: 30 | self.connections[k] = (pos1, pos2, s) 31 | self.scores[k] = s 32 | 33 | def get_connections(self, subsample=5): 34 | results = dict() 35 | for k, (pos1, pos2, s) in self.connections.items(): 36 | path = np.concatenate([self.traceback(pos1), self.traceback(pos2)[::-1]]) 37 | results[k] = path[::subsample] 38 | return results 39 | 40 | def goal_reached(self, int_index, float_cumcost): 41 | if float_cumcost > 0: 42 | return 2 43 | else: 44 | return 0 45 | 46 | if np.sum(lines_mask) == 0: 47 | return [] 48 | # Find extremities points 49 | end_points_candidates = np.stack(np.where((convolve2d(lines_mask, np.ones((3, 3)), mode='same') == 2) & lines_mask)).T 50 | connected_components = skimage_label(lines_mask, connectivity=2) 51 | # Group endpoint by connected components and keep only the two points furthest away 52 | d = defaultdict(list) 53 | for pt in end_points_candidates: 54 | d[connected_components[pt[0], pt[1]]].append(pt) 55 | end_points = [] 56 | for pts in d.values(): 57 | d = euclidean_distances(np.stack(pts), np.stack(pts)) 58 | i, j = np.unravel_index(d.argmax(), d.shape) 59 | end_points.append(pts[i]) 60 | end_points.append(pts[j]) 61 | end_points = np.stack(end_points) 62 | 63 | mcp = MakeLineMCP(~lines_mask) 64 | mcp.find_costs(end_points) 65 | connections = mcp.get_connections() 66 | if not np.all(np.array(sorted([i for k in connections.keys() for i in k])) == np.arange(len(end_points))): 67 | print('Warning : find_lines seems weird') 68 | return [c[:, None, ::-1] for c in connections.values()] 69 | -------------------------------------------------------------------------------- /dh_segment/post_processing/polygon_detection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import cv2 4 | import numpy as np 5 | import math 6 | from shapely import geometry 7 | 8 | 9 | def find_polygonal_regions(image_mask: np.ndarray, 10 | min_area: float=0., 11 | n_max_polygons: int=math.inf) -> list: 12 | """ 13 | Finds the shapes in a binary mask and returns their coordinates as polygons. 14 | 15 | :param image_mask: Uint8 binary 2D array 16 | :param min_area: minimum area the polygon should have in order to be considered as valid 17 | (value within [0,1] representing a percent of the total size of the image) 18 | :param n_max_polygons: maximum number of boxes that can be found (default inf). 19 | This will select n_max_boxes with largest area. 20 | :return: list of length n_max_polygons containing polygon's n coordinates [[x1, y1], ... [xn, yn]] 21 | """ 22 | 23 | contours, _ = cv2.findContours(image_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 24 | if contours is None: 25 | print('No contour found') 26 | return None 27 | found_polygons = list() 28 | 29 | for c in contours: 30 | if len(c) < 3: # A polygon cannot have less than 3 points 31 | continue 32 | polygon = geometry.Polygon([point[0] for point in c]) 33 | # Check that polygon has area greater than minimal area 34 | if polygon.area >= min_area*np.prod(image_mask.shape[:2]): 35 | found_polygons.append( 36 | (np.array([point for point in polygon.exterior.coords], dtype=np.uint), polygon.area) 37 | ) 38 | 39 | # sort by area 40 | found_polygons = [fp for fp in found_polygons if fp is not None] 41 | found_polygons = sorted(found_polygons, key=lambda x: x[1], reverse=True) 42 | 43 | if found_polygons: 44 | return [fp[0] for i, fp in enumerate(found_polygons) if i <= n_max_polygons] 45 | else: 46 | return None 47 | -------------------------------------------------------------------------------- /dh_segment/utils/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.utils` module contains the parameters for config with `sacred`_ package, 3 | image label vizualization functions and miscelleanous helpers. 4 | 5 | Parameters 6 | ---------- 7 | 8 | .. autosummary:: 9 | ModelParams 10 | TrainingParams 11 | 12 | Label image helpers 13 | ------------------- 14 | 15 | .. autosummary:: 16 | label_image_to_class 17 | class_to_label_image 18 | multilabel_image_to_class 19 | multiclass_to_label_image 20 | get_classes_color_from_file 21 | get_n_classes_from_file 22 | get_classes_color_from_file_multilabel 23 | get_n_classes_from_file_multilabel 24 | 25 | Evaluation utils 26 | ---------------- 27 | 28 | .. autosummary:: 29 | Metrics 30 | intersection_over_union 31 | 32 | Miscellaneous helpers 33 | --------------------- 34 | 35 | .. autosummary:: 36 | parse_json 37 | dump_json 38 | load_pickle 39 | dump_pickle 40 | hash_dict 41 | 42 | .. _sacred : https://sacred.readthedocs.io/en/latest/index.html 43 | 44 | ------ 45 | """ 46 | 47 | _PARAMSCONFIG = [ 48 | 'PredictionType', 49 | 'VGG16ModelParams', 50 | 'ResNetModelParams', 51 | 'UNetModelParams', 52 | 'ModelParams', 53 | 'TrainingParams' 54 | ] 55 | 56 | 57 | _LABELS = [ 58 | 'label_image_to_class', 59 | 'class_to_label_image', 60 | 'multilabel_image_to_class', 61 | 'multiclass_to_label_image', 62 | 'get_classes_color_from_file', 63 | 'get_n_classes_from_file', 64 | 'get_classes_color_from_file_multilabel', 65 | 'get_n_classes_from_file_multilabel' 66 | ] 67 | 68 | _MISC = [ 69 | 'parse_json', 70 | 'dump_json', 71 | 'load_pickle', 72 | 'dump_pickle', 73 | 'hash_dict' 74 | ] 75 | 76 | _EVALUATION = [ 77 | 'Metrics', 78 | 'intersection_over_union' 79 | ] 80 | 81 | __all__ = _PARAMSCONFIG + _LABELS + _MISC + _EVALUATION 82 | 83 | from .params_config import * 84 | from .labels import * 85 | from .misc import * 86 | from .evaluation import * -------------------------------------------------------------------------------- /dh_segment/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import numpy as np 6 | import json 7 | import cv2 8 | 9 | 10 | class Metrics: 11 | def __init__(self): 12 | self.total_elements = 0 13 | self.true_positives = 0 14 | self.true_negatives = 0 15 | self.false_positives = 0 16 | self.false_negatives = 0 17 | self.SE_list = list() 18 | self.IOU_list = list() 19 | 20 | self.MSE = 0 21 | self.psnr = 0 22 | self.mIOU = 0 23 | self.IU = 0 24 | self.accuracy = 0 25 | self.recall = 0 26 | self.precision = 0 27 | self.f_measure = 0 28 | 29 | def __add__(self, other): 30 | if isinstance(other, self.__class__): 31 | summable_attr = ['total_elements', 'false_negatives', 'false_positives', 'true_positives', 'true_negatives'] 32 | addlist_attr = ['SE_list', 'IOU_list'] 33 | m = Metrics() 34 | for k, v in self.__dict__.items(): 35 | if k in summable_attr: 36 | setattr(m, k, self.__dict__[k] + other.__dict__[k]) 37 | elif k in addlist_attr: 38 | mse1 = [self.__dict__[k]] if not isinstance(self.__dict__[k], list) else self.__dict__[k] 39 | mse2 = [other.__dict__[k]] if not isinstance(other.__dict__[k], list) else other.__dict__[k] 40 | 41 | setattr(m, k, mse1 + mse2) 42 | return m 43 | else: 44 | raise NotImplementedError 45 | 46 | def __radd__(self, other): 47 | return self.__add__(other) 48 | 49 | def compute_mse(self): 50 | self.MSE = np.sum(self.SE_list) / self.total_elements if self.total_elements > 0 else np.inf 51 | return self.MSE 52 | 53 | def compute_psnr(self): 54 | if self.MSE != 0: 55 | self.psnr = 10 * np.log10((1 ** 2) / self.MSE) 56 | return self.psnr 57 | else: 58 | print('Cannot compute PSNR, MSE is 0.') 59 | 60 | def compute_prf(self, beta=1): 61 | self.recall = self.true_positives / (self.true_positives + self.false_negatives) \ 62 | if (self.true_positives + self.false_negatives) > 0 else 0 63 | self.precision = self.true_positives / (self.true_positives + self.false_positives) \ 64 | if (self.true_positives + self.false_negatives) > 0 else 0 65 | self.f_measure = ((1 + beta ** 2) * self.recall * self.precision) / (self.recall + (beta ** 2) * self.precision) \ 66 | if (self.recall + self.precision) > 0 else 0 67 | 68 | return self.recall, self.precision, self.f_measure 69 | 70 | def compute_miou(self): 71 | self.mIOU = np.mean(self.IOU_list) 72 | return self.mIOU 73 | 74 | # See http://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/images/ConferencePapers/2017/DocUsingDeepFeatures.pdf 75 | def compute_iu(self): 76 | self.IU = self.true_positives / (self.true_positives + self.false_positives + self.false_negatives) \ 77 | if (self.true_positives + self.false_positives + self.false_negatives) > 0 else 0 78 | return self.IU 79 | 80 | def compute_accuracy(self): 81 | self.accuracy = (self.true_positives + self.true_negatives)/self.total_elements if self.total_elements > 0 else 0 82 | 83 | def save_to_json(self, json_filename: str) -> None: 84 | export_dic = self.__dict__.copy() 85 | del export_dic['MSE_list'] 86 | 87 | with open(json_filename, 'w') as outfile: 88 | json.dump(export_dic, outfile) 89 | 90 | 91 | def intersection_over_union(cnt1, cnt2, shape_mask): 92 | mask1 = np.zeros(shape_mask, np.uint8) 93 | mask1 = cv2.fillConvexPoly(mask1, cnt1.astype(np.int32), 1).astype(np.int8) 94 | mask2 = np.zeros(shape_mask, np.uint8) 95 | mask2 = cv2.fillConvexPoly(mask2, cnt2.astype(np.int32), 1).astype(np.int8) 96 | return np.sum(mask1 & mask2) / np.sum(mask1 | mask2) 97 | -------------------------------------------------------------------------------- /dh_segment/utils/labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __license__ = "GPL" 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | from typing import Tuple 8 | 9 | 10 | def label_image_to_class(label_image: tf.Tensor, classes_file: str) -> tf.Tensor: 11 | classes_color_values = get_classes_color_from_file(classes_file) 12 | # Convert label_image [H,W,3] to the classes [H,W],int32 according to the classes [C,3] 13 | with tf.name_scope('LabelAssign'): 14 | if len(label_image.get_shape()) == 3: 15 | diff = tf.cast(label_image[:, :, None, :], tf.float32) - tf.constant(classes_color_values[None, None, :, :]) # [H,W,C,3] 16 | elif len(label_image.get_shape()) == 4: 17 | diff = tf.cast(label_image[:, :, :, None, :], tf.float32) - tf.constant( 18 | classes_color_values[None, None, None, :, :]) # [B,H,W,C,3] 19 | else: 20 | raise NotImplementedError('Length is : {}'.format(len(label_image.get_shape()))) 21 | 22 | pixel_class_diff = tf.reduce_sum(tf.square(diff), axis=-1) # [H,W,C] or [B,H,W,C] 23 | class_label = tf.argmin(pixel_class_diff, axis=-1) # [H,W] or [B,H,W] 24 | return class_label 25 | 26 | 27 | def class_to_label_image(class_label: tf.Tensor, classes_file: str) -> tf.Tensor: 28 | classes_color_values = get_classes_color_from_file(classes_file) 29 | return tf.gather(classes_color_values, tf.cast(class_label, dtype=tf.int32)) 30 | 31 | 32 | def multilabel_image_to_class(label_image: tf.Tensor, classes_file: str) -> tf.Tensor: 33 | """ 34 | Combines image annotations with classes info of the txt file to create the input label for the training. 35 | 36 | :param label_image: annotated image [H,W,Ch] or [B,H,W,Ch] (Ch = color channels) 37 | :param classes_file: the filename of the txt file containing the class info 38 | :return: [H,W,Cl] or [B,H,W,Cl] (Cl = number of classes) 39 | """ 40 | classes_color_values, colors_labels = get_classes_color_from_file_multilabel(classes_file) 41 | # Convert label_image [H,W,3] to the classes [H,W,C],int32 according to the classes [C,3] 42 | with tf.name_scope('LabelAssign'): 43 | if len(label_image.get_shape()) == 3: 44 | diff = tf.cast(label_image[:, :, None, :], tf.float32) - tf.constant(classes_color_values[None, None, :, :]) # [H,W,C,3] 45 | elif len(label_image.get_shape()) == 4: 46 | diff = tf.cast(label_image[:, :, :, None, :], tf.float32) - tf.constant( 47 | classes_color_values[None, None, None, :, :]) # [B,H,W,C,3] 48 | else: 49 | raise NotImplementedError('Length is : {}'.format(len(label_image.get_shape()))) 50 | 51 | pixel_class_diff = tf.reduce_sum(tf.square(diff), axis=-1) # [H,W,C] or [B,H,W,C] 52 | class_label = tf.argmin(pixel_class_diff, axis=-1) # [H,W] or [B,H,W] 53 | 54 | return tf.gather(colors_labels, class_label) > 0 55 | 56 | 57 | def multiclass_to_label_image(class_label_tensor: tf.Tensor, classes_file: str) -> tf.Tensor: 58 | 59 | classes_color_values, colors_labels = get_classes_color_from_file_multilabel(classes_file) 60 | 61 | n_classes = colors_labels.shape[1] 62 | c = np.zeros((2,)*n_classes+(3,), np.int32) 63 | for c_value, inds in zip(classes_color_values, colors_labels): 64 | c[tuple(inds)] = c_value 65 | 66 | with tf.name_scope('Label2Img'): 67 | return tf.gather_nd(c, tf.cast(class_label_tensor, tf.int32)) 68 | 69 | 70 | def get_classes_color_from_file(classes_file: str) -> np.ndarray: 71 | if not os.path.exists(classes_file): 72 | raise FileNotFoundError(classes_file) 73 | result = np.loadtxt(classes_file).astype(np.float32) 74 | assert result.shape[1] == 3, "Color file should represent RGB values" 75 | return result 76 | 77 | 78 | def get_n_classes_from_file(classes_file: str) -> int: 79 | return get_classes_color_from_file(classes_file).shape[0] 80 | 81 | 82 | def get_classes_color_from_file_multilabel(classes_file: str) -> Tuple[np.ndarray, np.array]: 83 | """ 84 | Get classes and code labels from txt file. 85 | This function deals with the case of elements with multiple labels. 86 | 87 | :param classes_file: file containing the classes (usually named *classes.txt*) 88 | :return: for each class the RGB color (array size [N, 3]); and the label's code (array size [N, C]), 89 | with N the number of combinations and C the number of classes 90 | """ 91 | if not os.path.exists(classes_file): 92 | raise FileNotFoundError(classes_file) 93 | result = np.loadtxt(classes_file).astype(np.float32) 94 | assert result.shape[1] > 3, "The number of columns should be greater in multilabel framework" 95 | colors = result[:, :3] 96 | labels = result[:, 3:] 97 | return colors, labels.astype(np.int32) 98 | 99 | 100 | def get_n_classes_from_file_multilabel(classes_file: str) -> int: 101 | return get_classes_color_from_file_multilabel(classes_file)[1].shape[1] 102 | -------------------------------------------------------------------------------- /dh_segment/utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __license__ = "GPL" 3 | 4 | import tensorflow as tf 5 | import json 6 | import pickle 7 | from hashlib import sha1 8 | from random import shuffle 9 | 10 | 11 | def parse_json(filename): 12 | with open(filename, 'r') as f: 13 | return json.load(f) 14 | 15 | 16 | def dump_json(filename, dict): 17 | with open(filename, 'w') as f: 18 | json.dump(dict, f, indent=4, sort_keys=True) 19 | 20 | 21 | def load_pickle(filename): 22 | with open(filename, 'rb') as f: 23 | return pickle.load(f) 24 | 25 | 26 | def dump_pickle(filename, obj): 27 | with open(filename, 'wb') as f: 28 | return pickle.dump(obj, f) 29 | 30 | 31 | def hash_dict(params): 32 | return sha1(json.dumps(params, sort_keys=True).encode()).hexdigest() 33 | 34 | def shuffled(l: list) -> list: 35 | ll = l.copy() 36 | shuffle(ll) 37 | return ll 38 | -------------------------------------------------------------------------------- /dh_segment/utils/params_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | import warnings 7 | from random import shuffle 8 | 9 | 10 | class PredictionType: 11 | """ 12 | 13 | :cvar CLASSIFICATION: 14 | :cvar REGRESSION: 15 | :cvar MULTILABEL: 16 | """ 17 | CLASSIFICATION = 'CLASSIFICATION' 18 | REGRESSION = 'REGRESSION' 19 | MULTILABEL = 'MULTILABEL' 20 | 21 | @classmethod 22 | def parse(cls, prediction_type): 23 | if prediction_type == 'CLASSIFICATION': 24 | return PredictionType.CLASSIFICATION 25 | elif prediction_type == 'REGRESSION': 26 | return PredictionType.REGRESSION 27 | elif prediction_type == 'MULTILABEL': 28 | return PredictionType.MULTILABEL 29 | else: 30 | raise NotImplementedError('Unknown prediction type : {}'.format(prediction_type)) 31 | 32 | 33 | class BaseParams: 34 | def to_dict(self): 35 | return self.__dict__ 36 | 37 | @classmethod 38 | def from_dict(cls, d): 39 | result = cls() 40 | keys = result.to_dict().keys() 41 | for k, v in d.items(): 42 | assert k in keys, k 43 | setattr(result, k, v) 44 | result.check_params() 45 | return result 46 | 47 | def check_params(self): 48 | pass 49 | 50 | 51 | class VGG16ModelParams: 52 | PRETRAINED_MODEL_FILE = 'pretrained_models/vgg_16.ckpt' 53 | INTERMEDIATE_CONV = [ 54 | [(256, 3)] 55 | ] 56 | UPSCALE_PARAMS = [ 57 | [(32, 3)], 58 | [(64, 3)], 59 | [(128, 3)], 60 | [(256, 3)], 61 | [(512, 3)], 62 | [(512, 3)] 63 | ] 64 | SELECTED_LAYERS_UPSCALING = [ 65 | True, 66 | True, # Must have same length as vgg_upscale_params 67 | True, 68 | True, 69 | False, 70 | False 71 | ] 72 | CORRECTED_VERSION = None 73 | 74 | 75 | class ResNetModelParams: 76 | PRETRAINED_MODEL_FILE = 'pretrained_models/resnet_v1_50.ckpt' 77 | INTERMEDIATE_CONV = None 78 | UPSCALE_PARAMS = [ 79 | # (Filter size (depth bottleneck's output), number of bottleneck) 80 | (32, 0), 81 | (64, 0), 82 | (128, 0), 83 | (256, 0), 84 | (512, 0) 85 | ] 86 | SELECTED_LAYERS_UPSCALING = [ 87 | # Must have the same length as resnet_upscale_params 88 | True, 89 | True, 90 | True, 91 | True, 92 | True 93 | ] 94 | CORRECT_VERSION = False 95 | 96 | 97 | class UNetModelParams: 98 | PRETRAINED_MODEL_FILE = None 99 | INTERMEDIATE_CONV = None 100 | UPSCALE_PARAMS = None 101 | SELECTED_LAYERS_UPSCALING = None 102 | CORRECT_VERSION = False 103 | 104 | 105 | class ModelParams(BaseParams): 106 | """Parameters related to the model 107 | 108 | """ 109 | def __init__(self, **kwargs): 110 | self.batch_norm = kwargs.get('batch_norm', True) # type: bool 111 | self.batch_renorm = kwargs.get('batch_renorm', True) # type: bool 112 | self.weight_decay = kwargs.get('weight_decay', 1e-6) # type: float 113 | self.n_classes = kwargs.get('n_classes', None) # type: int 114 | self.pretrained_model_name = kwargs.get('pretrained_model_name', None) # type: str 115 | self.max_depth = kwargs.get('max_depth', 512) # type: int 116 | 117 | if self.pretrained_model_name == 'vgg16': 118 | model_class = VGG16ModelParams 119 | elif self.pretrained_model_name == 'resnet50': 120 | model_class = ResNetModelParams 121 | elif self.pretrained_model_name == 'unet': 122 | model_class = UNetModelParams 123 | else: 124 | raise NotImplementedError 125 | 126 | self.pretrained_model_file = kwargs.get('pretrained_model_file', model_class.PRETRAINED_MODEL_FILE) 127 | self.intermediate_conv = kwargs.get('intermediate_conv', model_class.INTERMEDIATE_CONV) 128 | self.upscale_params = kwargs.get('upscale_params', model_class.UPSCALE_PARAMS) 129 | self.selected_levels_upscaling = kwargs.get('selected_levels_upscaling', model_class.SELECTED_LAYERS_UPSCALING) 130 | self.correct_resnet_version = kwargs.get('correct_resnet_version', model_class.CORRECT_VERSION) 131 | self.check_params() 132 | 133 | def check_params(self): 134 | # Pretrained model name check 135 | # assert self.upscale_params is not None and self.selected_levels_upscaling is not None, \ 136 | # 'Model parameters cannot be None' 137 | if self.upscale_params is not None and self.selected_levels_upscaling is not None: 138 | 139 | assert len(self.upscale_params) == len(self.selected_levels_upscaling), \ 140 | 'Upscaling levels and selection levels must have the same lengths (in model_params definition), ' \ 141 | '{} != {}'.format(len(self.upscale_params), 142 | len(self.selected_levels_upscaling)) 143 | 144 | # assert os.path.isfile(self.pretrained_model_file), \ 145 | # 'Pretrained weights file {} not found'.format(self.pretrained_model_file) 146 | if not os.path.isfile(self.pretrained_model_file): 147 | warnings.warn('WARNING - Default pretrained weights file in {} was not found. ' 148 | 'Have you changed the default pretrained file ?'.format(self.pretrained_model_file)) 149 | 150 | 151 | class TrainingParams(BaseParams): 152 | """Parameters to configure training process 153 | 154 | :ivar n_epochs: number of epoch for training 155 | :vartype n_epochs: int 156 | :ivar evaluate_every_epoch: the model will be evaluated every `n` epochs 157 | :vartype evaluate_every_epoch: int 158 | :ivar learning_rate: the starting learning rate value 159 | :vartype learning_rate: float 160 | :ivar exponential_learning: option to use exponential learning rate 161 | :vartype exponential_learning: bool 162 | :ivar batch_size: size of batch 163 | :vartype batch_size: int 164 | :ivar data_augmentation: option to use data augmentation (by default is set to False) 165 | :vartype data_augmentation: bool 166 | :ivar data_augmentation_flip_lr: option to use image flipping in right-left direction 167 | :vartype data_augmentation_flip_lr: bool 168 | :ivar data_augmentation_flip_ud: option to use image flipping in up down direction 169 | :vartype data_augmentation_flip_ud: bool 170 | :ivar data_augmentation_color: option to use data augmentation with color 171 | :vartype data_augmentation_color: bool 172 | :ivar data_augmentation_max_rotation: maximum angle of rotation (in radians) for data augmentation 173 | :vartype data_augmentation_max_rotation: float 174 | :ivar data_augmentation_max_scaling: maximum scale of zooming during data augmentation (range: [0,1]) 175 | :vartype data_augmentation_max_scaling: float 176 | :ivar make_patches: option to crop image into patches. This will cut the entire image in several patches 177 | :vartype make_patches: bool 178 | :ivar patch_shape: shape of the patches 179 | :vartype patch_shape: tuple 180 | :ivar input_resized_size: size (in pixel) of the image after resizing. The original ratio is kept. If no resizing \ 181 | is wanted, set it to -1 182 | :vartype input_resized_size: int 183 | :ivar weights_labels: weight given to each label. Should be a list of length = number of classes 184 | :vartype weights_labels: list 185 | :ivar training_margin: size of the margin to add to the images. This is particularly useful when training with \ 186 | patches 187 | :vartype training_margin: int 188 | :ivar local_entropy_ratio: 189 | :vartype local_entropy_ratio: float 190 | :ivar local_entropy_sigma: 191 | :vartype local_entropy_sigma: float 192 | :ivar focal_loss_gamma: value of gamma for the focal loss. See paper : https://arxiv.org/abs/1708.02002 193 | :vartype focal_loss_gamma: float 194 | """ 195 | def __init__(self, **kwargs): 196 | self.n_epochs = kwargs.get('n_epochs', 20) 197 | self.evaluate_every_epoch = kwargs.get('evaluate_every_epoch', 10) 198 | self.learning_rate = kwargs.get('learning_rate', 1e-5) 199 | self.exponential_learning = kwargs.get('exponential_learning', True) 200 | self.batch_size = kwargs.get('batch_size', 5) 201 | self.data_augmentation = kwargs.get('data_augmentation', False) 202 | self.data_augmentation_flip_lr = kwargs.get('data_augmentation_flip_lr', False) 203 | self.data_augmentation_flip_ud = kwargs.get('data_augmentation_flip_ud', False) 204 | self.data_augmentation_color = kwargs.get('data_augmentation_color', False) 205 | self.data_augmentation_max_rotation = kwargs.get('data_augmentation_max_rotation', 0.2) 206 | self.data_augmentation_max_scaling = kwargs.get('data_augmentation_max_scaling', 0.05) 207 | self.make_patches = kwargs.get('make_patches', True) 208 | self.patch_shape = kwargs.get('patch_shape', (300, 300)) 209 | self.input_resized_size = int(kwargs.get('input_resized_size', 72e4)) # (600*1200) 210 | self.weights_labels = kwargs.get('weights_labels') 211 | self.weights_evaluation_miou = kwargs.get('weights_evaluation_miou', None) 212 | self.training_margin = kwargs.get('training_margin', 16) 213 | self.local_entropy_ratio = kwargs.get('local_entropy_ratio', 0.) 214 | self.local_entropy_sigma = kwargs.get('local_entropy_sigma', 3) 215 | self.focal_loss_gamma = kwargs.get('focal_loss_gamma', 0.) 216 | 217 | def check_params(self) -> None: 218 | """Checks if there is no parameter inconsistency 219 | """ 220 | assert self.training_margin*2 < min(self.patch_shape) 221 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = dhsegment 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/_static/cbad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/cbad.jpg -------------------------------------------------------------------------------- /doc/_static/cini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/cini.jpg -------------------------------------------------------------------------------- /doc/_static/cini_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/cini_input.jpg -------------------------------------------------------------------------------- /doc/_static/cini_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/cini_labels.jpg -------------------------------------------------------------------------------- /doc/_static/diva.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/diva.jpg -------------------------------------------------------------------------------- /doc/_static/diva_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/diva_preds.png -------------------------------------------------------------------------------- /doc/_static/ornaments.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/ornaments.jpg -------------------------------------------------------------------------------- /doc/_static/page.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/page.jpg -------------------------------------------------------------------------------- /doc/_static/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/system.png -------------------------------------------------------------------------------- /doc/_static/tensorboard_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/tensorboard_1.png -------------------------------------------------------------------------------- /doc/_static/tensorboard_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/tensorboard_2.png -------------------------------------------------------------------------------- /doc/_static/tensorboard_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/doc/_static/tensorboard_3.png -------------------------------------------------------------------------------- /doc/changelog.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | .. Unreleased 6 | ---------- 7 | 8 | 0.6.0 - 2020-01-09 9 | ------------------ 10 | 11 | Added 12 | ^^^^^ 13 | 14 | * Added ``TableCell`` object in PAGEXML. (@CrazyCrud) 15 | 16 | Changed 17 | ^^^^^^^ 18 | 19 | * Changed default value of ``min_area`` value to 0 in ``find_boxes`` and ``find_polygons``. 20 | 21 | Fixed 22 | ^^^^^ 23 | 24 | * If export directory is empty at the end of training, export a model anyway. 25 | * In ``post_processing/polygon_detection`` corrected the output of ``cv2.findContours`` to be compatible with OpenCV >= 4.0. 26 | 27 | 28 | 0.5.0 - 2019-08-14 29 | ------------------ 30 | 31 | Added 32 | ^^^^^ 33 | 34 | * ``https`` can now be used for PAGEXML schema. 35 | * All the ``PAGE`` objects can now have ``custom_attribute`` (this is mainly for tagging purposes). 36 | 37 | 38 | Changed 39 | ^^^^^^^ 40 | 41 | * The ``exps`` folders contains now only two examples that can be used as demos. The other experiments have been removed. 42 | * Installation of ``dh_segment`` package is now done via pip (using ``setup.py``) except for ``tensorflow`` package which is installed with anaconda. 43 | * ``setup.py`` has more flexible package versions. 44 | * Forced integer conversion when exporting coordinates to XML format. 45 | 46 | Fixed 47 | ^^^^^ 48 | 49 | * In page ``demo.py`` a empty ``Border`` is now created if no region has been detected. 50 | 51 | 52 | Removed 53 | ^^^^^^^ 54 | 55 | * Experiments in ``exps`` folder have been removed, except for ``page`` and ``cbad``. 56 | 57 | 58 | 0.4.0 - 2019-04-10 59 | ------------------ 60 | 61 | Added 62 | ^^^^^ 63 | 64 | * Input data can be a .csv file with format ``,``. 65 | * ``dh_segment.io.via`` helper functions to generate/export groundtruth from/to VGG Image Annotation tool. 66 | * ``Point.array_to_point`` to export a ``np.array`` into a list of ``Point``. 67 | * PAGEXML Regions can now contain a custom attribute (Transkribus output of region annotation) 68 | * ``Page.to_json()`` method for json formatting. 69 | 70 | Changed 71 | ^^^^^^^ 72 | 73 | * ``tensorflow`` v1.13 and ``opencv`` v4.0 are now used. 74 | * mIOU metric for evaluation during training (instead of accuracy). 75 | * TextLines are sorted according to their mean `y` coordinate when exported. 76 | 77 | Fixed 78 | ^^^^^ 79 | 80 | * Variable names typos in ``input.py`` and ``train.py``. 81 | * Documentation of the quickstart demo. 82 | 83 | Removed 84 | ^^^^^^^ 85 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'dhSegment' 23 | copyright = '2018, Digital Humanities Lab - EPFL' 24 | author = 'Sofia ARES OLIVEIRA, Benoit SEGUIN' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.autosummary', 44 | 'sphinx.ext.coverage', 45 | 'sphinx.ext.githubpages', 46 | 'sphinxcontrib.bibtex', # for bibtex 47 | 'sphinx_autodoc_typehints', # for typing 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = '.rst' 58 | 59 | # The master toctree document. 60 | master_doc = 'index' 61 | 62 | # The language for content autogenerated by Sphinx. Refer to documentation 63 | # for a list of supported languages. 64 | # 65 | # This is also used if you do content translation via gettext catalogs. 66 | # Usually you set "language" from the command line for these cases. 67 | language = None 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path . 72 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 73 | 74 | # The name of the Pygments (syntax highlighting) style to use. 75 | pygments_style = 'sphinx' 76 | 77 | 78 | # -- Options for HTML output ------------------------------------------------- 79 | 80 | # The theme to use for HTML and HTML Help pages. See the documentation for 81 | # a list of builtin themes. 82 | # 83 | html_theme = 'sphinx_rtd_theme' # alabaster, haiku, nature, pyramid, agogo, bizstyle, sphinx_rtd_theme 84 | 85 | # Theme options are theme-specific and customize the look and feel of a theme 86 | # further. For a list of options available for each theme, see the 87 | # documentation. 88 | # 89 | # html_theme_options = {} 90 | 91 | # Add any paths that contain custom static files (such as style sheets) here, 92 | # relative to this directory. They are copied after the builtin static files, 93 | # so a file named "default.css" will overwrite the builtin "default.css". 94 | html_static_path = ['_static'] 95 | 96 | # Custom sidebar templates, must be a dictionary that maps document names 97 | # to template names. 98 | # 99 | # The default sidebars (for documents that don't match any pattern) are 100 | # defined by theme itself. Builtin themes are using these templates by 101 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 102 | # 'searchbox.html']``. 103 | # 104 | # html_sidebars = {} 105 | 106 | 107 | # -- Options for HTMLHelp output --------------------------------------------- 108 | 109 | # Output file base name for HTML help builder. 110 | htmlhelp_basename = 'dhsegmentdoc' 111 | 112 | 113 | # -- Options for LaTeX output ------------------------------------------------ 114 | 115 | latex_elements = { 116 | # The paper size ('letterpaper' or 'a4paper'). 117 | # 118 | # 'papersize': 'letterpaper', 119 | 120 | # The font size ('10pt', '11pt' or '12pt'). 121 | # 122 | # 'pointsize': '10pt', 123 | 124 | # Additional stuff for the LaTeX preamble. 125 | # 126 | # 'preamble': '', 127 | 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class]). 136 | latex_documents = [ 137 | (master_doc, 'dhsegment.tex', 'dhsegment Documentation', 138 | author, 'manual'), 139 | ] 140 | 141 | 142 | # -- Options for manual page output ------------------------------------------ 143 | 144 | # One entry per manual page. List of tuples 145 | # (source start file, name, description, authors, manual section). 146 | man_pages = [ 147 | (master_doc, 'dhsegment', 'dhsegment Documentation', 148 | [author], 1) 149 | ] 150 | 151 | 152 | # -- Options for Texinfo output ---------------------------------------------- 153 | 154 | # Grouping the document tree into Texinfo files. List of tuples 155 | # (source start file, target name, title, author, 156 | # dir menu entry, description, category) 157 | texinfo_documents = [ 158 | (master_doc, 'dhsegment', 'dhsegment Documentation', 159 | author, 'dhsegment', 'One line description of project.', 160 | 'Miscellaneous'), 161 | ] 162 | 163 | 164 | # -- Extension configuration ------------------------------------------------- 165 | 166 | autodoc_mock_imports = [ 167 | # 'numpy', 168 | 'scipy', 169 | 'tensorflow', 170 | 'pandas', 171 | 'sklearn', 172 | 'skimage', 173 | 'shapely', 174 | 'typing', 175 | 'cv2', 176 | 'tqdm', 177 | 'imageio', 178 | 'PIL' 179 | ] 180 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. dhsegment documentation master file, created by 2 | sphinx-quickstart on Mon Oct 1 17:17:21 2018. 3 | 4 | ================================================================ 5 | dhSegment : Generic framework for historical document processing 6 | ================================================================ 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | intro/intro 12 | start/index 13 | reference/index 14 | references 15 | changelog 16 | 17 | **dhSegment** is a tool for Historical Document Processing. Its generic approach allows to segment regions and 18 | extract content from different type of documents. See some example of applications in the :ref:`usecases-label` section. 19 | 20 | The complete description of the system can be found in the corresponding `paper`_ :cite:`oliveiraseguin2018dhsegment` . 21 | 22 | .. _paper: https://arxiv.org/abs/1804.10371 23 | 24 | Indices and tables 25 | ------------------ 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | 31 | 32 | Acknowledgement 33 | ^^^^^^^^^^^^^^^ 34 | 35 | This work has been partly funded by the European Union’s Horizon 2020 research and 36 | innovation programme under grant agreement No 674943. -------------------------------------------------------------------------------- /doc/intro/intro.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Introduction 3 | ============ 4 | 5 | What is dhSegment? 6 | ------------------ 7 | 8 | .. image:: ../_static/system.png 9 | :width: 60 % 10 | :align: center 11 | :alt: dhSegment system 12 | 13 | 14 | dhSegment is a generic approach for Historical Document Processing. 15 | It relies on a Convolutional Neural Network to do the heavy lifting of predicting pixelwise characteristics. 16 | Then simple image processing operations are provided to extract the components of interest (boxes, polygons, lines, masks, ...) 17 | 18 | A few key facts: 19 | 20 | - You only need to provide a list of images with annotated masks, which can easily be created with an image editing software (Gimp, Photoshop). You only need to draw the elements you care about! 21 | 22 | - Allows to classify each pixel across multiple classes, with the possibility of assigning multiple labels per pixel. 23 | 24 | - On-the-fly data augmentation, and efficient batching of batches. 25 | 26 | - Leverages a state-of-the-art pre-trained network (Resnet50) to lower the need for training data and improve generalization. 27 | 28 | - Monitor training on Tensorboard very easily. 29 | 30 | - A list of simple image processing operations are already implemented such that the post-processing steps only take a couple of lines. 31 | 32 | What sort of training data do I need? 33 | --------------------------------------- 34 | 35 | Each training sample consists in an image of a document and its corresponding parts to be predicted. 36 | 37 | .. image:: ../_static/cini_input.jpg 38 | :width: 45 % 39 | :alt: example image input 40 | .. image:: ../_static/cini_labels.jpg 41 | :width: 45 % 42 | :alt: example label 43 | 44 | Additionally, a text file encoding the RGB values of the classes needs to be provided. 45 | In this case if we want the classes 'background', 'document' and 'photograph' to be respectively 46 | classes 0, 1, and 2 we need to encode their color line-by-line: :: 47 | 48 | 0 255 0 49 | 255 0 0 50 | 0 0 255 51 | 52 | .. _usecases-label: 53 | 54 | Use cases 55 | --------- 56 | 57 | Page Segmentation 58 | ^^^^^^^^^^^^^^^^^ 59 | 60 | .. image:: ../_static/page.jpg 61 | :width: 50 % 62 | :alt: page extraction use case 63 | 64 | Dataset : READ-BAD :cite:`gruning2018read` annotated by :cite:`tensmeyer2017pagenet`. 65 | 66 | 67 | Layout Analysis 68 | ^^^^^^^^^^^^^^^ 69 | 70 | .. image:: ../_static/diva.jpg 71 | :width: 45 % 72 | :alt: diva use case 73 | .. image:: ../_static/diva_preds.png 74 | :width: 45 % 75 | :alt: diva predictions use case 76 | 77 | Dataset : DIVA-HisDB :cite:`simistira2016diva`. 78 | 79 | Ornament Extraction 80 | ^^^^^^^^^^^^^^^^^^^ 81 | 82 | .. image:: ../_static/ornaments.jpg 83 | :width: 50 % 84 | :alt: ornaments use case 85 | 86 | Dataset : BCU collection. 87 | 88 | 89 | Line Detection 90 | ^^^^^^^^^^^^^^ 91 | 92 | .. image:: ../_static/cbad.jpg 93 | :width: 70 % 94 | :alt: line extraction use case 95 | 96 | Dataset : READ-BAD :cite:`gruning2018read`. 97 | 98 | 99 | Document Segmentation 100 | ^^^^^^^^^^^^^^^^^^^^^ 101 | 102 | .. image:: ../_static/cini.jpg 103 | :width: 70 % 104 | :alt: cini photo collection extraction use case 105 | 106 | Dataset : Photo-collection from the Cini Foundation. 107 | 108 | 109 | Tensorboard Integration 110 | ----------------------- 111 | The TensorBoard integration allows to visualize your TensorFlow graph, plot metrics 112 | and show the images and predictions during the execution of the graph. 113 | 114 | .. image:: ../_static/tensorboard_1.png 115 | :width: 65 % 116 | :alt: tensorboard example 1 117 | .. image:: ../_static/tensorboard_2.png 118 | :width: 65 % 119 | :alt: tensorboard example 2 120 | .. image:: ../_static/tensorboard_3.png 121 | :width: 65 % 122 | :alt: tensorboard example 3 -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Reference guide 3 | =============== 4 | 5 | .. automodule:: dh_segment 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | network 11 | io 12 | inference 13 | post_processing 14 | utils -------------------------------------------------------------------------------- /doc/reference/inference.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Inference 3 | ========= 4 | 5 | .. automodule:: dh_segment.inference 6 | :members: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /doc/reference/io.rst: -------------------------------------------------------------------------------- 1 | .. comment 2 | Interface 3 | ========= 4 | 5 | Input functions for ``tf.Estimator`` 6 | ------------------------------------ 7 | 8 | Input function 9 | 10 | .. autosummary:: 11 | input.input_fn 12 | 13 | Data augmentation 14 | 15 | .. autosummary:: 16 | data_augmentation_fn 17 | extract_patches_fn 18 | rotate_crop 19 | 20 | Resizing function 21 | 22 | .. autosummary:: 23 | dh_segment.io.input_utils.resize_image 24 | dh_segment.io.input_utils.load_and_resize_image 25 | 26 | 27 | Tensorflow serving functions 28 | ---------------------------- 29 | 30 | .. autosummary:: 31 | dh_segment.io.input.serving_input_filename 32 | dh_segment.io.input.serving_input_image 33 | 34 | ---- 35 | 36 | PAGE XML and JSON import / export 37 | --------------------------------- 38 | 39 | PAGE classes 40 | 41 | .. autosummary:: 42 | dh_segment.io.PAGE.Point 43 | dh_segment.io.PAGE.Text 44 | dh_segment.io.PAGE.Border 45 | dh_segment.io.PAGE.TextRegion 46 | dh_segment.io.PAGE.TextLine 47 | dh_segment.io.PAGE.GraphicRegion 48 | dh_segment.io.PAGE.TableRegion 49 | dh_segment.io.PAGE.SeparatorRegion 50 | dh_segment.io.PAGE.GroupSegment 51 | dh_segment.io.PAGE.Metadata 52 | dh_segment.io.PAGE.Page 53 | 54 | Abstract classes 55 | 56 | .. autosummary:: 57 | dh_segment.io.PAGE.BaseElement 58 | dh_segment.io.PAGE.Region 59 | 60 | Parsing and helpers 61 | 62 | .. autosummary:: 63 | dh_segment.io.PAGE.parse_file 64 | dh_segment.io.PAGE.json_serialize 65 | 66 | ---- 67 | 68 | ============== 69 | Input / Output 70 | ============== 71 | 72 | .. automodule:: dh_segment.io 73 | :members: 74 | :undoc-members: 75 | 76 | .. automodule:: dh_segment.io.PAGE 77 | :members: 78 | :undoc-members: 79 | 80 | .. automodule:: dh_segment.io.via 81 | :members: 82 | :undoc-members: 83 | :exclude-members: main, init_logger -------------------------------------------------------------------------------- /doc/reference/network.rst: -------------------------------------------------------------------------------- 1 | Network architecture 2 | ==================== 3 | 4 | Here is the dhsegment architecture definition 5 | 6 | ----- 7 | 8 | .. automodule:: dh_segment.network 9 | :members: 10 | :undoc-members: 11 | 12 | -------------------------------------------------------------------------------- /doc/reference/post_processing.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Post processing 3 | =============== 4 | 5 | .. automodule:: dh_segment.post_processing 6 | :members: 7 | :undoc-members: 8 | 9 | -------------------------------------------------------------------------------- /doc/reference/utils.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Utilities 3 | ========= 4 | 5 | .. automodule:: dh_segment.utils 6 | :members: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /doc/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{oliveiraseguin2018dhsegment, 2 | title={dhSegment: A generic deep-learning approach for document segmentation}, 3 | author={Ares Oliveira, Sofia and Seguin, Benoit and Kaplan, Frederic}, 4 | booktitle={Frontiers in Handwriting Recognition (ICFHR), 2018 16th International Conference on}, 5 | pages={7--12}, 6 | year={2018}, 7 | organization={IEEE} 8 | } 9 | 10 | @inproceedings{tensmeyer2017pagenet, 11 | title={Pagenet: Page boundary extraction in historical handwritten documents}, 12 | author={Tensmeyer, Chris and Davis, Brian and Wigington, Curtis and Lee, Iain and Barrett, Bill}, 13 | booktitle={Proceedings of the 4th International Workshop on Historical Document Imaging and Processing}, 14 | pages={59--64}, 15 | year={2017}, 16 | organization={ACM} 17 | } 18 | 19 | @inproceedings{gruning2018read, 20 | title={READ-BAD: A new dataset and evaluation scheme for baseline detection in archival documents}, 21 | author={Gr{\"u}ning, Tobias and Labahn, Roger and Diem, Markus and Kleber, Florian and Fiel, Stefan}, 22 | booktitle={2018 13th IAPR International Workshop on Document Analysis Systems (DAS)}, 23 | pages={351--356}, 24 | year={2018}, 25 | organization={IEEE} 26 | } 27 | 28 | @inproceedings{simistira2016diva, 29 | title={Diva-hisdb: A precisely annotated large dataset of challenging medieval manuscripts}, 30 | author={Simistira, Foteini and Seuret, Mathias and Eichenberger, Nicole and Garz, Angelika and Liwicki, Marcus and Ingold, Rolf}, 31 | booktitle={Frontiers in Handwriting Recognition (ICFHR), 2016 15th International Conference on}, 32 | pages={471--476}, 33 | year={2016}, 34 | organization={IEEE} 35 | } 36 | 37 | -------------------------------------------------------------------------------- /doc/references.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | .. bibliography:: references.bib 6 | :cited: 7 | :style: alpha -------------------------------------------------------------------------------- /doc/start/annotating.rst: -------------------------------------------------------------------------------- 1 | Creating groundtruth data 2 | ------------------------- 3 | 4 | Using GIMP or Photoshop 5 | ^^^^^^^^^^^^^^^^^^^^^^^ 6 | Create directly your masks using your favorite image editor. You just have to draw the regions you want to extract 7 | with a different color for each label. 8 | 9 | Using VGG Image Annotator (VIA) 10 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 11 | `VGG Image Annotator (VIA) `_ is an image annotation tool that can be 12 | used to define regions in an image and create textual descriptions of those regions. You can either use it 13 | `online `_ or 14 | `download the application `_. 15 | 16 | From the exported annotations (in JSON format), you'll have to generate the corresponding image masks. 17 | See the :ref:`ref_via` in the ``via`` module. 18 | 19 | When assigning attributes to your annotated regions, you should favour attributes of type "dropdown", "checkbox" 20 | and "radio" and avoid "text" type in order to ease the parsing of the exported file (avoid typos and formatting errors). 21 | 22 | **Example of how to create individual masks from VIA annotation file** 23 | 24 | .. code:: python 25 | 26 | from dh_segment.io import via 27 | 28 | collection = 'mycollection' 29 | annotation_file = 'via_sample.json' 30 | masks_dir = '/home/project/generated_masks' 31 | images_dir = './my_images' 32 | 33 | # Load all the data in the annotation file 34 | # (the file may be an exported project or an export of the annotations) 35 | via_data = via.load_annotation_data(annotation_file) 36 | 37 | # In the case of an exported project file, you can set ``only_img_annotations=True`` 38 | # to get only the image annotations 39 | via_annotations = via.load_annotation_data(annotation_file, only_img_annotations=True) 40 | 41 | # Collect the annotated regions 42 | working_items = via.collect_working_items(via_annotations, collection, images_dir) 43 | 44 | # Collect the attributes and options 45 | if '_via_attributes' in via_data.keys(): 46 | list_attributes = via.parse_via_attributes(via_data['_via_attributes']) 47 | else: 48 | list_attributes = via.get_via_attributes(via_annotations) 49 | 50 | # Create one mask per option per attribute 51 | via.create_masks(masks_dir, working_items, list_attributes, collection) 52 | 53 | .. Using Transkribus 54 | ^^^^^^^^^^^^^^^^^ 55 | It is possible to generate PAGEXML files with `Transkribus `_. 56 | However the helpers functions to generate the masks images from XML have not been implemented yet. -------------------------------------------------------------------------------- /doc/start/demo.rst: -------------------------------------------------------------------------------- 1 | Demo 2 | ---- 3 | 4 | This demo shows the usage of dhSegment for page document extraction. 5 | It trains a model from scratch (optional) using the READ-BAD dataset :cite:`gruning2018read` 6 | and the annotations of `Pagenet`_ :cite:`tensmeyer2017pagenet` (annotator1 is used). 7 | In order to limit memory usage, the images in the dataset we provide have been downsized to have 1M pixels each. 8 | 9 | .. _Pagenet: https://github.com/ctensmeyer/pagenet/tree/master/annotations 10 | 11 | 12 | **How to** 13 | 14 | 0. If you have not yet done so, clone the repository : :: 15 | 16 | git clone https://github.com/dhlab-epfl/dhSegment.git 17 | 18 | 1. Get the annotated dataset `here`_, which already contains the folders ``images`` and ``labels`` 19 | for training, validation and testing set. Unzip it into ``demo/pages``. :: 20 | 21 | cd demo/ 22 | wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip 23 | unzip pages.zip 24 | cd .. 25 | 26 | .. _here: https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip 27 | 28 | 2. (Only needed if training from scratch) Download the pretrained weights for ResNet : :: 29 | 30 | cd pretrained_models/ 31 | python download_resnet_pretrained_model.py 32 | cd .. 33 | 34 | 3. You can train the model from scratch with: ``python train.py with demo/demo_config.json`` 35 | but because this takes quite some time, we recommend you to skip this and just download the 36 | `provided model`_ (download and unzip it in ``demo/model``) :: 37 | 38 | cd demo/ 39 | wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip 40 | unzip model.zip 41 | cd .. 42 | 43 | .. _provided model : https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip 44 | 45 | 4. (Only if training from scratch) You can visualize the progresses in tensorboard by running 46 | ``tensorboard --logdir .`` in the ``demo`` folder. 47 | 48 | 5. Run ``python demo.py`` 49 | 50 | 6. Have a look at the results in ``demo/processed_images`` 51 | 52 | -------------------------------------------------------------------------------- /doc/start/index.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | .. toctree:: 5 | install 6 | annotating 7 | training 8 | demo -------------------------------------------------------------------------------- /doc/start/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | 4 | It is recommended to install ``tensorflow`` (or ``tensorflow-gpu``) independently using Anaconda distribution, 5 | in order to make sure all dependencies are properly installed. 6 | 7 | 1. Clone the repository using ``git clone https://github.com/dhlab-epfl/dhSegment.git`` 8 | 9 | 2. Install Anaconda or Miniconda (`installation procedure `_) 10 | 11 | 3. Create a virtual environment and activate it :: 12 | 13 | conda create -n dh_segment python=3.6 14 | source activate dh_segment 15 | 16 | 17 | 4. Install dhSegment dependencies with ``pip install git+https://github.com/dhlab-epfl/dhSegment`` 18 | 19 | 5. Install TensorFlow 1.13 with conda ``conda install tensorflow-gpu=1.13.1``. 20 | -------------------------------------------------------------------------------- /doc/start/training.rst: -------------------------------------------------------------------------------- 1 | Training 2 | -------- 3 | 4 | .. note:: A good nvidia GPU (6GB RAM at least) is most likely necessary to train your own models. We assume CUDA 5 | and cuDNN are installed. 6 | 7 | **Input data** 8 | 9 | You need to have your training data in a folder containing ``images`` folder and ``labels`` folder. 10 | The pairs (images, labels) need to have the same name (it is not mandatory to have the same extension file, 11 | however we recommend having the label images as ``.png`` files). 12 | 13 | The annotated images in ``label`` folder are (usually) RGB images with the regions to segment annotated with 14 | a specific color. 15 | 16 | .. note:: It is now also possible to use a `csv` file containing the pairs ``original_image_filename``, 17 | ``label_image_filename`` as input data. 18 | 19 | To input a ``csv`` file instead of the two folders ``images`` and ``labels``, 20 | the content should be formatted in the following way: :: 21 | 22 | mypath/myfolder/original_image_filename1,mypath/myfolder/label_image_filename1 23 | mypath/myfolder/original_image_filename2,mypath/myfolder/label_image_filename2 24 | 25 | 26 | 27 | **The class.txt file** 28 | 29 | The file containing the classes has the format shown below, where each row corresponds to one class 30 | (including 'negative' or 'background' class) and each row has 3 values for the 3 RGB values. 31 | Of course each class needs to have a different code. :: 32 | 33 | classes.txt 34 | 35 | 0 0 0 36 | 0 255 0 37 | ... 38 | 39 | 40 | **Config file with ``sacred``** 41 | 42 | `sacred`_ package is used to deal with experiments and trainings. Have a look at the documentation to use it properly. 43 | 44 | In order to train a model, you should run ``python train.py with `` 45 | 46 | .. _sacred: https://sacred.readthedocs.io/en/latest/quickstart.html 47 | 48 | 49 | Multilabel classification training 50 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 51 | 52 | In case you want to be able to assign multiple labels to elements, the ``classes.txt`` file must be changed. 53 | Besides the color code, you need to add an *attribution* code to each color. The attribution code has length `n_classes` 54 | and indicates which classes are assigned to the color. 55 | 56 | Take for example 3 classes {A, B, C} and the following possible labelling combinations: 57 | 58 | - A (color code ``(0 255 0)``) with attribution code ``1 0 0`` 59 | - B (color code ``(255 0 0)``) with attribution code ``0 1 0`` 60 | - C (color code ``(0 0 255)``) with attribution code ``0 0 1`` 61 | - AB (color code ``(128 128 128)``) with attribution code ``1 1 0`` 62 | - BC (color code ``(0 255 255)``) with attribution code ``0 1 1`` 63 | 64 | The attributions code has value ``1`` when the label is assigned and ``0`` when it's not. 65 | (The attribution code ``1 0 1`` would mean that the color annotates elements that belong to classes A and C) 66 | 67 | In our example the ``classes.txt`` file would then look like : :: 68 | 69 | 70 | classes.txt 71 | 72 | 0 0 0 0 0 0 73 | 0 255 0 1 0 0 74 | 255 0 0 0 1 0 75 | 0 0 255 0 0 1 76 | 128 128 128 1 1 0 77 | 0 255 255 0 1 1 78 | -------------------------------------------------------------------------------- /doc/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | -------------------------------------------------------------------------------- /exps/README.md: -------------------------------------------------------------------------------- 1 | ## Experiments 2 | 3 | - `page` experiment on [PageNet dataset](https://dl.acm.org/citation.cfm?id=3151522) 4 | - `cBAD` experiment on [READ-BAD dataset](https://arxiv.org/abs/1705.03311) 5 | -------------------------------------------------------------------------------- /exps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment/cca94e94aec52baa9350eaaa60c006d7fde103b7/exps/__init__.py -------------------------------------------------------------------------------- /exps/cbad/README.md: -------------------------------------------------------------------------------- 1 | # cBAD experiment 2 | 3 | ## Dataset 4 | Dataset from [ICDAR 2017 Competition on Baseline Detection in Archival Documents (cBAD)](https://zenodo.org/record/835441) ([paper](https://arxiv.org/abs/1705.03311)) 5 | 6 | ## Demo 7 | 8 | 1. Dowload pretrained weights. 9 | ``` shell 10 | $ cd pretrained_models/ 11 | $ python download_resnet_pretrained_model.py 12 | $ cd .. 13 | ``` 14 | 15 | 2. Run the script `make_cbad.py` that will download the dataset and create the masks. This may take some time (10-20 min). 16 | 17 | ``` shell 18 | $ cd exps/cbad 19 | $ python make_cbad.py --downloading_dir ../../data/cbad-dataset --masks_dir ../../data/cbad-masks 20 | $ cd ../.. 21 | ``` 22 | 23 | 3. To train a model run from the root directory `python train.py with demo/demo_cbad_config.json`. 24 | If you changed the default directory of `--masks_dir` make sure to update the file `demo_cbad_config.json`. 25 | 26 | 4. To use the trained model on new data, run the `demo_processing.py` script : 27 | ``` shell 28 | $ cd exps/cbad 29 | $ python demo_processing.py ../../data/cbad-masks/simple/test/images/*.jpg 30 | --model_dir ../../demo/cbad_simple_model/export/ 31 | --output_dir ../../demo/baseline_extraction_output 32 | --draw_extractions 1 33 | $ cd ../.. 34 | ``` 35 | 5. Have a look at the result in the folder `demo/baseline_extraction_output`. -------------------------------------------------------------------------------- /exps/cbad/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' -------------------------------------------------------------------------------- /exps/cbad/demo_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | import os 7 | from typing import List 8 | from dh_segment.inference import LoadedModel 9 | from dh_segment.io import PAGE 10 | from process import line_extraction_v1 11 | from imageio import imread, imsave 12 | from tqdm import tqdm 13 | import click 14 | 15 | 16 | @click.command() 17 | @click.argument('filenames_to_process', nargs=-1) 18 | @click.option('--model_dir', help="The directory of te model to use") 19 | @click.option('--output_dir', help="Directory to output the PAGEXML files") 20 | @click.option('--draw_extractions', help="If true, the extracted lines will be drawn and exported to the output_dir") 21 | def baseline_extraction(model_dir: str, 22 | filenames_to_process: List[str], 23 | output_dir: str, 24 | draw_extractions: bool=False, 25 | config: tf.ConfigProto=None) -> None: 26 | """ 27 | Given a model directory this function will load the model and apply it to the given files. 28 | 29 | :param model_dir: Directory containing the saved model 30 | :param filenames_to_process: filenames of the images to process 31 | :param output_dir: output directory to save the predictions (probability images) 32 | :param draw_extractions: 33 | :param config: ``ConfigProto`` object for ``tf.Session``. 34 | :return: 35 | """ 36 | 37 | os.makedirs(output_dir, exist_ok=True) 38 | if draw_extractions: 39 | drawing_dir = os.path.join(output_dir, 'drawings') 40 | os.makedirs(drawing_dir) 41 | 42 | with tf.Session(config=config): 43 | # Load the model 44 | m = LoadedModel(model_dir, predict_mode='filename_original_shape') 45 | for filename in tqdm(filenames_to_process, desc='Prediction'): 46 | # Inference 47 | prediction = m.predict(filename) 48 | # Take the first element of the 'probs' dictionary (batch size = 1) 49 | probs = prediction['probs'][0] 50 | original_shape = probs.shape 51 | 52 | # The baselines probs are on the second channel 53 | baseline_probs = probs[:, :, 1] 54 | contours, _ = line_extraction_v1(baseline_probs, low_threshold=0.2, high_threshold=0.4, sigma=1.5) 55 | 56 | basename = os.path.basename(filename).split('.')[0] 57 | 58 | # Compute the ratio to save the coordinates in the original image coordinates reference. 59 | ratio = (original_shape[0] / probs.shape[0], original_shape[1] / probs.shape[1]) 60 | xml_filename = os.path.join(output_dir, basename + '.xml') 61 | page_object = PAGE.save_baselines(xml_filename, contours, ratio, predictions_shape=probs.shape[:2]) 62 | 63 | # If specified, saves the images with the annotated baslines 64 | if draw_extractions: 65 | image = imread(filename) 66 | page_object.draw_baselines(image, color=(255, 0, 0), thickness=5) 67 | 68 | basename = os.path.basename(filename) 69 | imsave(os.path.join(drawing_dir, basename), image) 70 | 71 | if __name__ == '__main__': 72 | baseline_extraction() 73 | -------------------------------------------------------------------------------- /exps/cbad/evaluation.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import subprocess 4 | from glob import glob 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from dh_segment.io import PAGE 8 | from .process import extract_lines 9 | 10 | CBAD_JAR = './cBAD/TranskribusBaseLineEvaluationScheme_v0.1.3/' \ 11 | 'TranskribusBaseLineEvaluationScheme-0.1.3-jar-with-dependencies.jar' 12 | PP_PARAMS = {'sigma': 1.5, 'low_threshold': 0.2, 'high_threshold': 0.4} 13 | 14 | 15 | def eval_fn(input_dir: str, 16 | groudtruth_dir: str, 17 | output_dir: str=None, 18 | post_process_params: dict=PP_PARAMS, 19 | channel_baselines: int=1, 20 | jar_tool_path: str=CBAD_JAR, 21 | masks_dir: str=None) -> dict: 22 | """ 23 | Evaluates a model against the selected set ('groundtruth_dir' contains XML files) 24 | 25 | :param input_dir: Input directory containing probability maps (.npy) 26 | :param groudtruth_dir: directory containg XML groundtruths 27 | :param output_dir: output directory for results 28 | :param post_process_params: parameters form post processing of probability maps 29 | :param channel_baselines: the baseline class chanel 30 | :param jar_tool_path: path to cBAD evaluation tool (.jar file) 31 | :param masks_dir: optional, directory where binary masks of the page are stored (.png) 32 | :return: 33 | """ 34 | 35 | if output_dir is None: 36 | output_dir = input_dir 37 | 38 | # Apply post processing and find lines 39 | for file in tqdm(glob(os.path.join(input_dir, '*.npy'))): 40 | basename = os.path.basename(file).split('.')[0] 41 | gt_xml_filename = os.path.join(groudtruth_dir, basename + '.xml') 42 | gt_page_xml = PAGE.parse_file(gt_xml_filename) 43 | 44 | original_shape = [gt_page_xml.image_height, gt_page_xml.image_width] 45 | 46 | _, _ = extract_lines(file, output_dir, original_shape, post_process_params, 47 | channel_baselines=channel_baselines, mask_dir=masks_dir) 48 | 49 | # Create pairs predicted XML - groundtruth XML to be evaluated 50 | xml_pred_filenames_list = glob(os.path.join(output_dir, '*.xml')) 51 | xml_filenames_tuples = list() 52 | for xml_filename in xml_pred_filenames_list: 53 | basename = os.path.basename(xml_filename) 54 | gt_xml_filename = os.path.join(groudtruth_dir, basename) 55 | 56 | xml_filenames_tuples.append((gt_xml_filename, xml_filename)) 57 | 58 | gt_pages_list_filename = os.path.join(output_dir, 'gt_pages_simple.lst') 59 | generated_pages_list_filename = os.path.join(output_dir, 'generated_pages_simple.lst') 60 | with open(gt_pages_list_filename, 'w') as f: 61 | f.writelines('\n'.join([s[0] for s in xml_filenames_tuples])) 62 | with open(generated_pages_list_filename, 'w') as f: 63 | f.writelines('\n'.join([s[1] for s in xml_filenames_tuples])) 64 | 65 | # Evaluation using JAVA Tool 66 | cmd = 'java -jar {} {} {}'.format(jar_tool_path, gt_pages_list_filename, generated_pages_list_filename) 67 | result = subprocess.check_output(cmd, shell=True).decode() 68 | with open(os.path.join(output_dir, 'scores.txt'), 'w') as f: 69 | f.write(result) 70 | parse_score_txt(result, os.path.join(output_dir, 'scores.csv')) 71 | 72 | # Parse results from output of tool 73 | lines = result.splitlines() 74 | avg_precision = float(next(filter(lambda l: 'Avg (over pages) P value:' in l, lines)).split()[-1]) 75 | avg_recall = float(next(filter(lambda l: 'Avg (over pages) R value:' in l, lines)).split()[-1]) 76 | f_measure = float(next(filter(lambda l: 'Resulting F_1 value:' in l, lines)).split()[-1]) 77 | 78 | print('P {}, R {}, F {}'.format(avg_precision, avg_recall, f_measure)) 79 | 80 | return { 81 | 'avg_precision': avg_precision, 82 | 'avg_recall': avg_recall, 83 | 'f_measure': f_measure 84 | } 85 | 86 | 87 | def parse_score_txt(score_txt: str, output_csv: str): 88 | lines = score_txt.splitlines() 89 | header_ind = next((i for i, l in enumerate(lines) 90 | if l == '#P value, #R value, #F_1 value, #TruthFileName, #HypoFileName')) 91 | final_line = next((i for i, l in enumerate(lines) if i > header_ind and l == '')) 92 | csv_data = '\n'.join(lines[header_ind:final_line]) 93 | df = pd.read_csv(io.StringIO(csv_data)) 94 | df = df.rename(columns={k: k.strip() for k in df.columns}) 95 | df['#HypoFileName'] = [os.path.basename(f).split('.')[0] for f in df['#HypoFileName']] 96 | del df['#TruthFileName'] 97 | df = df.rename(columns={'#P value': 'P', '#R value': 'R', '#F_1 value': 'F_1', '#HypoFileName': 'basename'}) 98 | df = df.reindex(columns=['basename', 'F_1', 'P', 'R']) 99 | df = df.sort_values('F_1', ascending=True) 100 | df.to_csv(output_csv, index=False) 101 | -------------------------------------------------------------------------------- /exps/cbad/example_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from scipy.misc import imread\n", 13 | "from tqdm import tqdm\n", 14 | "from .evaluation import eval_fn\n", 15 | "from .process import prediction_fn, extract_lines" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "model_dirs_list = ['model1/export/timestamp/', \n", 25 | " 'model2/export/timestamp/']" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "set_dir = './baseline_dataset/images/'" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Prediction" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "for model_dir in model_dirs_list:\n", 51 | " output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions'])\n", 52 | " prediction_fn(model_dir, set_dir, output_dir)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Evaluation" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "CBAD_JAR = './TranskribusBaseLineEvaluationScheme_v0.1.3/' \\\n", 78 | " 'TranskribusBaseLineEvaluationScheme-0.1.3-jar-with-dependencies.jar'\n", 79 | "gt_dir = './dataset/test/gt/'\n", 80 | "pred_dir_list = ['./model1/preds_test/',\n", 81 | " './model2/preds_test/']" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "post_process_params = {'sigma': 1.5,\n", 91 | " 'low_threshold': 0.2,\n", 92 | " 'high_threshold': 0.4}" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "list_results = list()\n", 102 | "for pred_dir in pred_dir_list:\n", 103 | " list_results.append(eval_fn(pred_dir, gt_dir, pred_dir, post_process_params, CBAD_JAR))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 2", 119 | "language": "python", 120 | "name": "python2" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 2.0 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython2", 132 | "version": "2.7.6" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 0 137 | } -------------------------------------------------------------------------------- /exps/cbad/make_cbad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | import click 7 | # from utils import cbad_download, cbad_set_generator, split_set_for_eval 8 | from utils import cbad_set_generator, split_set_for_eval 9 | from exps.commonutils import cbad_download, CBAD_TRAIN_COMPLEX_FOLDER, CBAD_TEST_COMPLEX_FOLDER, CBAD_TRAIN_SIMPLE_FOLDER, CBAD_TEST_SIMPLE_FOLDER 10 | 11 | 12 | @click.command() 13 | @click.option('--downloading_dir', help='Directory to download the cBAD-ICDAR17 dataset') 14 | @click.option('--masks_dir', help="Directory where to output the generated masks") 15 | def generate_cbad_dataset(downloading_dir: str, masks_dir: str): 16 | # Check if dataset has already been downloaded 17 | if os.path.exists(downloading_dir): 18 | print('Dataset has already been downloaded. Skipping process.') 19 | else: 20 | # Download dataset 21 | cbad_download(downloading_dir) 22 | 23 | # Create masks 24 | dirs_tuple = [(os.path.join(downloading_dir, CBAD_TRAIN_COMPLEX_FOLDER), os.path.join(masks_dir, 'complex', 'train')), 25 | (os.path.join(downloading_dir, CBAD_TEST_COMPLEX_FOLDER), os.path.join(masks_dir, 'complex', 'test')), 26 | (os.path.join(downloading_dir, CBAD_TRAIN_SIMPLE_FOLDER), os.path.join(masks_dir, 'simple', 'train')), 27 | (os.path.join(downloading_dir, CBAD_TEST_SIMPLE_FOLDER), os.path.join(masks_dir, 'simple', 'test'))] 28 | 29 | print('Creating sets') 30 | for dir_tuple in dirs_tuple: 31 | input_dir, output_dir = dir_tuple 32 | os.makedirs(output_dir, exist_ok=True) 33 | # For each set create the folder with the annotated data 34 | cbad_set_generator(input_dir=input_dir, 35 | output_dir=output_dir, 36 | img_size=2e6, 37 | draw_baselines=True, 38 | draw_endpoints=False) 39 | 40 | # Split the 'official' train set into training and validation set 41 | if 'train' in output_dir: 42 | print('Make eval set from the given training data (0.15/0.85 eval/train)') 43 | csv_filename = os.path.join(output_dir, 'set_data.csv') 44 | split_set_for_eval(csv_filename) 45 | print('Done!') 46 | 47 | 48 | if __name__ == '__main__': 49 | generate_cbad_dataset() 50 | -------------------------------------------------------------------------------- /exps/cbad/process.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import numpy as np 3 | from scipy.ndimage import label 4 | import cv2 5 | import os 6 | import tensorflow as tf 7 | from tqdm import tqdm 8 | from glob import glob 9 | from imageio import imsave, imread 10 | import PIL 11 | from dh_segment.utils import dump_pickle 12 | from dh_segment.post_processing.binarization import hysteresis_thresholding, cleaning_probs 13 | from dh_segment.post_processing.line_vectorization import find_lines 14 | from dh_segment.io import PAGE 15 | from dh_segment.inference import LoadedModel 16 | 17 | 18 | def prediction_fn(model_dir: str, 19 | input_dir: str, 20 | output_dir: str=None, 21 | config: tf.ConfigProto=None) -> None: 22 | """ 23 | Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir. 24 | The predictions will be saved in output_dir as .npy files (values ranging [0,255]) 25 | 26 | :param model_dir: Directory containing the saved model 27 | :param input_dir: input directory where the images to predict are 28 | :param output_dir: output directory to save the predictions (probability images) 29 | :param config: ConfigProto object to pass to the session in order to define which GPU to use 30 | :return: 31 | """ 32 | if not output_dir: 33 | # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions' 34 | output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions']) 35 | 36 | os.makedirs(output_dir, exist_ok=True) 37 | filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) + glob(os.path.join(input_dir, '*.png')) 38 | 39 | with tf.Session(config=config): 40 | m = LoadedModel(model_dir, predict_mode='filename_original_shape') 41 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 42 | pred = m.predict(filename)['probs'][0] 43 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), np.uint8(255 * pred)) 44 | 45 | 46 | def cbad_post_processing_fn(probs: np.array, 47 | baseline_chanel: int=1, 48 | sigma: float=2.5, 49 | low_threshold: float=0.8, 50 | high_threshold: float=0.9, 51 | filter_width: float=0, 52 | vertical_maxima: bool=False, 53 | output_basename=None) -> Tuple[List[np.ndarray], np.ndarray]: 54 | """ 55 | Given a probability map, returns the contour of lines and the corresponding mask. 56 | Saves the results in .pkl file if requested. 57 | 58 | :param probs: output of the model (probabilities) in range [0, 255] 59 | :param baseline_chanel: channel where the baseline class is detected 60 | :param sigma: sigma value for gaussian filtering 61 | :param low_threshold: hysteresis low threshold 62 | :param high_threshold: hysteresis high threshold 63 | :param filter_width: percentage of the image width to filter out lines that are close to borders (default 0.0) 64 | :param output_basename: name of file to save the intermediaty result as .pkl file. 65 | :param vertical_maxima: set to True to use vertical local maxima as candidates for the hysteresis thresholding 66 | :return: contours, mask 67 | WARNING : contours IN OPENCV format List[np.ndarray(n_points, 1, (x,y))] 68 | """ 69 | 70 | contours, lines_mask = line_extraction_v1(probs[:, :, baseline_chanel], sigma, low_threshold, high_threshold, 71 | filter_width, vertical_maxima) 72 | if output_basename is not None: 73 | dump_pickle(output_basename+'.pkl', (contours, lines_mask.shape)) 74 | return contours, lines_mask 75 | 76 | 77 | def line_extraction_v1(probs: np.ndarray, 78 | low_threshold: float, 79 | high_threshold: float, 80 | sigma: float=0.0, 81 | filter_width: float=0.00, 82 | vertical_maxima: bool=False) -> Tuple[List[np.ndarray], np.ndarray]: 83 | """ 84 | Given a probability map, returns the contour of lines and the corresponding mask 85 | 86 | :param probs: probability map (numpy array) 87 | :param low_threshold: hysteresis low threshold 88 | :param high_threshold: hysteresis high threshold 89 | :param sigma: sigma value for gaussian filtering 90 | :param filter_width: percentage of the image width to filter out lines that are close to borders (default 0.0) 91 | :param vertical_maxima: set to True to use vertical local maxima as candidates for the hysteresis thresholding 92 | :return: 93 | """ 94 | # Smooth 95 | probs2 = cleaning_probs(probs, sigma=sigma) 96 | 97 | lines_mask = hysteresis_thresholding(probs2, low_threshold, high_threshold, 98 | candidates_mask=vertical_local_maxima(probs2) if vertical_maxima else None) 99 | # Remove lines touching border 100 | # lines_mask = remove_borders(lines_mask) 101 | 102 | # Extract polygons from line mask 103 | contours = find_lines(lines_mask) 104 | 105 | filtered_contours = [] 106 | page_width = probs.shape[1] 107 | for cnt in contours: 108 | centroid_x, centroid_y = np.mean(cnt, axis=0)[0] 109 | if centroid_x < filter_width*page_width or centroid_x > (1-filter_width)*page_width: 110 | continue 111 | # if cv2.arcLength(cnt, False) < filter_width*page_width: 112 | # continue 113 | filtered_contours.append(cnt) 114 | 115 | return filtered_contours, lines_mask 116 | 117 | 118 | def vertical_local_maxima(probs: np.ndarray) -> np.ndarray: 119 | local_maxima = np.zeros_like(probs, dtype=bool) 120 | local_maxima[1:-1] = (probs[1:-1] >= probs[:-2]) & (probs[2:] <= probs[1:-1]) 121 | local_maxima = cv2.morphologyEx(local_maxima.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), dtype=np.uint8)) 122 | return local_maxima > 0 123 | 124 | 125 | def remove_borders(mask: np.ndarray, margin: int=5) -> np.ndarray: 126 | tmp = mask.copy() 127 | tmp[:margin] = 1 128 | tmp[-margin:] = 1 129 | tmp[:, :margin] = 1 130 | tmp[:, -margin:] = 1 131 | label_components, count = label(tmp, np.ones((3, 3))) 132 | result = mask.copy() 133 | border_component = label_components[0, 0] 134 | result[label_components == border_component] = 0 135 | return result 136 | 137 | 138 | def extract_lines(npy_filename: str, 139 | output_dir: str, 140 | original_shape: list, 141 | post_process_params: dict, 142 | channel_baselines: int=1, 143 | mask_dir: str=None, 144 | debug: bool=False): 145 | """ 146 | From the prediction files (probs) (.npy) finds and extracts the lines into PAGE-XML format. 147 | 148 | :param npy_filename: filename of saved predictions (probs) in range (0,255) 149 | :param output_dir: output direcoty to save the xml files 150 | :param original_shape: shpae of the original input image (to rescale the extracted lines if necessary) 151 | :param post_process_params: pramas for lines detection (sigma, thresholds, ...) 152 | :param channel_baselines: channel where the baseline class is detected 153 | :param mask_dir: directory containing masks of the page in order to improve the line extraction 154 | :param debug: if True will output the binary image of the extracted lines 155 | :return: contours of lines (open cv format), binary image of lines (lines mask) 156 | """ 157 | 158 | os.makedirs(output_dir, exist_ok=True) 159 | 160 | basename = os.path.basename(npy_filename).split('.')[0] 161 | 162 | pred = np.load(npy_filename)/255 # type: np.ndarray 163 | lines_prob = pred[:, :, channel_baselines] 164 | 165 | if mask_dir is not None: 166 | mask = imread(os.path.join(mask_dir, basename + '.png'), mode='L') 167 | mask = np.array(PIL.Image.fromarray(mask, mode='L').resize(lines_prob.shape, resample=PIL.Image.BILINEAR)) 168 | lines_prob[mask == 0] = 0. 169 | 170 | contours, lines_mask = line_extraction_v1(lines_prob, **post_process_params) 171 | 172 | if debug: 173 | imsave(os.path.join(output_dir, '{}_bin.jpg'.format(basename)), lines_mask) 174 | 175 | ratio = (original_shape[0] / pred.shape[0], original_shape[1] / pred.shape[1]) 176 | xml_filename = os.path.join(output_dir, basename + '.xml') 177 | PAGE.save_baselines(xml_filename, contours, ratio, predictions_shape=pred.shape[:2]) 178 | 179 | return contours, lines_mask 180 | -------------------------------------------------------------------------------- /exps/commonutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | from tqdm import tqdm 7 | import urllib 8 | import zipfile 9 | import numpy as np 10 | import cv2 11 | from imageio import imsave 12 | 13 | RANDOM_SEED = 0 14 | np.random.seed(RANDOM_SEED) 15 | 16 | CBAD_TRAIN_COMPLEX_FOLDER = 'cbad-icdar2017-train-complex-documents' 17 | CBAD_TEST_COMPLEX_FOLDER = 'cbad-icdar2017-test-complex-documents' 18 | CBAD_TRAIN_SIMPLE_FOLDER = 'cbad-icdar2017-train-simple-documents' 19 | CBAD_TEST_SIMPLE_FOLDER = 'cbad-icdar2017-test-simple-documents' 20 | 21 | 22 | def get_page_filename(image_filename: str) -> str: 23 | """ 24 | Given an path to a .jpg or .png file, get the corresponding .xml file. 25 | 26 | :param image_filename: filename of the image 27 | :return: the filename of the corresponding .xml file, raises exception if .xml file does not exist 28 | """ 29 | page_filename = os.path.join(os.path.dirname(image_filename), 30 | 'page', 31 | '{}.xml'.format(os.path.basename(image_filename)[:-4])) 32 | 33 | if os.path.exists(page_filename): 34 | return page_filename 35 | else: 36 | raise FileNotFoundError 37 | 38 | 39 | def get_image_label_basename(image_filename: str) -> str: 40 | """ 41 | Creates a new filename composed of the begining of the folder/collection (ex. EPFL, ABP) and the original filename 42 | 43 | :param image_filename: path of the image filename 44 | :return: 45 | """ 46 | # Get acronym followed by name of file 47 | directory, basename = os.path.split(image_filename) 48 | acronym = directory.split(os.path.sep)[-1].split('_')[0] 49 | return '{}_{}'.format(acronym, basename.split('.')[0]) 50 | 51 | 52 | def save_and_resize(img: np.array, 53 | filename: str, 54 | size=None, 55 | nearest: bool=False) -> None: 56 | """ 57 | Resizes the image if necessary and saves it. The resizing will keep the image ratio 58 | 59 | :param img: the image to resize and save (numpy array) 60 | :param filename: filename of the saved image 61 | :param size: size of the image after resizing (in pixels). The ratio of the original image will be kept 62 | :param nearest: whether to use nearest interpolation method (default to False) 63 | :return: 64 | """ 65 | if size is not None: 66 | h, w = img.shape[:2] 67 | ratio = float(np.sqrt(size/(h*w))) 68 | resized = cv2.resize(img, (int(w*ratio), int(h*ratio)), 69 | interpolation=cv2.INTER_NEAREST if nearest else cv2.INTER_LINEAR) 70 | imsave(filename, resized) 71 | else: 72 | imsave(filename, img) 73 | 74 | 75 | # ------------------------------ 76 | 77 | 78 | def _progress_hook(t): 79 | last_b = [0] 80 | 81 | def update_to(b: int=1, bsize: int=1, tsize: int=None): 82 | """ 83 | Adapted from: source unknown 84 | :param b: Number of blocks transferred so far [default: 1]. 85 | :param bsize: Size of each block (in tqdm units) [default: 1]. 86 | :param tsize: Total size (in tqdm units). If [default: None] remains unchanged. 87 | """ 88 | if tsize is not None: 89 | t.total = tsize 90 | t.update((b - last_b[0]) * bsize) 91 | last_b[0] = b 92 | 93 | return update_to 94 | 95 | 96 | def cbad_download(output_dir: str): 97 | """ 98 | Download BAD-READ dataset. 99 | 100 | :param output_dir: folder where to download the data 101 | :return: 102 | """ 103 | os.makedirs(output_dir, exist_ok=True) 104 | zip_filename = os.path.join(output_dir, 'cbad-icdar17.zip') 105 | 106 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc="Downloading cBAD-ICDAR17 dataset") as t: 107 | urllib.request.urlretrieve('https://zenodo.org/record/1491441/files/READ-ICDAR2017-cBAD-dataset-v4.zip', 108 | zip_filename, reporthook=_progress_hook(t)) 109 | print('cBAD-ICDAR2017 dataset downloaded successfully!') 110 | print('Extracting files ...') 111 | with zipfile.ZipFile(zip_filename, 'r') as zip_ref: 112 | zip_ref.extractall(output_dir) 113 | 114 | # Renaming 115 | os.rename(os.path.join(output_dir, 'Test-Baseline Competition - Complex Documents'), 116 | os.path.join(output_dir, CBAD_TEST_COMPLEX_FOLDER)) 117 | os.rename(os.path.join(output_dir, 'Test-Baseline Competition - Simple Documents'), 118 | os.path.join(output_dir, CBAD_TEST_SIMPLE_FOLDER)) 119 | os.rename(os.path.join(output_dir, 'Train-Baseline Competition - Complex Documents'), 120 | os.path.join(output_dir, CBAD_TRAIN_COMPLEX_FOLDER)) 121 | os.rename(os.path.join(output_dir, 'Train-Baseline Competition - Simple Documents'), 122 | os.path.join(output_dir, CBAD_TRAIN_SIMPLE_FOLDER)) 123 | 124 | os.remove(zip_filename) 125 | print('Files extracted and renamed in {}'.format(output_dir)) 126 | -------------------------------------------------------------------------------- /exps/page/README.md: -------------------------------------------------------------------------------- 1 | # Page experiment 2 | Based on paper ["PageNet: Page Boundary Extraction in Historical Handwritten Documents."](https://dl.acm.org/citation.cfm?id=3151522) 3 | 4 | 5 | ## Dataset 6 | The page annotations come from this [repository](https://github.com/ctensmeyer/pagenet/tree/master/annotations). We use READ-cBAD data with _annotator 1_ and _set1_. 7 | 8 | ## Demo 9 | 10 | 1. Dowload pretrained weights. 11 | ``` shell 12 | $ cd pretrained_models/ 13 | $ python download_resnet_pretrained_model.py 14 | $ cd .. 15 | ``` 16 | 17 | 2. Run the script `make_page.py` that will download the dataset and create the masks. This may take some time (10-20 min). 18 | 19 | ``` shell 20 | $ cd exps/page 21 | $ python make_page.py --downloading_dir ../../data/cbad-dataset --masks_dir ../../data/page-masks 22 | $ cd ../.. 23 | ``` 24 | 25 | 26 | 3. To train a model run from the root directory `python train.py with demo/demo_page_config.json`. 27 | If you changed the default directory of `--masks_dir` make sure to update the file `demo_config.json`. 28 | 29 | 4. To use the trained model on new data, run the `demo_processing.py` script : 30 | ``` shell 31 | $ cd exps/page 32 | $ python demo_processing.py ../../data/page-masks/test/images/*.jpg 33 | --model_dir ../../demo/page_model/export/ 34 | --output_dir ../../demo/page_extraction_output 35 | --draw_extractions 1 36 | $ cd ../.. 37 | ``` 38 | 39 | 5. Have a look at the result in the folder `demo/page_extraction_output`. -------------------------------------------------------------------------------- /exps/page/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | -------------------------------------------------------------------------------- /exps/page/demo_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | from typing import List 7 | import os 8 | import cv2 9 | from imageio import imread, imsave 10 | import numpy as np 11 | import click 12 | from tqdm import tqdm 13 | from dh_segment.inference import LoadedModel 14 | from process import page_post_processing_fn 15 | from dh_segment.post_processing.boxes_detection import find_boxes 16 | from dh_segment.io import PAGE 17 | 18 | @click.command() 19 | @click.argument('filenames_to_process', nargs=-1) 20 | @click.option('--model_dir', help="The directory of te model to use") 21 | @click.option('--output_dir', help="Directory to output the PAGEXML files") 22 | @click.option('--draw_extractions', help="If true, the extracted lines will be drawn and exported to the output_dir") 23 | def page_extraction(model_dir: str, 24 | filenames_to_process: List[str], 25 | output_dir: str, 26 | draw_extractions: bool=False, 27 | config: tf.ConfigProto=None): 28 | 29 | os.makedirs(output_dir, exist_ok=True) 30 | if draw_extractions: 31 | drawing_dir = os.path.join(output_dir, 'drawings') 32 | os.makedirs(drawing_dir) 33 | 34 | with tf.Session(config=config): 35 | # Load the model 36 | m = LoadedModel(model_dir, predict_mode='filename') 37 | for filename in tqdm(filenames_to_process, desc='Prediction'): 38 | # Inference 39 | prediction = m.predict(filename) 40 | probs = prediction['probs'][0] 41 | original_shape = prediction['original_shape'] 42 | 43 | probs = probs / np.max(probs) # Normalize to be in [0, 1] 44 | # Binarize the predictions 45 | page_bin = page_post_processing_fn(probs, threshold=-1) 46 | 47 | # Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes) 48 | bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False), 49 | tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST) 50 | 51 | # Find quadrilateral enclosing the page 52 | pred_page_coords = find_boxes(bin_upscaled.astype(np.uint8, copy=False), 53 | mode='min_rectangle', min_area=0.2, n_max_boxes=1) 54 | 55 | if pred_page_coords is not None: 56 | # Write corners points into a .txt file 57 | 58 | # Create page region and XML file 59 | page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :])) 60 | 61 | if draw_extractions: 62 | # Draw page box on original image and export it. Add also box coordinates to the txt file 63 | original_img = imread(filename, pilmode='RGB') 64 | cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5) 65 | 66 | basename = os.path.basename(filename).split('.')[0] 67 | imsave(os.path.join(drawing_dir, '{}_boxes.jpg'.format(basename)), original_img) 68 | 69 | else: 70 | print('No box found in {}'.format(filename)) 71 | page_border = PAGE.Border() 72 | 73 | page_xml = PAGE.Page(image_filename=filename, image_width=original_shape[1], image_height=original_shape[0], 74 | page_border=page_border) 75 | xml_filename = os.path.join(output_dir, '{}.xml'.format(basename)) 76 | page_xml.write_to_file(xml_filename, creator_name='PageExtractor') 77 | 78 | 79 | if __name__ == '__main__': 80 | page_extraction() 81 | -------------------------------------------------------------------------------- /exps/page/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | import os 8 | from imageio import imread 9 | import numpy as np 10 | from .process import extract_page 11 | from dh_segment.utils.evaluation import intersection_over_union, Metrics 12 | 13 | 14 | PP_PARAMS = {'threshold': -1, 'kernel_size': 5} 15 | 16 | 17 | def eval_fn(input_dir: str, groundtruth_dir: str, post_process_params: dict=PP_PARAMS) -> Metrics: 18 | """ 19 | 20 | :param input_dir: directory containing the predictions .npy files (range [0, 255]) 21 | :param groundtruth_dir: directory containing the ground truth images (.png) (must have the same name as predictions 22 | files in input_dir) 23 | :param post_process_params: params for post processing fn 24 | :return: Metrics object containing all the necessary metrics 25 | """ 26 | global_metrics = Metrics() 27 | for file in tqdm(glob(os.path.join(input_dir, '*.npy'))): 28 | basename = os.path.basename(file).split('.')[0] 29 | 30 | prediction = np.load(file) 31 | label_image = imread(os.path.join(groundtruth_dir, '{}.png'.format(basename)), pilmode='L') 32 | 33 | pred_box = extract_page(prediction / np.max(prediction), post_process_params=post_process_params) 34 | label_box = extract_page(label_image / np.max(label_image), min_area=0.0) 35 | 36 | if pred_box is not None and label_box is not None: 37 | iou = intersection_over_union(label_box[:, None, :], pred_box[:, None, :], label_image.shape) 38 | global_metrics.IOU_list.append(iou) 39 | else: 40 | global_metrics.IOU_list.append(0) 41 | 42 | global_metrics.compute_miou() 43 | print('EVAL --- mIOU : {}\n'.format(global_metrics.mIOU)) 44 | 45 | return global_metrics 46 | -------------------------------------------------------------------------------- /exps/page/example_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from glob import glob\n", 13 | "from .process import prediction_fn" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "model_dirs_list = ['model1/export/timestamp/', \n", 23 | " 'model2/export/timestamp/']\n", 24 | "input_dir = 'dataset_page/set/images/'\n", 25 | "output_dir_name = 'out_predictions'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Predictions" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "for model_dir in model_dirs_list:\n", 42 | " output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + [output_dir_name])\n", 43 | " os.makedirs(output_dir, exist_ok=True)\n", 44 | " prediction_fn(model_dir, input_dir, output_dir)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Evaluation" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from .evaluation import eval_fn\n", 70 | "import numpy as np" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "pred_dir_list = [os.path.abspath(os.path.join(md, '..', '..', output_dir_name)) \n", 80 | " for md in model_dirs_list]\n", 81 | "gt_dir = 'dataset_page/set/labels/'" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "list_metrics = list()\n", 91 | "for pred_dir in pred_dir_list:\n", 92 | " metrics = eval_fn(pred_dir, gt_dir)\n", 93 | " list_metrics.append(metrics)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "list_mIOUs = [m.mIOU for m in list_metrics]\n", 103 | "\n", 104 | "print('MIOU : {:.03} +- {:.03} ([{:.03}, {:.03}])'.format(np.mean(list_mIOUs), \n", 105 | " np.std(list_mIOUs), \n", 106 | " np.min(list_mIOUs), \n", 107 | " np.max(list_mIOUs)))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Export" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import json\n", 124 | "export_metric_filename = './metrics_page.json'" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "with open(export_metric_filename, 'w', encoding='utf8') as f:\n", 134 | " json.dump({'{}'.format(i+1): vars(m) for i, m in enumerate(list_metrics)}, f, indent=4)" 135 | ] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "Python 2", 141 | "language": "python", 142 | "name": "python2" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 2.0 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython2", 154 | "version": "2.7.6" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 0 159 | } -------------------------------------------------------------------------------- /exps/page/example_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from glob import glob\n", 13 | "from tqdm import tqdm\n", 14 | "import numpy as np\n", 15 | "import tempfile\n", 16 | "from .process import prediction_fn, extract_page, format_quad_to_string" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "model_dir = 'model1/export/timestamp/'\n", 26 | "input_dir = 'dataset_page/set/images/'\n", 27 | "output_dir = './out_pages'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "pp_params = {'threshold': -1, 'kernel_size': 5}" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "with tempfile.TemporaryDirectory() as tmpdirname:\n", 46 | " prediction_fn(model_dir, input_dir, tmpdirname)\n", 47 | " \n", 48 | " # Export page coordinates in txt file\n", 49 | " with open(os.path.join(output_dir, 'pages.txt'), 'w') as f:\n", 50 | " for filename in tqdm(glob(os.path.join(tmpdirname, '*.npy'))):\n", 51 | " \n", 52 | " prediction = np.load(filename)\n", 53 | " pred_box = extract_page(prediction / np.max(prediction), **pp_params)\n", 54 | " \n", 55 | " f.write('{},{}\\n'.format(filename, format_quad_to_string(pred_box)))" 56 | ] 57 | } 58 | ], 59 | "metadata": { 60 | "kernelspec": { 61 | "display_name": "Python 2", 62 | "language": "python", 63 | "name": "python2" 64 | }, 65 | "language_info": { 66 | "codemirror_mode": { 67 | "name": "ipython", 68 | "version": 2.0 69 | }, 70 | "file_extension": ".py", 71 | "mimetype": "text/x-python", 72 | "name": "python", 73 | "nbconvert_exporter": "python", 74 | "pygments_lexer": "ipython2", 75 | "version": "2.7.6" 76 | } 77 | }, 78 | "nbformat": 4, 79 | "nbformat_minor": 0 80 | } -------------------------------------------------------------------------------- /exps/page/make_page.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import os 6 | import click 7 | from exps.commonutils import cbad_download 8 | from utils import page_files_download, page_set_annotator, \ 9 | format_txt_file, TRAIN_TXT_FILENAME, TEST_TXT_FILENAME, EVAL_TXT_FILENAME 10 | 11 | 12 | @click.command() 13 | @click.option('--downloading_dir', help='Directory to download the cBAD-ICDAR17 dataset') 14 | @click.option('--masks_dir', help="Directory where to output the generated masks") 15 | def generate_page_dataset(downloading_dir: str, masks_dir: str): 16 | # Check if dataset has already been downloaded 17 | if os.path.exists(downloading_dir): 18 | print('Dataset has already been downloaded at {}. Skipping process.'.format(downloading_dir)) 19 | else: 20 | # Download dataset 21 | cbad_download(downloading_dir) 22 | 23 | page_txt_folder = os.path.join(downloading_dir, 'page-txt-files') 24 | if os.path.exists(page_txt_folder): 25 | print('Page txt files have already been downloaded at {}. Skipping process.'.format(page_txt_folder)) 26 | else: 27 | # Download files 28 | page_files_download(page_txt_folder) 29 | 30 | tuple_train = (os.path.join(page_txt_folder, TRAIN_TXT_FILENAME), 31 | '_formatted.'.join(TRAIN_TXT_FILENAME.split('.')), 32 | os.path.join(masks_dir, 'train')) 33 | tuple_test = (os.path.join(page_txt_folder, TEST_TXT_FILENAME), 34 | '_formatted.'.join(TEST_TXT_FILENAME.split('.')), 35 | os.path.join(masks_dir, 'test')) 36 | tuple_eval = (os.path.join(page_txt_folder, EVAL_TXT_FILENAME), 37 | '_formatted.'.join(EVAL_TXT_FILENAME.split('.')), 38 | os.path.join(masks_dir, 'eval')) 39 | 40 | print('Creating sets') 41 | for tup in [tuple_train, tuple_test, tuple_eval]: 42 | input_txt_filename, output_txt_filename, set_masks_dir = tup 43 | 44 | # Format txt files 45 | format_txt_file(input_txt_filename, output_txt_filename, downloading_dir) 46 | 47 | # Create masks 48 | os.makedirs(set_masks_dir, exist_ok=True) 49 | page_set_annotator(output_txt_filename, set_masks_dir) 50 | 51 | print('Done!') 52 | 53 | if __name__ == '__main__': 54 | generate_page_dataset() 55 | -------------------------------------------------------------------------------- /exps/page/process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from glob import glob 10 | from dh_segment.inference import LoadedModel 11 | from imageio import imsave 12 | from dh_segment.post_processing import binarization 13 | from dh_segment.post_processing.boxes_detection import find_boxes 14 | 15 | 16 | def prediction_fn(model_dir: str, input_dir: str, output_dir: str=None, tf_config: tf.ConfigProto=None) -> None: 17 | """ 18 | Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir. 19 | The predictions will be saved in output_dir as .npy files (values ranging [0,255]) 20 | 21 | :param model_dir: Directory containing the saved model 22 | :param input_dir: input directory where the images to predict are 23 | :param output_dir: output directory to save the predictions (probability images) 24 | :return: 25 | """ 26 | if not output_dir: 27 | # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions' 28 | output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions']) 29 | 30 | os.makedirs(output_dir, exist_ok=True) 31 | filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) + glob(os.path.join(input_dir, '*.png')) 32 | # Load model 33 | with tf.Session(config=tf_config): 34 | m = LoadedModel(model_dir, predict_mode='filename_original_shape') 35 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 36 | pred = m.predict(filename)['probs'][0] 37 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), np.uint8(255 * pred)) 38 | 39 | 40 | def page_post_processing_fn(probs: np.ndarray, threshold: float=0.5, output_basename: str=None, 41 | kernel_size: int = 5) -> np.ndarray: 42 | """ 43 | Computes the binary mask of the detected Page from the probabilities outputed by network 44 | 45 | :param probs: array in range [0, 1] of shape HxWx2 46 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 47 | :param output_basename: 48 | :param kernel_size: size of kernel for morphological cleaning 49 | """ 50 | 51 | mask = binarization.thresholding(probs[:, :, 1], threshold=threshold) 52 | result = binarization.cleaning_binary(mask, kernel_size=kernel_size) 53 | 54 | if output_basename is not None: 55 | imsave('{}.png'.format(output_basename), result*255) 56 | return result 57 | 58 | 59 | def format_quad_to_string(quad): 60 | s = '' 61 | for corner in quad: 62 | s += '{},{},'.format(corner[0], corner[1]) 63 | return s[:-1] 64 | 65 | 66 | def extract_page(prediction: np.ndarray, min_area: float=0.2, post_process_params: dict=None) -> list(): 67 | """ 68 | Given an image with probabilities, post-processes it and extracts one box 69 | 70 | :param prediction: probability mask [0, 1] 71 | :param min_area: minimum area to be considered as a valid extraction 72 | :param post_process_params: params for page prost processing function 73 | :return: list of coordinates of boxe 74 | """ 75 | if post_process_params: 76 | post_pred = page_post_processing_fn(prediction, **post_process_params) 77 | else: 78 | post_pred = prediction 79 | pred_box = find_boxes(np.uint8(post_pred), mode='quadrilateral', min_area=min_area, n_max_boxes=1) 80 | 81 | return pred_box 82 | -------------------------------------------------------------------------------- /exps/page/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from imageio import imread, imsave 6 | import numpy as np 7 | import cv2 8 | import os 9 | import re 10 | from tqdm import tqdm 11 | import urllib.request 12 | from exps.commonutils import _progress_hook, CBAD_TEST_SIMPLE_FOLDER, CBAD_TEST_COMPLEX_FOLDER, \ 13 | CBAD_TRAIN_SIMPLE_FOLDER, CBAD_TRAIN_COMPLEX_FOLDER 14 | 15 | TRAIN_FILE_URL = 'https://raw.githubusercontent.com/ctensmeyer/pagenet/master/annotations/cbad_train_annotator_1.txt' 16 | TEST_FILE_URL = 'https://raw.githubusercontent.com/ctensmeyer/pagenet/master/annotations/cbad_test_annotator_1.txt' 17 | EVAL_FILE_URL = 'https://raw.githubusercontent.com/ctensmeyer/pagenet/master/annotations/cbad_val_annotator_1.txt' 18 | 19 | TRAIN_TXT_FILENAME = 'page_train.txt' 20 | TEST_TXT_FILENAME = 'page_test.txt' 21 | EVAL_TXT_FILENAME = 'page_eval.txt' 22 | 23 | 24 | def get_coords_form_txt_line(line: str)-> tuple: 25 | """ 26 | gets the coordinates of the page from the txt file (line-wise) 27 | 28 | :param line: line of the .txt file 29 | :return: coordinates, filename 30 | """ 31 | splits = line.split(',') 32 | full_filename = splits[0] 33 | splits = splits[1:] 34 | if splits[-1] in ['SINGLE', 'ABNORMAL']: 35 | coords_simple = np.reshape(np.array(splits[:-1], dtype=int), (4, 2)) 36 | # coords_double = None 37 | coords = coords_simple 38 | else: 39 | coords_simple = np.reshape(np.array(splits[:8], dtype=int), (4, 2)) 40 | # coords_double = np.reshape(np.array(splits[-4:], dtype=int), (2, 2)) 41 | # coords = (coords_simple, coords_double) 42 | coords = coords_simple 43 | 44 | return coords, full_filename 45 | 46 | 47 | def make_binary_mask(txt_file): 48 | """ 49 | From export txt file with filnenames and coordinates of qudrilaterals, generate binary mask of page 50 | 51 | :param txt_file: txt file filename 52 | :return: 53 | """ 54 | for line in open(txt_file, 'r'): 55 | dirname, _ = os.path.split(txt_file) 56 | c, full_name = get_coords_form_txt_line(line) 57 | img = imread(full_name) 58 | label_img = np.zeros((img.shape[0], img.shape[1]), np.uint8) 59 | label_img = cv2.fillPoly(label_img, [c[:, None, :]], 255) 60 | basename = os.path.basename(full_name) 61 | imsave(os.path.join(dirname, '{}_bin.png'.format(basename.split('.')[0])), label_img) 62 | 63 | 64 | def page_set_annotator(txt_filename: str, output_dir: str): 65 | """ 66 | Given a txt file (filename, coords corners), generates a dataset of images + labels 67 | 68 | :param txt_filename: File (txt) containing list of images 69 | :param input_dir: Root directory to original images 70 | :param output_dir: Output directory for generated dataset 71 | :return: 72 | """ 73 | 74 | output_img_dir = os.path.join(output_dir, 'images') 75 | output_label_dir = os.path.join(output_dir, 'labels') 76 | os.makedirs(output_img_dir, exist_ok=True) 77 | os.makedirs(output_label_dir, exist_ok=True) 78 | 79 | for line in tqdm(open(txt_filename, 'r')): 80 | coords, image_filename = get_coords_form_txt_line(line) 81 | 82 | try: 83 | img = imread(image_filename) 84 | except FileNotFoundError: 85 | print('File {} not found'.format(image_filename)) 86 | continue 87 | label_img = np.zeros((img.shape[0], img.shape[1], 3)) 88 | 89 | label_img = cv2.fillPoly(label_img, [coords], (255, 0, 0)) 90 | # if coords_double is not None: 91 | # label_img = cv2.polylines(label_img, [coords_double], False, color=(0, 0, 0), thickness=50) 92 | 93 | collection, filename = image_filename.split(os.path.sep)[-2:] 94 | 95 | imsave(os.path.join(output_img_dir, '{}_{}.jpg'.format(collection.split('_')[0], filename.split('.')[0])), img.astype(np.uint8)) 96 | imsave(os.path.join(output_label_dir, '{}_{}.png'.format(collection.split('_')[0], filename.split('.')[0])), label_img.astype(np.uint8)) 97 | 98 | # Class file 99 | classes = np.stack([(0, 0, 0), (255, 0, 0)]) 100 | np.savetxt(os.path.join(output_dir, 'classes.txt'), classes, fmt='%d') 101 | 102 | # ----------------------------- 103 | 104 | 105 | def page_files_download(output_dir: str) -> None: 106 | """ 107 | Download Page txt files from github repository. 108 | 109 | :param output_dir: folder where to download the data 110 | :return: 111 | """ 112 | os.makedirs(output_dir, exist_ok=True) 113 | train_filename = os.path.join(output_dir, TRAIN_TXT_FILENAME) 114 | test_filename = os.path.join(output_dir, TEST_TXT_FILENAME) 115 | eval_filename = os.path.join(output_dir, EVAL_TXT_FILENAME) 116 | 117 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc="Downloading train file") as t: 118 | urllib.request.urlretrieve(TRAIN_FILE_URL, train_filename, reporthook=_progress_hook(t)) 119 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc="Downloading test file") as t: 120 | urllib.request.urlretrieve(TEST_FILE_URL, test_filename, reporthook=_progress_hook(t)) 121 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc="Downloading eval file") as t: 122 | urllib.request.urlretrieve(EVAL_FILE_URL, eval_filename, reporthook=_progress_hook(t)) 123 | 124 | print('Page files downloaded successfully!') 125 | 126 | 127 | def format_txt_file(input_txt_filename: str, 128 | output_txt_filename: str, 129 | cbad_data_folder: str) -> None: 130 | """ 131 | Transforms the relative path of the images into absolute path. 132 | 133 | :param input_txt_filename: original downloaded .txt filename 134 | :param output_txt_filename: filename of the formatted content 135 | :param cbad_data_folder: path to the folder containing the READ-BAD data 136 | :return: 137 | """ 138 | final_tokens = list() 139 | for line in open(input_txt_filename, 'r'): 140 | tokens = line.split(',') 141 | filename = tokens[0] 142 | full_filename = os.path.join(os.path.abspath(cbad_data_folder), filename) 143 | 144 | if 'complex' in filename: 145 | pattern = 'complex' 146 | candidate_folders = [CBAD_TRAIN_COMPLEX_FOLDER, CBAD_TEST_COMPLEX_FOLDER] 147 | elif 'simple' in filename: 148 | pattern = 'simple' 149 | candidate_folders = [CBAD_TRAIN_SIMPLE_FOLDER, CBAD_TEST_SIMPLE_FOLDER] 150 | else: 151 | raise Exception 152 | 153 | option1 = re.sub(pattern, candidate_folders[0], full_filename) 154 | option2 = re.sub(pattern, candidate_folders[1], full_filename) 155 | # .JPG files 156 | option3 = re.sub(pattern, candidate_folders[0], full_filename.split('.')[0] + '.JPG') 157 | option4 = re.sub(pattern, candidate_folders[1], full_filename.split('.')[0] + '.JPG') 158 | 159 | if os.path.exists(option1): 160 | tokens[0] = option1 161 | elif os.path.exists(option2): 162 | tokens[0] = option2 163 | elif os.path.exists(option3): 164 | tokens[0] = option3 165 | elif os.path.exists(option4): 166 | tokens[0] = option4 167 | else: 168 | raise FileNotFoundError('for {}'.format(filename)) 169 | 170 | final_tokens.append(','.join(tokens)) 171 | 172 | with open(output_txt_filename, 'w') as f: 173 | for line in final_tokens: 174 | f.write(line) 175 | -------------------------------------------------------------------------------- /general_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 1e-5, 4 | "batch_size": 16, 5 | "make_patches": true, 6 | "n_epochs": 30, 7 | "patch_shape": [300, 300], 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": false, 12 | "data_augmentation_flip_ud": false, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10 15 | }, 16 | "model_params": { 17 | "batch_norm": true, 18 | "batch_renorm": true, 19 | "selected_levels_upscaling": [ 20 | true, 21 | true, 22 | true, 23 | true, 24 | true 25 | ] 26 | }, 27 | "pretrained_model_name" : "resnet50", 28 | "prediction_type": "CLASSIFICATION", 29 | "gpu" : "0" 30 | } -------------------------------------------------------------------------------- /pretrained_models/download_resnet_pretrained_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import urllib.request 4 | import tarfile 5 | import os 6 | from tqdm import tqdm 7 | 8 | 9 | def progress_hook(t): 10 | last_b = [0] 11 | 12 | def update_to(b=1, bsize=1, tsize=None): 13 | """ 14 | b : int, optional 15 | Number of blocks transferred so far [default: 1]. 16 | bsize : int, optional 17 | Size of each block (in tqdm units) [default: 1]. 18 | tsize : int, optional 19 | Total size (in tqdm units). If [default: None] remains unchanged. 20 | """ 21 | if tsize is not None: 22 | t.total = tsize 23 | t.update((b - last_b[0]) * bsize) 24 | last_b[0] = b 25 | 26 | return update_to 27 | 28 | 29 | if __name__ == '__main__': 30 | tar_filename = 'resnet_v1_50.tar.gz' 31 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, 32 | desc="Downloading pre-trained weights") as t: 33 | urllib.request.urlretrieve('http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz', tar_filename, 34 | reporthook=progress_hook(t)) 35 | tar = tarfile.open(tar_filename) 36 | tar.extractall() 37 | tar.close() 38 | print('Resnet pre-trained weights downloaded!') 39 | os.remove(tar_filename) 40 | -------------------------------------------------------------------------------- /pretrained_models/download_vgg_pretrained_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import urllib.request 4 | import tarfile 5 | import os 6 | from tqdm import tqdm 7 | 8 | 9 | def progress_hook(t): 10 | last_b = [0] 11 | 12 | def update_to(b=1, bsize=1, tsize=None): 13 | """ 14 | b : int, optional 15 | Number of blocks transferred so far [default: 1]. 16 | bsize : int, optional 17 | Size of each block (in tqdm units) [default: 1]. 18 | tsize : int, optional 19 | Total size (in tqdm units). If [default: None] remains unchanged. 20 | """ 21 | if tsize is not None: 22 | t.total = tsize 23 | t.update((b - last_b[0]) * bsize) 24 | last_b[0] = b 25 | 26 | return update_to 27 | 28 | 29 | if __name__ == '__main__': 30 | tar_filename = 'vgg_16.tar.gz' 31 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, 32 | desc="Downloading pre-trained weights") as t: 33 | urllib.request.urlretrieve('http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz', tar_filename, 34 | reporthook=progress_hook(t)) 35 | tar = tarfile.open(tar_filename) 36 | tar.extractall() 37 | tar.close() 38 | print('VGG-16 pre-trained weights downloaded!') 39 | os.remove(tar_filename) 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | setup(name='dh_segment', 5 | version='0.6.0', 6 | license='GPL', 7 | url='https://github.com/dhlab-epfl/dhSegment', 8 | description='Generic framework for historical document processing', 9 | packages=find_packages(), 10 | project_urls={ 11 | 'Paper': 'https://arxiv.org/abs/1804.10371', 12 | 'Source Code': 'https://github.com/dhlab-epfl/dhSegment' 13 | }, 14 | install_requires=[ 15 | 'imageio>=2.5', 16 | 'pandas>=0.24.2', 17 | 'shapely>=1.6.4', 18 | 'scikit-learn>=0.20.3', 19 | 'scikit-image>=0.15.0', 20 | 'opencv-python>=4.0.1', 21 | 'tqdm>=4.31.1', 22 | 'sacred==0.7.4', # 0.7.5 causes an error 23 | 'requests>=2.21.0', 24 | 'click>=7.0' 25 | ], 26 | extras_require={ 27 | 'doc': [ 28 | 'sphinx', 29 | 'sphinx-autodoc-typehints', 30 | 'sphinx-rtd-theme', 31 | 'sphinxcontrib-bibtex', 32 | 'sphinxcontrib-websupport' 33 | ], 34 | }, 35 | zip_safe=False) 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | # Tensorflow logging level 4 | from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity 5 | 6 | tf.logging.set_verbosity(WARNING) 7 | from dh_segment import estimator_fn, utils 8 | from dh_segment.io import input 9 | import json 10 | from glob import glob 11 | import numpy as np 12 | 13 | try: 14 | import better_exceptions 15 | except ImportError: 16 | print('/!\ W -- Not able to import package better_exceptions') 17 | pass 18 | from tqdm import trange 19 | from sacred import Experiment 20 | import pandas as pd 21 | 22 | ex = Experiment('dhSegment_experiment') 23 | 24 | 25 | @ex.config 26 | def default_config(): 27 | train_data = None # Directory with training data 28 | eval_data = None # Directory with validation data 29 | model_output_dir = None # Directory to output tf model 30 | restore_model = False # Set to true to continue training 31 | classes_file = None # txt file with classes values (unused for REGRESSION) 32 | gpu = '' # GPU to be used for training 33 | prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL 34 | pretrained_model_name = 'resnet50' 35 | model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters 36 | training_params = utils.TrainingParams().to_dict() # Training parameters 37 | if prediction_type == utils.PredictionType.CLASSIFICATION: 38 | assert classes_file is not None 39 | model_params['n_classes'] = utils.get_n_classes_from_file(classes_file) 40 | elif prediction_type == utils.PredictionType.REGRESSION: 41 | model_params['n_classes'] = 1 42 | elif prediction_type == utils.PredictionType.MULTILABEL: 43 | assert classes_file is not None 44 | model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file) 45 | 46 | 47 | @ex.automain 48 | def run(train_data, eval_data, model_output_dir, gpu, training_params, _config): 49 | # Create output directory 50 | if not os.path.isdir(model_output_dir): 51 | os.makedirs(model_output_dir) 52 | else: 53 | assert _config.get('restore_model'), \ 54 | '{0} already exists, you cannot use it as output directory. ' \ 55 | 'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir) 56 | # Save config 57 | with open(os.path.join(model_output_dir, 'config.json'), 'w') as f: 58 | json.dump(_config, f, indent=4, sort_keys=True) 59 | 60 | # Create export directory for saved models 61 | saved_model_dir = os.path.join(model_output_dir, 'export') 62 | if not os.path.isdir(saved_model_dir): 63 | os.makedirs(saved_model_dir) 64 | 65 | training_params = utils.TrainingParams.from_dict(training_params) 66 | 67 | session_config = tf.ConfigProto() 68 | session_config.gpu_options.visible_device_list = str(gpu) 69 | session_config.gpu_options.per_process_gpu_memory_fraction = 0.9 70 | estimator_config = tf.estimator.RunConfig().replace(session_config=session_config, 71 | save_summary_steps=10, 72 | keep_checkpoint_max=1) 73 | estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir, 74 | params=_config, config=estimator_config) 75 | 76 | def get_dirs_or_files(input_data): 77 | if os.path.isdir(input_data): 78 | image_input, labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels') 79 | # Check if training dir exists 80 | assert os.path.isdir(image_input), "{} is not a directory".format(image_input) 81 | assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) 82 | 83 | elif os.path.isfile(input_data) and input_data.endswith('.csv'): 84 | image_input = input_data 85 | labels_input = None 86 | else: 87 | raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data)) 88 | return image_input, labels_input 89 | 90 | train_input, train_labels_input = get_dirs_or_files(train_data) 91 | if eval_data is not None: 92 | eval_input, eval_labels_input = get_dirs_or_files(eval_data) 93 | 94 | # Configure exporter 95 | serving_input_fn = input.serving_input_filename(training_params.input_resized_size) 96 | if eval_data is not None: 97 | exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2) 98 | else: 99 | exporter = tf.estimator.LatestExporter(name='SimpleExporter', serving_input_receiver_fn=serving_input_fn, 100 | exports_to_keep=5) 101 | 102 | for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'): 103 | estimator.train(input.input_fn(train_input, 104 | input_label_dir=train_labels_input, 105 | num_epochs=training_params.evaluate_every_epoch, 106 | batch_size=training_params.batch_size, 107 | data_augmentation=training_params.data_augmentation, 108 | make_patches=training_params.make_patches, 109 | image_summaries=True, 110 | params=_config, 111 | num_threads=32)) 112 | 113 | if eval_data is not None: 114 | eval_result = estimator.evaluate(input.input_fn(eval_input, 115 | input_label_dir=eval_labels_input, 116 | batch_size=1, 117 | data_augmentation=False, 118 | make_patches=False, 119 | image_summaries=False, 120 | params=_config, 121 | num_threads=32)) 122 | else: 123 | eval_result = None 124 | 125 | exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result, 126 | is_the_final_export=False) 127 | 128 | # If export directory is empty, export a model anyway 129 | if not os.listdir(saved_model_dir): 130 | final_exporter = tf.estimator.FinalExporter(name='FinalExporter', serving_input_receiver_fn=serving_input_fn) 131 | final_exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result, 132 | is_the_final_export=True) 133 | --------------------------------------------------------------------------------