├── exps ├── Cini │ ├── __init__.py │ ├── cini_post_processing.py │ ├── cini_evaluation.py │ └── cini_process_set.py ├── DIBCO │ ├── __init__.py │ ├── dibco_post_processing.py │ ├── dibco_evaluation.py │ └── dibco_dataset_generator.py ├── Ornaments │ ├── __init__.py │ ├── ornaments_post_processing.py │ ├── ornaments_dataset_generator.py │ ├── ornaments_process_eval.py │ ├── ornaments_evaluation.py │ └── ornaments_process_set.py ├── cbad │ ├── __init__.py │ ├── example_evaluation.ipynb │ ├── evaluation.py │ ├── utils.py │ └── process.py ├── diva │ ├── __init__.py │ ├── evaluation.py │ ├── process.py │ ├── example_evaluation.ipynb │ └── utils.py ├── page │ ├── __init__.py │ ├── README.md │ ├── evaluation.py │ ├── example_processing.ipynb │ ├── utils.py │ ├── process.py │ └── example_evaluation.ipynb ├── README.md ├── __init__.py └── _misc │ ├── worker.py │ ├── layout_generate_dataset.py │ └── post_process_evaluation.py ├── doc ├── tutorials │ └── index.rst ├── _static │ ├── cbad.jpg │ ├── cini.jpg │ ├── diva.jpg │ ├── page.jpg │ ├── system.png │ ├── cini_input.jpg │ ├── diva_preds.png │ ├── ornaments.jpg │ ├── cini_labels.jpg │ ├── tensorboard_1.png │ ├── tensorboard_2.png │ └── tensorboard_3.png ├── start │ ├── index.rst │ ├── install.rst │ ├── demo.rst │ ├── annotating.rst │ └── training.rst ├── references.rst ├── reference │ ├── utils.rst │ ├── inference.rst │ ├── post_processing.rst │ ├── network.rst │ ├── index.rst │ └── io.rst ├── Makefile ├── changelog.rst ├── index.rst ├── references.bib ├── intro │ └── intro.rst └── conf.py ├── dh_segment_text ├── network │ ├── pretrained_models │ │ ├── mobilenet │ │ │ ├── __init__.py │ │ │ └── encoder.py │ │ ├── resnet50 │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── vgg16.py │ ├── __init__.py │ ├── model.py │ └── simple_decoder.py ├── embeddings │ ├── __init__.py │ ├── encoder.py │ ├── pca_encoder.py │ ├── embeddings_utils.py │ ├── conv2d_encoder.py │ └── conv1d_encoder.py ├── inference │ └── __init__.py ├── post_processing │ ├── __init__.py │ ├── polygon_detection.py │ ├── binarization.py │ ├── line_vectorization.py │ └── boxes_detection.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── misc.py │ ├── evaluation.py │ ├── labels.py │ └── params_config.py └── io │ └── __init__.py ├── .readthedocs.yml ├── environment.yml ├── demo └── demo_config.json ├── general_config.json ├── setup.py ├── embeddings_config.json ├── .gitignore ├── demo.py ├── README.md └── dh_segment_train.py /exps/Cini/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exps/DIBCO/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exps/Ornaments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | -------------------------------------------------------------------------------- /dh_segment_text/network/pretrained_models/mobilenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dh_segment_text/network/pretrained_models/resnet50/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exps/cbad/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' -------------------------------------------------------------------------------- /exps/diva/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" -------------------------------------------------------------------------------- /exps/page/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | -------------------------------------------------------------------------------- /doc/_static/cbad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/cbad.jpg -------------------------------------------------------------------------------- /doc/_static/cini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/cini.jpg -------------------------------------------------------------------------------- /doc/_static/diva.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/diva.jpg -------------------------------------------------------------------------------- /doc/_static/page.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/page.jpg -------------------------------------------------------------------------------- /doc/_static/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/system.png -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | python: 2 | version: 3.5 3 | pip_install: true 4 | extra_requirements: 5 | - doc -------------------------------------------------------------------------------- /doc/_static/cini_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/cini_input.jpg -------------------------------------------------------------------------------- /doc/_static/diva_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/diva_preds.png -------------------------------------------------------------------------------- /doc/_static/ornaments.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/ornaments.jpg -------------------------------------------------------------------------------- /doc/_static/cini_labels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/cini_labels.jpg -------------------------------------------------------------------------------- /doc/_static/tensorboard_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/tensorboard_1.png -------------------------------------------------------------------------------- /doc/_static/tensorboard_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/tensorboard_2.png -------------------------------------------------------------------------------- /doc/_static/tensorboard_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhlab-epfl/dhSegment-text/HEAD/doc/_static/tensorboard_3.png -------------------------------------------------------------------------------- /doc/start/index.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | .. toctree:: 5 | install 6 | annotating 7 | training 8 | demo -------------------------------------------------------------------------------- /doc/references.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | References 3 | ========== 4 | 5 | .. bibliography:: references.bib 6 | :cited: 7 | :style: alpha -------------------------------------------------------------------------------- /doc/reference/utils.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Utilities 3 | ========= 4 | 5 | .. automodule:: dh_segment.utils 6 | :members: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /doc/reference/inference.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Inference 3 | ========= 4 | 5 | .. automodule:: dh_segment.inference 6 | :members: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /dh_segment_text/network/pretrained_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet50.encoder import ResnetV1_50 2 | from .vgg16 import VGG16 3 | from .mobilenet.encoder import MobileNetV2 4 | -------------------------------------------------------------------------------- /doc/reference/post_processing.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Post processing 3 | =============== 4 | 5 | .. automodule:: dh_segment.post_processing 6 | :members: 7 | :undoc-members: 8 | 9 | -------------------------------------------------------------------------------- /doc/reference/network.rst: -------------------------------------------------------------------------------- 1 | Network architecture 2 | ==================== 3 | 4 | Here is the dhsegment architecture definition 5 | 6 | ----- 7 | 8 | .. automodule:: dh_segment.network 9 | :members: 10 | :undoc-members: 11 | 12 | -------------------------------------------------------------------------------- /dh_segment_text/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['EmbeddingsEncoder', 'PCAEncoder', 'Conv1dEncoder', 'Conv2dEncoder'] 2 | 3 | from .encoder import * 4 | from .pca_encoder import * 5 | from .conv1d_encoder import * 6 | from .conv2d_encoder import * 7 | -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Reference guide 3 | =============== 4 | 5 | .. automodule:: dh_segment 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | network 11 | io 12 | inference 13 | post_processing 14 | utils -------------------------------------------------------------------------------- /dh_segment_text/embeddings/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from abc import ABC, abstractmethod 3 | from typing import List 4 | 5 | class EmbeddingsEncoder(ABC): 6 | @abstractmethod 7 | def __call__(self, embeddings: tf.Tensor, embeddings_map: tf.Tensor, target_shape: tf.Tensor, target_dim: int, is_training: bool=False) -> tf.Tensor: 8 | pass 9 | -------------------------------------------------------------------------------- /dh_segment_text/network/__init__.py: -------------------------------------------------------------------------------- 1 | _MODEL = [ 2 | 'Encoder', 3 | 'Decoder', 4 | ] 5 | 6 | _SIMPLEDECODER = [ 7 | 'SimpleDecoder' 8 | ] 9 | 10 | _PRETRAINED = [ 11 | 'ResnetV1_50', 12 | 'VGG16' 13 | ] 14 | __all__ = _MODEL + _SIMPLEDECODER + _PRETRAINED 15 | 16 | from .model import * 17 | from .simple_decoder import * 18 | from .pretrained_models import * 19 | -------------------------------------------------------------------------------- /dh_segment_text/inference/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.inference` module implements the function related to the usage of a dhSegment model, 3 | for instance to use a trained model to inference on new data. 4 | 5 | Loading a model 6 | --------------- 7 | 8 | .. autosummary:: 9 | LoadedModel 10 | 11 | 12 | ----- 13 | """ 14 | 15 | __all__ = ['LoadedModel'] 16 | 17 | from .loader import * -------------------------------------------------------------------------------- /exps/page/README.md: -------------------------------------------------------------------------------- 1 | ### Page experiment 2 | Based on paper ["PageNet: Page Boundary Extraction in Historical Handwritten Documents."](https://dl.acm.org/citation.cfm?id=3151522) 3 | 4 | 5 | #### Dataset 6 | The page annotations come from this [repository](https://github.com/ctensmeyer/pagenet/tree/master/annotations). We use READ-cBAD data with _annotator 1_ and _set1_. 7 | 8 | `utils.page_dataset_generator` is used to generate the label images. 9 | 10 | #### Results -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dh_segment 2 | channels: 3 | - defaults 4 | dependencies: 5 | - imageio=2.5.0 6 | - numpy=1.16.2 7 | - pandas=0.24.2 8 | - pillow=5.4.1 9 | - python=3.6 10 | - scikit-image=0.14.2 11 | - scikit-learn=0.20.3 12 | - scipy=1.2.1 13 | - setuptools=40.8.0 14 | - shapely=1.6.4 15 | - tensorflow-gpu==1.13.1 16 | - tqdm=4.31.1 17 | - requests=2.21.0 18 | - pip: 19 | - better-exceptions==0.2.1 20 | - opencv-python==4.0.1.23 21 | - sacred==0.7.4 22 | - sphinx 23 | - sphinx-autodoc-typehints 24 | - sphinx-rtd-theme 25 | - sphinxcontrib-bibtex 26 | -------------------------------------------------------------------------------- /exps/README.md: -------------------------------------------------------------------------------- 1 | ## Experiments 2 | 3 | This folder contains code that is helpful for experiences on various datasets. 4 | This is a bit messy and _under refactoring_! 5 | 6 | - [x] `page` experiment on [PageNet dataset](https://dl.acm.org/citation.cfm?id=3151522) 7 | - [x] `cBAD` experiment on [READ-BAD dataset](https://arxiv.org/abs/1705.03311) 8 | - [x] `DIVA` experiment on [DIVA-HisDB dataset](http://diuf.unifr.ch/main/hisdoc/sites/diuf.unifr.ch.main.hisdoc/files/uploads/hisdoc2.0-publications/2016-icfhr-divahisdb.pdf) 9 | - [ ] `ornaments` experiment on private dataset 10 | - [ ] `cini` experiment on private dataset 11 | - [ ] `dibco` experiment 12 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = dhsegment 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /exps/DIBCO/dibco_post_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import numpy as np 5 | from scipy.misc import imsave 6 | import cv2 7 | 8 | 9 | def dibco_binarization_fn(probs: np.ndarray, threshold=0.5, output_basename=None): 10 | probs = probs[:, :, 1] 11 | if threshold < 0: 12 | probs = np.uint8(probs * 255) 13 | # Otsu's thresholding 14 | blur = cv2.GaussianBlur(probs, (5, 5), 0) 15 | thresh_val, bin_img = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 16 | result = np.uint8(bin_img / 255) 17 | else: 18 | result = (probs > threshold).astype(np.uint8) 19 | 20 | if output_basename is not None: 21 | imsave(output_basename + '.png', result*255) 22 | return result -------------------------------------------------------------------------------- /demo/demo_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 5e-5, 4 | "batch_size": 1, 5 | "make_patches": false, 6 | "training_margin" : 0, 7 | "n_epochs": 30, 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": true, 12 | "data_augmentation_flip_ud": true, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10 15 | }, 16 | "pretrained_model_name" : "resnet50", 17 | "prediction_type": "CLASSIFICATION", 18 | "train_data" : "demo/pages/train/", 19 | "eval_data" : "demo/pages/val_a1", 20 | "classes_file" : "demo/pages/classes.txt", 21 | "model_output_dir" : "demo/page_model", 22 | "gpu" : "0" 23 | } -------------------------------------------------------------------------------- /exps/__init__.py: -------------------------------------------------------------------------------- 1 | from .DIVA.diva_post_processing import diva_post_processing_fn 2 | from .Ornaments.ornaments_post_processing import ornaments_post_processing_fn 3 | from .Page.page_post_processing import page_post_processing_fn 4 | from .cBAD.cbad_post_processing import cbad_post_processing_fn 5 | from .DIBCO.dibco_post_processing import dibco_binarization_fn 6 | #from .Cini.cini_post_processing import cini_post_processing_fn 7 | from .DIVA.diva_evaluation import diva_evaluate_folder 8 | #from .Cini.cini_evaluation import cini_evaluate_folder 9 | from .cBAD.cbad_evaluation import cbad_evaluate_folder 10 | from .DIBCO.dibco_evaluation import dibco_evaluate_folder 11 | from .Ornaments.ornaments_evaluation import ornament_evaluate_folder 12 | from .Page.page_evaluation import page_evaluate_folder 13 | from .evaluation.base import evaluate_epoch -------------------------------------------------------------------------------- /general_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "training_params" : { 3 | "learning_rate": 1e-5, 4 | "batch_size": 16, 5 | "make_patches": true, 6 | "n_epochs": 30, 7 | "patch_shape": [300, 300], 8 | "data_augmentation" : true, 9 | "data_augmentation_max_rotation" : 0.2, 10 | "data_augmentation_max_scaling" : 0.2, 11 | "data_augmentation_flip_lr": false, 12 | "data_augmentation_flip_ud": false, 13 | "data_augmentation_color": false, 14 | "evaluate_every_epoch" : 10 15 | }, 16 | "model_params": { 17 | "batch_norm": true, 18 | "batch_renorm": true, 19 | "selected_levels_upscaling": [ 20 | true, 21 | true, 22 | true, 23 | true, 24 | true 25 | ] 26 | }, 27 | "pretrained_model_name" : "resnet50", 28 | "prediction_type": "CLASSIFICATION", 29 | "gpu" : "0" 30 | } -------------------------------------------------------------------------------- /dh_segment_text/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.post_processing` module contains functions to post-process probability maps. 3 | 4 | **Binarization** 5 | 6 | .. autosummary:: 7 | thresholding 8 | cleaning_binary 9 | 10 | **Detection** 11 | 12 | .. autosummary:: 13 | find_boxes 14 | find_polygonal_regions 15 | 16 | **Vectorization** 17 | 18 | .. autosummary:: 19 | find_lines 20 | 21 | ------ 22 | 23 | """ 24 | 25 | _BINARIZATION = [ 26 | 'thresholding', 27 | 'cleaning_binary', 28 | 29 | ] 30 | 31 | _DETECTION = [ 32 | 'find_boxes', 33 | 'find_polygonal_regions' 34 | ] 35 | 36 | _VECTORIZATION = [ 37 | 'find_lines' 38 | ] 39 | 40 | __all__ = _BINARIZATION + _DETECTION + _VECTORIZATION 41 | 42 | from .binarization import * 43 | from .boxes_detection import * 44 | from .line_vectorization import * 45 | from .polygon_detection import * 46 | 47 | -------------------------------------------------------------------------------- /doc/changelog.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Unreleased 6 | ---------- 7 | 8 | 0.4.0 - 2019-04-10 9 | ------------------ 10 | Added 11 | ^^^^^ 12 | 13 | * Input data can be a .csv file with format ``,``. 14 | * ``dh_segment.io.via`` helper functions to generate/export groundtruth from/to VGG Image Annotation tool. 15 | * ``Point.array_to_point`` to export a ``np.array`` into a list of ``Point``. 16 | * PAGEXML Regions can now contain a custom attribute (Transkribus output of region annotation) 17 | * ``Page.to_json()`` method for json formatting. 18 | 19 | Changed 20 | ^^^^^^^ 21 | 22 | * ``tensorflow`` v1.13 and ``opencv`` v4.0 are now used. 23 | * mIOU metric for evaluation during training (instead of accuracy). 24 | * TextLines are sorted according to their mean `y` coordinate when exported. 25 | 26 | Fixed 27 | ^^^^^ 28 | 29 | * Variable names typos in ``input.py`` and ``train.py``. 30 | * Documentation of the quickstart demo. 31 | 32 | Removed 33 | ^^^^^^^ 34 | -------------------------------------------------------------------------------- /doc/start/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | 4 | Using ``pip`` 5 | ^^^^^^^^^^^^^ 6 | 7 | 1. Clone the repository using ``git clone https://github.com/dhlab-epfl/dhSegment.git`` 8 | 9 | 2. Create and activate a virtualenv :: 10 | 11 | virtualenv myvirtualenvs/dh_segment 12 | source myvirtualenvs/dh_segment/bin/activate 13 | 14 | 3. Install the dependencies using ``pip`` (this will look for the ``setup.py`` file) :: 15 | 16 | pip install git+https://github.com/dhlab-epfl/dhSegment 17 | 18 | Using Anaconda 19 | ^^^^^^^^^^^^^^ 20 | 21 | 1. Install Anaconda or Miniconda (`installation procedure `_) 22 | 23 | 2. Clone the repository: ``git clone https://github.com/dhlab-epfl/dhSegment.git`` 24 | 25 | 3. Create a virtual environment with all the packages: ``conda env create -f environment.yml`` 26 | 27 | 4. Then activate the environment with ``source activate dh_segment`` 28 | 29 | 30 | 5. To be able to import the package (i.e ``import dh_segment``) in your code, you have to run : :: 31 | 32 | python setup.py install 33 | 34 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. dhsegment documentation master file, created by 2 | sphinx-quickstart on Mon Oct 1 17:17:21 2018. 3 | 4 | ================================================================ 5 | dhSegment : Generic framework for historical document processing 6 | ================================================================ 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | intro/intro 12 | start/index 13 | reference/index 14 | references 15 | changelog 16 | 17 | **dhSegment** is a tool for Historical Document Processing. Its generic approach allows to segment regions and 18 | extract content from different type of documents. See some example of applications in the :ref:`usecases-label` section. 19 | 20 | The complete description of the system can be found in the corresponding `paper`_ :cite:`oliveiraseguin2018dhsegment` . 21 | 22 | .. _paper: https://arxiv.org/abs/1804.10371 23 | 24 | Indices and tables 25 | ------------------ 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | 31 | 32 | Acknowledgement 33 | ^^^^^^^^^^^^^^^ 34 | 35 | This work has been partly funded by the European Union’s Horizon 2020 research and 36 | innovation programme under grant agreement No 674943. -------------------------------------------------------------------------------- /dh_segment_text/embeddings/pca_encoder.py: -------------------------------------------------------------------------------- 1 | from .encoder import EmbeddingsEncoder 2 | from .embeddings_utils import batch_resize_and_gather 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | class PCAEncoder(EmbeddingsEncoder): 7 | def __init__(self, pca_mean_path: str, pca_components_path: str, target_dim: int): 8 | self.pca_mean = tf.constant(np.load(pca_mean_path), dtype=tf.float32) 9 | self.pca_components = tf.constant(np.load(pca_components_path), dtype=tf.float32) 10 | self.target_dim = target_dim 11 | 12 | def __call__(self, embeddings: tf.Tensor, embeddings_map: tf.Tensor, target_shape: tf.Tensor, is_training: bool=False) -> tf.Tensor: 13 | with tf.variable_scope("PCAEncoder"): 14 | reduced_components = tf.transpose(self.pca_components[:self.target_dim]) 15 | reduced_embeddings = tf.einsum('aij,jk->aik', (embeddings-self.pca_mean), reduced_components) 16 | embeddings_feature_map = batch_resize_and_gather(embeddings_map, 17 | target_shape, 18 | reduced_embeddings) 19 | embeddings_feature_map.set_shape([None, None, None, self.target_dim]) 20 | return embeddings_feature_map 21 | 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | setup(name='dh_segment_text', 5 | version='0.4.0', 6 | license='GPL', 7 | url='https://github.com/dhlab-epfl/dhSegment', 8 | description='Generic framework for historical document processing', 9 | packages=find_packages(exclude=['exps*']), 10 | project_urls={ 11 | 'Paper': 'https://arxiv.org/abs/1804.10371', 12 | 'Source Code': 'https://github.com/dhlab-epfl/dhSegment' 13 | }, 14 | scripts=['dh_segment_train.py'], 15 | install_requires=[ 16 | 'tensorflow-gpu==1.13.1', 17 | 'numpy==1.16.2', 18 | 'imageio==2.5.0', 19 | 'pandas==0.24.2', 20 | 'scipy==1.2.1', 21 | 'shapely==1.6.4', 22 | 'scikit-learn==0.20.3', 23 | 'scikit-image==0.15.0', 24 | 'opencv-python==4.0.1.23', 25 | 'tqdm==4.31.1', 26 | 'sacred==0.7.4', 27 | 'requests==2.21.0' 28 | ], 29 | extras_require={ 30 | 'doc': [ 31 | 'sphinx==1.8.1', 32 | 'sphinx-autodoc-typehints==1.3.0', 33 | 'sphinx-rtd-theme==0.4.1', 34 | 'sphinxcontrib-bibtex==0.4.0', 35 | 'sphinxcontrib-websupport' 36 | ], 37 | }, 38 | zip_safe=False) 39 | -------------------------------------------------------------------------------- /exps/Ornaments/ornaments_post_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import cv2 5 | import numpy as np 6 | from scipy.misc import imsave 7 | 8 | 9 | def ornaments_post_processing_fn(probs: np.ndarray, threshold: float=0.5, ksize_open: tuple=(5, 5), 10 | ksize_close: tuple=(7, 7), output_basename: str=None) -> np.ndarray: 11 | """ 12 | 13 | :param probs: 14 | :param threshold: 15 | :param ksize_open: 16 | :param ksize_close: 17 | :param output_basename: 18 | :return: 19 | """ 20 | probs = probs[:, :, 1] 21 | if threshold < 0: # Otsu thresholding 22 | probs_ch = np.uint8(probs * 255) 23 | blur = cv2.GaussianBlur(probs_ch, (5, 5), 0) 24 | thresh_val, bin_img = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 25 | mask = bin_img / 255 26 | else: 27 | mask = probs > threshold 28 | # TODO : adaptive kernel (not hard-coded) 29 | mask = cv2.morphologyEx((mask.astype(np.uint8) * 255), cv2.MORPH_OPEN, kernel=np.ones(ksize_open)) 30 | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel=np.ones(ksize_close)) 31 | 32 | result = mask / 255 33 | 34 | if output_basename is not None: 35 | imsave('{}.png'.format(output_basename), result*255) 36 | return result 37 | -------------------------------------------------------------------------------- /dh_segment_text/network/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import tensorflow as tf 4 | from abc import ABC, abstractmethod 5 | from typing import List, Union, Tuple, Optional, Dict 6 | 7 | 8 | class Encoder(ABC): 9 | @abstractmethod 10 | def __call__(self, images: tf.Tensor, is_training=False) -> List[tf.Tensor]: 11 | """ 12 | 13 | :param images: [NxHxWx3] float32 [0..255] input images 14 | :return: a list of the feature maps in decreasing spatial resolution (first element is most likely the input \ 15 | image itself, then the output of the first pooling op, etc...) 16 | """ 17 | pass 18 | 19 | def pretrained_information(self) -> Tuple[Optional[str], Union[None, List, Dict]]: 20 | """ 21 | 22 | :return: The filename of the pretrained checkpoint and the corresponding variables (List of Dict mapping) \ 23 | or `None` if no-pretraining is done 24 | """ 25 | return None, None 26 | 27 | 28 | class Decoder(ABC): 29 | @abstractmethod 30 | def __call__(self, feature_maps: List[tf.Tensor], num_classes: int, is_training=False) -> tf.Tensor: 31 | """ 32 | 33 | :param feature_maps: list of feature maps, in decreasing spatial resolution, first one being at the original \ 34 | resolution 35 | :return: [N,H,W,num_classes] float32 tensor of logit scores 36 | """ 37 | pass 38 | -------------------------------------------------------------------------------- /dh_segment_text/__init__.py: -------------------------------------------------------------------------------- 1 | # _MODEL = [ 2 | # 'inference_vgg16', 3 | # 'inference_resnet_v1_50', 4 | # 'inference_u_net', 5 | # 'vgg_16_fn', 6 | # 'resnet_v1_50_fn' 7 | # ] 8 | # 9 | # _INPUT = [ 10 | # 'input_fn', 11 | # 'serving_input_filename', 12 | # 'serving_input_image', 13 | # 'data_augmentation_fn', 14 | # 'rotate_crop', 15 | # 'resize_image', 16 | # 'load_and_resize_image', 17 | # 'extract_patches_fn', 18 | # 'local_entropy' 19 | # ] 20 | # 21 | # _ESTIMATOR = [ 22 | # 'model_fn' 23 | # ] 24 | # 25 | # _LOADER = [ 26 | # 'LoadedModel' 27 | # ] 28 | # 29 | # _UTILS = [ 30 | # 'PredictionType', 31 | # 'VGG16ModelParams', 32 | # 'ResNetModelParams', 33 | # 'UNetModelParams', 34 | # 'ModelParams', 35 | # 'TrainingParams', 36 | # 'label_image_to_class', 37 | # 'class_to_label_image', 38 | # 'multilabel_image_to_class', 39 | # 'multiclass_to_label_image', 40 | # 'get_classes_color_from_file', 41 | # 'get_n_classes_from_file', 42 | # 'get_classes_color_from_file_multilabel', 43 | # 'get_n_classes_from_file_multilabel', 44 | # '_get_image_shape_tensor', 45 | # ] 46 | # 47 | # __all__ = _MODEL + _INPUT + _ESTIMATOR + _LOADER + _UTILS 48 | # 49 | # from dh_segment.model.pretrained_models import * 50 | # 51 | # from dh_segment.network import * 52 | # from .estimator_fn import * 53 | # from .io import * 54 | # from .network import * 55 | # from .inference import * 56 | # from .utils import * -------------------------------------------------------------------------------- /dh_segment_text/embeddings/embeddings_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def batch_resize_maps(batch_maps: tf.Tensor, target_shape: tf.Tensor, scope: str="ResizeMaps"): 4 | with tf.variable_scope(scope): 5 | maps_resized = tf.squeeze( 6 | tf.image.resize_nearest_neighbor( 7 | tf.expand_dims(batch_maps, axis=-1), 8 | target_shape 9 | ), axis=-1) 10 | return maps_resized 11 | 12 | def batch_gather(batch_maps: tf.Tensor, embeddings: tf.Tensor, scope: str="BatchGather"): 13 | with tf.variable_scope(scope): 14 | b = tf.shape(batch_maps)[0] 15 | x = tf.shape(batch_maps)[1] 16 | y = tf.shape(batch_maps)[2] 17 | batches_range = tf.expand_dims(tf.expand_dims(tf.range(b), axis=-1), axis=-1) 18 | batch_indices = tf.tile(batches_range, (1, x, y)) 19 | batch_maps_indices = tf.stack([batch_indices, batch_maps], axis=-1) 20 | embeddings_feature_map = tf.gather_nd(embeddings, batch_maps_indices) 21 | return embeddings_feature_map 22 | 23 | def batch_resize_and_gather(batch_maps: tf.Tensor, 24 | target_shape: tf.Tensor, 25 | embeddings: tf.Tensor, 26 | scope: str="BatchResizeGather"): 27 | with tf.variable_scope(scope): 28 | batch_maps_resized = batch_resize_maps(batch_maps, target_shape) 29 | embeddings_feature_map = batch_gather(batch_maps_resized, embeddings) 30 | return embeddings_feature_map 31 | -------------------------------------------------------------------------------- /doc/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{oliveiraseguin2018dhsegment, 2 | title={dhSegment: A generic deep-learning approach for document segmentation}, 3 | author={Ares Oliveira, Sofia and Seguin, Benoit and Kaplan, Frederic}, 4 | booktitle={Frontiers in Handwriting Recognition (ICFHR), 2018 16th International Conference on}, 5 | pages={7--12}, 6 | year={2018}, 7 | organization={IEEE} 8 | } 9 | 10 | @inproceedings{tensmeyer2017pagenet, 11 | title={Pagenet: Page boundary extraction in historical handwritten documents}, 12 | author={Tensmeyer, Chris and Davis, Brian and Wigington, Curtis and Lee, Iain and Barrett, Bill}, 13 | booktitle={Proceedings of the 4th International Workshop on Historical Document Imaging and Processing}, 14 | pages={59--64}, 15 | year={2017}, 16 | organization={ACM} 17 | } 18 | 19 | @inproceedings{gruning2018read, 20 | title={READ-BAD: A new dataset and evaluation scheme for baseline detection in archival documents}, 21 | author={Gr{\"u}ning, Tobias and Labahn, Roger and Diem, Markus and Kleber, Florian and Fiel, Stefan}, 22 | booktitle={2018 13th IAPR International Workshop on Document Analysis Systems (DAS)}, 23 | pages={351--356}, 24 | year={2018}, 25 | organization={IEEE} 26 | } 27 | 28 | @inproceedings{simistira2016diva, 29 | title={Diva-hisdb: A precisely annotated large dataset of challenging medieval manuscripts}, 30 | author={Simistira, Foteini and Seuret, Mathias and Eichenberger, Nicole and Garz, Angelika and Liwicki, Marcus and Ingold, Rolf}, 31 | booktitle={Frontiers in Handwriting Recognition (ICFHR), 2016 15th International Conference on}, 32 | pages={471--476}, 33 | year={2016}, 34 | organization={IEEE} 35 | } 36 | 37 | -------------------------------------------------------------------------------- /exps/_misc/worker.py: -------------------------------------------------------------------------------- 1 | from train import ex 2 | import argparse 3 | import glob 4 | import os 5 | import time 6 | import json 7 | 8 | 9 | if __name__ == '__main__': 10 | print('Starting worker') 11 | 12 | ap = argparse.ArgumentParser() 13 | ap.add_argument("-c", "--configs-dir", required=True, help="Folder with configs file") 14 | ap.add_argument("-f", "--failed-configs-dir", required=True, help="Folder with failed experiments") 15 | args = ap.parse_args() 16 | 17 | while True: 18 | config_files = glob.glob(os.path.join(args.configs_dir, '**/*.json'), recursive=True) 19 | if len(config_files) == 0: 20 | time.sleep(3) 21 | continue 22 | 23 | # Found a config file 24 | config_file = config_files[0] 25 | print('Found config file : {}'.format(config_file)) 26 | with open(config_file, 'r') as f: 27 | config = json.load(f) 28 | try: 29 | os.remove(config_file) 30 | except Exception: 31 | print('Some worker already processed this config file!') 32 | continue 33 | 34 | print("Running config") 35 | try: 36 | res = ex.run(config_updates=config) 37 | except Exception as e: 38 | print(e) 39 | print('----------------ERROR----------------') 40 | filename = os.path.relpath(config_file, args.configs_dir) 41 | print('Experiment {} failed : {}'.format(filename, e)) 42 | output_file = os.path.join(args.failed_configs_dir, filename) 43 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 44 | with open(output_file, 'w') as f: 45 | json.dump(config, f) 46 | continue 47 | print("Running Done") 48 | -------------------------------------------------------------------------------- /embeddings_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data": "/path/to/train.csv", 3 | "eval_data": "/path/to/val.csv", 4 | "model_output_dir": "/path/to/model_output", 5 | "restore_model": true, 6 | "classes_file": "/path/to/classes.txt", 7 | "gpu": "0", 8 | "use_embeddings": true, 9 | "embeddings_dim": 300, 10 | "prediction_type": "CLASSIFICATION", 11 | "model_params": { 12 | "encoder_name": "dh_segment_text.network.pretrained_models.ResnetV1_50", 13 | "encoder_params": { 14 | "concat_level": 0, 15 | "weight_decay": 1e-06 16 | }, 17 | "decoder_name": "dh_segment_text.network.SimpleDecoder", 18 | "decoder_params": { 19 | "upsampling_dims": [32, 64, 128, 256, 512], 20 | "max_depth": 2348, 21 | "weight_decay": 1e-06 22 | }, 23 | "n_classes": 5 24 | }, 25 | "embeddings_params": { 26 | "target_dim": 300, 27 | "encoder_name": "dh_segment_text.embeddings.PCAEncoder", 28 | "encoder_params": { 29 | "pca_components_path": "/path/to/pca_std.npy", 30 | "pca_mean_path": "/path/to/pca_mean.npy" 31 | } 32 | }, 33 | "training_params": { 34 | "n_epochs": 50, 35 | "evaluate_every_epoch": 10, 36 | "learning_rate": 0.0001, 37 | "exponential_learning": true, 38 | "batch_size": 4, 39 | "data_augmentation": true, 40 | "data_augmentation_flip_lr": false, 41 | "data_augmentation_flip_ud": false, 42 | "data_augmentation_color": false, 43 | "data_augmentation_max_rotation": 0.01, 44 | "data_augmentation_max_scaling": 0.2, 45 | "make_patches": false, 46 | "input_resized_size": 500000.0, 47 | "training_margin": 0 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /exps/page/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | import os 8 | from imageio import imread 9 | import numpy as np 10 | from .process import extract_page 11 | from dh_segment.post_processing.evaluation import intersection_over_union, Metrics 12 | 13 | 14 | PP_PARAMS = {'threshold': -1, 'kernel_size': 5} 15 | 16 | 17 | def eval_fn(input_dir: str, groundtruth_dir: str, post_process_params: dict=PP_PARAMS) -> Metrics: 18 | """ 19 | 20 | :param input_dir: directory containing the predictions .npy files (range [0, 255]) 21 | :param groundtruth_dir: directory containing the ground truth images (.png) (must have the same name as predictions 22 | files in input_dir) 23 | :param post_process_params: params for post processing fn 24 | :return: Metrics object containing all the necessary metrics 25 | """ 26 | global_metrics = Metrics() 27 | for file in tqdm(glob(os.path.join(input_dir, '*.npy'))): 28 | basename = os.path.basename(file).split('.')[0] 29 | 30 | prediction = np.load(file) 31 | label_image = imread(os.path.join(groundtruth_dir, '{}.png'.format(basename)), pilmode='L') 32 | 33 | pred_box = extract_page(prediction / np.max(prediction), **post_process_params) 34 | label_box = extract_page(label_image / np.max(label_image), min_area=0.0) 35 | 36 | if pred_box is not None and label_box is not None: 37 | iou = intersection_over_union(label_box[:, None, :], pred_box[:, None, :], label_image.shape) 38 | global_metrics.IOU_list.append(iou) 39 | else: 40 | global_metrics.IOU_list.append(0) 41 | 42 | global_metrics.compute_miou() 43 | print('EVAL --- mIOU : {}\n'.format(global_metrics.mIOU)) 44 | 45 | return global_metrics 46 | -------------------------------------------------------------------------------- /dh_segment_text/utils/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.utils` module contains the parameters for config with `sacred`_ package, 3 | image label vizualization functions and miscelleanous helpers. 4 | 5 | Parameters 6 | ---------- 7 | 8 | .. autosummary:: 9 | ModelParams 10 | TrainingParams 11 | 12 | Label image helpers 13 | ------------------- 14 | 15 | .. autosummary:: 16 | label_image_to_class 17 | class_to_label_image 18 | multilabel_image_to_class 19 | multiclass_to_label_image 20 | get_classes_color_from_file 21 | get_n_classes_from_file 22 | get_classes_color_from_file_multilabel 23 | get_n_classes_from_file_multilabel 24 | 25 | Evaluation utils 26 | ---------------- 27 | 28 | .. autosummary:: 29 | Metrics 30 | intersection_over_union 31 | 32 | Miscellaneous helpers 33 | --------------------- 34 | 35 | .. autosummary:: 36 | parse_json 37 | dump_json 38 | load_pickle 39 | dump_pickle 40 | hash_dict 41 | 42 | .. _sacred : https://sacred.readthedocs.io/en/latest/index.html 43 | 44 | ------ 45 | """ 46 | 47 | _PARAMSCONFIG = [ 48 | 'PredictionType', 49 | 'ModelParams', 50 | 'TrainingParams' 51 | ] 52 | 53 | 54 | _LABELS = [ 55 | 'label_image_to_class', 56 | 'class_to_label_image', 57 | 'multilabel_image_to_class', 58 | 'multiclass_to_label_image', 59 | 'get_classes_color_from_file', 60 | 'get_n_classes_from_file', 61 | 'get_classes_color_from_file_multilabel', 62 | 'get_n_classes_from_file_multilabel' 63 | ] 64 | 65 | _MISC = [ 66 | 'parse_json', 67 | 'dump_json', 68 | 'load_pickle', 69 | 'dump_pickle', 70 | 'hash_dict' 71 | ] 72 | 73 | _EVALUATION = [ 74 | 'Metrics', 75 | 'intersection_over_union' 76 | ] 77 | 78 | __all__ = _PARAMSCONFIG + _LABELS + _MISC + _EVALUATION 79 | 80 | from .params_config import * 81 | from .labels import * 82 | from .misc import * 83 | from .evaluation import * -------------------------------------------------------------------------------- /dh_segment_text/post_processing/polygon_detection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import cv2 4 | import numpy as np 5 | import math 6 | from shapely import geometry 7 | 8 | 9 | def find_polygonal_regions(image_mask: np.ndarray, min_area: float=0.1, n_max_polygons: int=math.inf) -> list: 10 | """ 11 | Finds the shapes in a binary mask and returns their coordinates as polygons. 12 | 13 | :param image_mask: Uint8 binary 2D array 14 | :param min_area: minimum area the polygon should have in order to be considered as valid 15 | (value within [0,1] representing a percent of the total size of the image) 16 | :param n_max_polygons: maximum number of boxes that can be found (default inf). 17 | This will select n_max_boxes with largest area. 18 | :return: list of length n_max_polygons containing polygon's n coordinates [[x1, y1], ... [xn, yn]] 19 | """ 20 | 21 | _, contours, _ = cv2.findContours(image_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 22 | if contours is None: 23 | print('No contour found') 24 | return None 25 | found_polygons = list() 26 | 27 | for c in contours: 28 | if len(c) < 3: # A polygon cannot have less than 3 points 29 | continue 30 | polygon = geometry.Polygon([point[0] for point in c]) 31 | # Check that polygon has area greater than minimal area 32 | if polygon.area >= min_area*np.prod(image_mask.shape[:2]): 33 | found_polygons.append( 34 | (np.array([point for point in polygon.exterior.coords], dtype=np.uint), polygon.area) 35 | ) 36 | 37 | # sort by area 38 | found_polygons = [fp for fp in found_polygons if fp is not None] 39 | found_polygons = sorted(found_polygons, key=lambda x: x[1], reverse=True) 40 | 41 | if found_polygons: 42 | return [fp[0] for i, fp in enumerate(found_polygons) if i <= n_max_polygons] 43 | else: 44 | return None 45 | -------------------------------------------------------------------------------- /doc/start/demo.rst: -------------------------------------------------------------------------------- 1 | Demo 2 | ---- 3 | 4 | This demo shows the usage of dhSegment for page document extraction. 5 | It trains a model from scratch (optional) using the READ-BAD dataset :cite:`gruning2018read` 6 | and the annotations of `Pagenet`_ :cite:`tensmeyer2017pagenet` (annotator1 is used). 7 | In order to limit memory usage, the images in the dataset we provide have been downsized to have 1M pixels each. 8 | 9 | .. _Pagenet: https://github.com/ctensmeyer/pagenet/tree/master/annotations 10 | 11 | 12 | **How to** 13 | 14 | 0. If you have not yet done so, clone the repository : :: 15 | 16 | git clone https://github.com/dhlab-epfl/dhSegment.git 17 | 18 | 1. Get the annotated dataset `here`_, which already contains the folders ``images`` and ``labels`` 19 | for training, validation and testing set. Unzip it into ``demo/pages``. :: 20 | 21 | cd demo/ 22 | wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip 23 | unzip pages.zip 24 | cd .. 25 | 26 | .. _here: https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/pages.zip 27 | 28 | 2. (Only needed if training from scratch) Download the pretrained weights for ResNet : :: 29 | 30 | cd pretrained_models/ 31 | python download_resnet_pretrained_model.py 32 | cd .. 33 | 34 | 3. You can train the model from scratch with: ``python train.py with demo/demo_config.json`` 35 | but because this takes quite some time, we recommend you to skip this and just download the 36 | `provided model`_ (download and unzip it in ``demo/model``) :: 37 | 38 | cd demo/ 39 | wget https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip 40 | unzip model.zip 41 | cd .. 42 | 43 | .. _provided model : https://github.com/dhlab-epfl/dhSegment/releases/download/v0.2/model.zip 44 | 45 | 4. (Only if training from scratch) You can visualize the progresses in tensorboard by running 46 | ``tensorboard --logdir .`` in the ``demo`` folder. 47 | 48 | 5. Run ``python demo.py`` 49 | 50 | 6. Have a look at the results in ``demo/processed_images`` 51 | 52 | -------------------------------------------------------------------------------- /doc/reference/io.rst: -------------------------------------------------------------------------------- 1 | .. comment 2 | Interface 3 | ========= 4 | 5 | Input functions for ``tf.Estimator`` 6 | ------------------------------------ 7 | 8 | Input function 9 | 10 | .. autosummary:: 11 | input.input_fn 12 | 13 | Data augmentation 14 | 15 | .. autosummary:: 16 | data_augmentation_fn 17 | extract_patches_fn 18 | rotate_crop 19 | 20 | Resizing function 21 | 22 | .. autosummary:: 23 | dh_segment.io.input_utils.resize_image 24 | dh_segment.io.input_utils.load_and_resize_image 25 | 26 | 27 | Tensorflow serving functions 28 | ---------------------------- 29 | 30 | .. autosummary:: 31 | dh_segment.io.input.serving_input_filename 32 | dh_segment.io.input.serving_input_image 33 | 34 | ---- 35 | 36 | PAGE XML and JSON import / export 37 | --------------------------------- 38 | 39 | PAGE classes 40 | 41 | .. autosummary:: 42 | dh_segment.io.PAGE.Point 43 | dh_segment.io.PAGE.Text 44 | dh_segment.io.PAGE.Border 45 | dh_segment.io.PAGE.TextRegion 46 | dh_segment.io.PAGE.TextLine 47 | dh_segment.io.PAGE.GraphicRegion 48 | dh_segment.io.PAGE.TableRegion 49 | dh_segment.io.PAGE.SeparatorRegion 50 | dh_segment.io.PAGE.GroupSegment 51 | dh_segment.io.PAGE.Metadata 52 | dh_segment.io.PAGE.Page 53 | 54 | Abstract classes 55 | 56 | .. autosummary:: 57 | dh_segment.io.PAGE.BaseElement 58 | dh_segment.io.PAGE.Region 59 | 60 | Parsing and helpers 61 | 62 | .. autosummary:: 63 | dh_segment.io.PAGE.parse_file 64 | dh_segment.io.PAGE.json_serialize 65 | 66 | ---- 67 | 68 | ============== 69 | Input / Output 70 | ============== 71 | 72 | .. automodule:: dh_segment.io 73 | :members: 74 | :undoc-members: 75 | 76 | .. automodule:: dh_segment.io.PAGE 77 | :members: 78 | :undoc-members: 79 | 80 | .. automodule:: dh_segment.io.via 81 | :members: 82 | :undoc-members: 83 | :exclude-members: main, init_logger -------------------------------------------------------------------------------- /dh_segment_text/embeddings/conv2d_encoder.py: -------------------------------------------------------------------------------- 1 | from .encoder import EmbeddingsEncoder 2 | from .embeddings_utils import batch_resize_and_gather 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | class Conv2dEncoder(EmbeddingsEncoder): 7 | def __init__(self, target_dim: int, starting_dim: int=256, max_conv: int=-1): 8 | self.target_dim = target_dim 9 | self.starting_dim = starting_dim 10 | max_power = int(np.round(np.log2(self.starting_dim))) 11 | min_power = int(np.floor(np.log2(self.target_dim))) 12 | if max_conv == -1: 13 | max_conv = (max_power-min_power)+1 14 | self.conv_sizes = np.logspace(min_power,max_power,max_conv, base=2).astype(int)[::-1][:-1] 15 | 16 | 17 | def __call__(self, embeddings: tf.Tensor, embeddings_map: tf.Tensor, target_shape: tf.Tensor) -> tf.Tensor: 18 | with tf.variable_scope("Conv2D_encoder"): 19 | embeddings_feature_map = batch_resize_and_gather(embeddings_map, 20 | target_shape, 21 | embeddings) 22 | with tf.variable_scope("Encoder"): 23 | if self.target_dim >= self.starting_dim: 24 | raise IndexError(f"Target dim was bigger than {self.starting_dim}, got {self.target_dim}") 25 | reduced_embeddings = embeddings 26 | for i, conv_size in enumerate(self.conv_sizes): 27 | reduced_embeddings = tf.contrib.layers.conv1d(reduced_embeddings, conv_size, (1), scope='conv_%01d'%i) 28 | reduced_embeddings = tf.contrib.layers.conv1d(reduced_embeddings, self.target_dim, (1), scope='conv_final') 29 | embeddings_feature_map = batch_resize_and_gather(embeddings_map, 30 | target_shape, 31 | reduced_embeddings) 32 | embeddings_feature_map.set_shape([None, None, None, self.target_dim]) 33 | return embeddings_feature_map 34 | 35 | -------------------------------------------------------------------------------- /exps/Ornaments/ornaments_dataset_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | from glob import glob 6 | from sklearn.model_selection import train_test_split 7 | import argparse 8 | from scipy.misc import imsave, imread 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | 13 | INPUT_FOLDER = '/home/datasets/ornaments/sets_all_ornaments/all_pages/images/' 14 | 15 | 16 | def generate_set(filenames, output_dir, set: str): 17 | """ 18 | 19 | :param filenames: 20 | :param output_dir: 21 | :param set: Should be 'train', 'test' or 'validation' 22 | :return: 23 | """ 24 | 25 | out_dir = os.path.join(output_dir, set) 26 | os.makedirs(os.path.join(out_dir, 'images'), exist_ok=True) 27 | os.makedirs(os.path.join(out_dir, 'labels'), exist_ok=True) 28 | for file in tqdm(filenames, desc='Generated files'): 29 | basename = os.path.split(file)[1].split('.')[0] 30 | imsave(os.path.join(out_dir, 'images', '{}.jpg'.format(basename)), imread(file)) 31 | imsave(os.path.join(out_dir, 'labels', '{}.png'.format(basename)), 32 | imread(os.path.abspath(os.path.join(file, '..', '..', 'labels', '{}.png'.format(basename))))) 33 | 34 | # Class file 35 | classes = np.stack([(0, 0, 0), (0, 255, 0)]) 36 | np.savetxt(os.path.join(output_dir, 'classes.txt'), classes, fmt='%d') 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-i', '--input_folder', required=True, type=str, help='Input_folder where the images are') 42 | parser.add_argument('-o', '--output_dir', required=True, type=str, help='Output directory for generated dataset') 43 | args = vars(parser.parse_args()) 44 | 45 | filenames = glob(os.path.join(args.get('input_folder'), '*')) 46 | 47 | # Split 0.2 test, 0.7 train, 0.1 eval 48 | split, test_split = train_test_split(filenames, test_size=0.2, train_size=0.8, random_state=1) 49 | train_split, validation_split = train_test_split(split, test_size=0.125, train_size=0.875, random_state=1) 50 | 51 | generate_set(train_split, args.get('output_dir'), 'train') 52 | generate_set(test_split, args.get('output_dir'), 'test') 53 | generate_set(validation_split, args.get('output_dir'), 'validation') 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /exps/page/example_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from glob import glob\n", 13 | "from tqdm import tqdm\n", 14 | "import numpy as np\n", 15 | "import tempfile\n", 16 | "from .process import prediction_fn, extract_page, format_quad_to_string" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "model_dir = 'model1/export/timestamp/'\n", 26 | "input_dir = 'dataset_page/set/images/'\n", 27 | "output_dir = './out_pages'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "pp_params = {'threshold': -1, 'kernel_size': 5}" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "with tempfile.TemporaryDirectory() as tmpdirname:\n", 46 | " prediction_fn(model_dir, input_dir, tmpdirname)\n", 47 | " \n", 48 | " # Export page coordinates in txt file\n", 49 | " with open(os.path.join(output_dir, 'pages.txt'), 'w') as f:\n", 50 | " for filename in tqdm(glob(os.path.join(tmpdirname, '*.npy'))):\n", 51 | " \n", 52 | " prediction = np.load(filename)\n", 53 | " pred_box = extract_page(prediction / np.max(prediction), **pp_params)\n", 54 | " \n", 55 | " f.write('{},{}\\n'.format(filename, format_quad_to_string(pred_box)))" 56 | ] 57 | } 58 | ], 59 | "metadata": { 60 | "kernelspec": { 61 | "display_name": "Python 2", 62 | "language": "python", 63 | "name": "python2" 64 | }, 65 | "language_info": { 66 | "codemirror_mode": { 67 | "name": "ipython", 68 | "version": 2.0 69 | }, 70 | "file_extension": ".py", 71 | "mimetype": "text/x-python", 72 | "name": "python", 73 | "nbconvert_exporter": "python", 74 | "pygments_lexer": "ipython2", 75 | "version": "2.7.6" 76 | } 77 | }, 78 | "nbformat": 4, 79 | "nbformat_minor": 0 80 | } -------------------------------------------------------------------------------- /exps/DIBCO/dibco_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | from exps.evaluation.base import Metrics, compare_bin_prediction_to_label 6 | from glob import glob 7 | from scipy.misc import imread, imsave 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | def dibco_evaluate_folder(output_folder: str, validation_dir: str, verbose=False, debug_folder=None) -> dict: 13 | filenames_processed = glob(os.path.join(output_folder, '*.png')) 14 | 15 | if debug_folder is not None: 16 | os.makedirs(debug_folder, exist_ok=True) 17 | 18 | global_metrics = Metrics() 19 | for filename in filenames_processed: 20 | post_processed = imread(filename, mode='L') 21 | post_processed = post_processed / np.max(post_processed) 22 | 23 | basename = os.path.basename(filename).split('.')[0] 24 | label_image = imread(os.path.join(validation_dir, 'labels', '{}.png'.format(basename)), mode='L') 25 | label_image_normalized = label_image / np.max(label_image) 26 | 27 | target_shape = (label_image.shape[1], label_image.shape[0]) 28 | bin_upscaled = cv2.resize(np.uint8(post_processed), target_shape, interpolation=cv2.INTER_NEAREST) 29 | 30 | # Compute errors 31 | metric = compare_bin_prediction_to_label(bin_upscaled, label_image_normalized) 32 | global_metrics += metric 33 | 34 | if debug_folder is not None: 35 | debug_image = np.zeros((*label_image.shape[:2], 3), np.uint8) 36 | debug_image[np.logical_and(bin_upscaled, label_image_normalized)] = [0, 255, 0] 37 | debug_image[np.logical_and(bin_upscaled, label_image_normalized == 0)] = [0, 0, 255] 38 | debug_image[np.logical_and(bin_upscaled == 0, label_image_normalized)] = [255, 0, 0] 39 | imsave(os.path.join(debug_folder, basename + '.png'), debug_image) 40 | 41 | global_metrics.compute_mse() 42 | global_metrics.compute_psnr() 43 | global_metrics.compute_prf() 44 | 45 | if verbose: 46 | print('EVAL --- PSNR : {}, R : {}, P : {}, FM : {}'.format(global_metrics.PSNR, global_metrics.recall, 47 | global_metrics.precision, global_metrics.f_measure)) 48 | 49 | return {k: v for k, v in vars(global_metrics).items() if k in ['MSE', 'PSNR', 'precision', 'recall', 'f_measure']} 50 | 51 | -------------------------------------------------------------------------------- /doc/start/annotating.rst: -------------------------------------------------------------------------------- 1 | Creating groundtruth data 2 | ------------------------- 3 | 4 | Using GIMP or Photoshop 5 | ^^^^^^^^^^^^^^^^^^^^^^^ 6 | Create directly your masks using your favorite image editor. You just have to draw the regions you want to extract 7 | with a different color for each label. 8 | 9 | Using VGG Image Annotator (VIA) 10 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 11 | `VGG Image Annotator (VIA) `_ is an image annotation tool that can be 12 | used to define regions in an image and create textual descriptions of those regions. You can either use it 13 | `online `_ or 14 | `download the application `_. 15 | 16 | From the exported annotations (in JSON format), you'll have to generate the corresponding image masks. 17 | See the :ref:`ref_via` in the ``via`` module. 18 | 19 | When assigning attributes to your annotated regions, you should favour attributes of type "dropdown", "checkbox" 20 | and "radio" and avoid "text" type in order to ease the parsing of the exported file (avoid typos and formatting errors). 21 | 22 | **Example of how to create individual masks from VIA annotation file** 23 | 24 | .. code:: python 25 | 26 | from dh_segment.io import via 27 | 28 | collection = 'mycollection' 29 | annotation_file = 'via_sample.json' 30 | masks_dir = '/home/project/generated_masks' 31 | images_dir = './my_images' 32 | 33 | # Load all the data in the annotation file 34 | # (the file may be an exported project or an export of the annotations) 35 | via_data = via.load_annotation_data(annotation_file) 36 | 37 | # In the case of an exported project file, you can set ``only_img_annotations=True`` 38 | # to get only the image annotations 39 | via_annotations = via.load_annotation_data(annotation_file, only_img_annotations=True) 40 | 41 | # Collect the annotated regions 42 | working_items = via.collect_working_items(via_annotations, collection, images_dir) 43 | 44 | # Collect the attributes and options 45 | if '_via_attributes' in via_data.keys(): 46 | list_attributes = via.parse_via_attributes(via_data['_via_attributes']) 47 | else: 48 | list_attributes = via.get_via_attributes(via_annotations) 49 | 50 | # Create one mask per option per attribute 51 | via.create_masks(masks_dir, working_items, list_attributes, collection) 52 | 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | # Edit at https://www.gitignore.io/?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # End of https://www.gitignore.io/api/python 130 | -------------------------------------------------------------------------------- /dh_segment_text/network/pretrained_models/mobilenet/encoder.py: -------------------------------------------------------------------------------- 1 | from ...model import Encoder 2 | import tensorflow as tf 3 | from .mobilenet_v2 import training_scope, mobilenet_base 4 | from typing import Tuple, Optional, Union, List, Dict 5 | from tensorflow.contrib import slim 6 | import os 7 | from ....utils.misc import get_data_folder, download_file 8 | import tarfile 9 | 10 | 11 | class MobileNetV2(Encoder): 12 | def __init__(self, train_batchnorm: bool=False, weight_decay: float=0.00004, batch_renorm: bool=True): 13 | self.train_batchnorm = train_batchnorm 14 | self.weight_decay = weight_decay 15 | self.batch_renorm = batch_renorm 16 | pretrained_dir = os.path.join(get_data_folder(), 'mobilenet_v2') 17 | self.pretrained_file = os.path.join(pretrained_dir, 'mobilenet_v2_1.0_224.ckpt') 18 | if not os.path.exists(self.pretrained_file+'.index'): 19 | print("Could not find pre-trained file {}, downloading it!".format(self.pretrained_file)) 20 | tar_filename = os.path.join(get_data_folder(), 'resnet_v1_50.tar.gz') 21 | download_file('https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz', tar_filename) 22 | tar = tarfile.open(tar_filename) 23 | tar.extractall(path=pretrained_dir) 24 | tar.close() 25 | os.remove(tar_filename) 26 | assert os.path.exists(self.pretrained_file+'.index') 27 | print('Pre-trained weights downloaded!') 28 | 29 | def __call__(self, images: tf.Tensor, is_training=False) -> List[tf.Tensor]: 30 | outputs = [] 31 | 32 | with slim.arg_scope(training_scope(weight_decay=self.weight_decay, 33 | is_training=is_training and self.train_batchnorm)): 34 | normalized_images = (images / 127.5) - 1.0 35 | outputs.append(normalized_images) 36 | 37 | desired_endpoints = [ 38 | 'layer_2', 39 | 'layer_4', 40 | 'layer_7', 41 | 'layer_14', 42 | 'layer_18' 43 | ] 44 | 45 | _, endpoints = mobilenet_base(normalized_images) 46 | for d in desired_endpoints: 47 | outputs.append(endpoints[d]) 48 | 49 | return outputs 50 | 51 | def pretrained_information(self) -> Tuple[Optional[str], Union[None, List, Dict]]: 52 | return self.pretrained_file, [v for v in tf.global_variables() 53 | if 'MobilenetV2' in v.name and 'renorm' not in v.name] -------------------------------------------------------------------------------- /exps/diva/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | import os 8 | from imageio import imsave 9 | import subprocess 10 | import tempfile 11 | import numpy as np 12 | from .process import diva_post_processing_fn 13 | from .utils import to_original_color_code, parse_diva_tool_output 14 | 15 | 16 | DIVA_JAR = './DIVA_Layout_Analysis_Evaluator/out/artifacts/LayoutAnalysisEvaluator.jar' 17 | PP_PARAMS = {'thresholds': [0.5, 0.5, 0.5], 'min_cc': 50} 18 | 19 | 20 | def eval_fn(input_dir: str, groundtruth_dir: str, output_filename: str, post_process_params: dict=PP_PARAMS, 21 | diva_jar: str=DIVA_JAR): 22 | """ 23 | 24 | :param input_dir: directory containing the predictions .npy files (range [0, 255]) 25 | :param groundtruth_dir: directory containing the ground truth images (.png) (must have the same name as predictions 26 | files in input_dir) 27 | :param output_filename: filename of the .txt file containing all the results computed by the Evaluation tool 28 | :param post_process_params: params for post processing fn 29 | :param diva_jar: path for the DIVA Evaluation Tool (.jar file) 30 | :return: mean IU 31 | """ 32 | results_list = list() 33 | with tempfile.TemporaryDirectory() as tmpdir: 34 | for file in tqdm(glob(os.path.join(input_dir, '*.npy'))): 35 | basename = os.path.basename(file).split('.')[0] 36 | 37 | pred = np.load(file) 38 | pp_preds = diva_post_processing_fn(pred/np.max(pred), **post_process_params) 39 | 40 | original_colors = to_original_color_code(np.uint8(pp_preds * 255)) 41 | pred_img_filename = os.path.join(tmpdir, '{}_orig_colors.png'.format(basename)) 42 | imsave(pred_img_filename, original_colors) 43 | 44 | label_image_filename = os.path.join(groundtruth_dir, basename + '.png') 45 | 46 | cmd = 'java -jar {} -gt {} -p {}'.format(diva_jar, label_image_filename, pred_img_filename) 47 | result = subprocess.check_output(cmd, shell=True).decode() 48 | results_list.append(result) 49 | 50 | mius = list() 51 | with open(output_filename, 'w') as f: 52 | for result in results_list: 53 | r = parse_diva_tool_output(result) 54 | mius.append(r['Mean_IU']) 55 | f.write(result) 56 | f.write('--- Mean IU : {}'.format(np.mean(mius))) 57 | 58 | print('Mean IU : {}'.format(np.mean(mius))) 59 | return np.mean(mius) 60 | -------------------------------------------------------------------------------- /dh_segment_text/post_processing/binarization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.ndimage import label 4 | 5 | 6 | def thresholding(probs: np.ndarray, threshold: float=-1) -> np.ndarray: 7 | """ 8 | Computes the binary mask of the detected Page from the probabilities output by network. 9 | 10 | :param probs: array in range [0, 1] of shape HxWx2 11 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 12 | :return: binary mask 13 | """ 14 | 15 | if threshold < 0: # Otsu's thresholding 16 | probs = np.uint8(probs * 255) 17 | #TODO Correct that weird gaussianBlur 18 | probs = cv2.GaussianBlur(probs, (5, 5), 0) 19 | 20 | thresh_val, bin_img = cv2.threshold(probs, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 21 | mask = np.uint8(bin_img / 255) 22 | else: 23 | mask = np.uint8(probs > threshold) 24 | 25 | return mask 26 | 27 | 28 | def cleaning_binary(mask: np.ndarray, kernel_size: int=5) -> np.ndarray: 29 | """ 30 | Uses mathematical morphology to clean and remove small elements from binary images. 31 | 32 | :param mask: the binary image to clean 33 | :param kernel_size: size of the kernel 34 | :return: the cleaned mask 35 | """ 36 | 37 | ksize_open = (kernel_size, kernel_size) 38 | ksize_close = (kernel_size, kernel_size) 39 | mask = cv2.morphologyEx((mask.astype(np.uint8, copy=False) * 255), cv2.MORPH_OPEN, kernel=np.ones(ksize_open)) 40 | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel=np.ones(ksize_close)) 41 | return np.uint8(mask / 255) 42 | 43 | 44 | def hysteresis_thresholding(probs: np.array, low_threshold: float, high_threshold: float, 45 | candidates_mask: np.ndarray=None) -> np.ndarray: 46 | low_mask = probs > low_threshold 47 | if candidates_mask is not None: 48 | low_mask = candidates_mask & low_mask 49 | # Connected components extraction 50 | label_components, count = label(low_mask, np.ones((3, 3))) 51 | # Keep components with high threshold elements 52 | good_labels = np.unique(label_components[low_mask & (probs > high_threshold)]) 53 | label_masks = np.zeros((count + 1,), bool) 54 | label_masks[good_labels] = 1 55 | return label_masks[label_components] 56 | 57 | 58 | def cleaning_probs(probs: np.ndarray, sigma: float) -> np.ndarray: 59 | # Smooth 60 | if sigma > 0.: 61 | return cv2.GaussianBlur(probs, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) 62 | elif sigma == 0.: 63 | return cv2.fastNlMeansDenoising((probs*255).astype(np.uint8), h=20)/255 64 | else: # Negative sigma, do not do anything 65 | return probs 66 | -------------------------------------------------------------------------------- /dh_segment_text/io/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The :mod:`dh_segment.io` module implements input / output functions and classes. 3 | 4 | Input functions for ``tf.Estimator`` 5 | ------------------------------------ 6 | 7 | **Input function** 8 | 9 | .. autosummary:: 10 | input_fn 11 | 12 | **Data augmentation** 13 | 14 | .. autosummary:: 15 | data_augmentation_fn 16 | extract_patches_fn 17 | rotate_crop 18 | 19 | **Resizing function** 20 | 21 | .. autosummary:: 22 | resize_image 23 | load_and_resize_image 24 | 25 | 26 | Tensorflow serving functions 27 | ---------------------------- 28 | 29 | .. autosummary:: 30 | serving_input_filename 31 | serving_input_image 32 | 33 | ---- 34 | 35 | PAGE XML and JSON import / export 36 | --------------------------------- 37 | 38 | **PAGE classes** 39 | 40 | .. autosummary:: 41 | PAGE.Point 42 | PAGE.Text 43 | PAGE.Border 44 | PAGE.TextRegion 45 | PAGE.TextLine 46 | PAGE.GraphicRegion 47 | PAGE.TableRegion 48 | PAGE.SeparatorRegion 49 | PAGE.GroupSegment 50 | PAGE.Metadata 51 | PAGE.Page 52 | 53 | **Abstract classes** 54 | 55 | .. autosummary:: 56 | PAGE.BaseElement 57 | PAGE.Region 58 | 59 | **Parsing and helpers** 60 | 61 | .. autosummary:: 62 | PAGE.parse_file 63 | PAGE.json_serialize 64 | 65 | ---- 66 | 67 | .. _ref_via: 68 | 69 | VGG Image Annotator helpers 70 | --------------------------- 71 | 72 | 73 | **VIA objects** 74 | 75 | .. autosummary:: 76 | via.WorkingItem 77 | via.VIAttribute 78 | 79 | 80 | **Creating masks with VIA annotations** 81 | 82 | .. autosummary:: 83 | via.load_annotation_data 84 | via.export_annotation_dict 85 | via.get_annotations_per_file 86 | via.parse_via_attributes 87 | via.get_via_attributes 88 | via.collect_working_items 89 | via.create_masks 90 | 91 | 92 | **Formatting in VIA JSON format** 93 | 94 | .. autosummary:: 95 | via.create_via_region_from_coordinates 96 | via.create_via_annotation_single_image 97 | 98 | ---- 99 | 100 | """ 101 | 102 | 103 | _INPUT = [ 104 | 'input_fn', 105 | 'serving_input_filename', 106 | 'serving_input_image', 107 | 'data_augmentation_fn', 108 | 'rotate_crop', 109 | 'resize_image', 110 | 'load_and_resize_image', 111 | 'extract_patches_fn', 112 | 'local_entropy' 113 | ] 114 | 115 | # _PAGE_OBJECTS = [ 116 | # 'Point', 117 | # 'Text', 118 | # 'Region', 119 | # 'TextLine', 120 | # 'GraphicRegion', 121 | # 'TextRegion', 122 | # 'TableRegion', 123 | # 'SeparatorRegion', 124 | # 'Border', 125 | # 'Metadata', 126 | # 'GroupSegment', 127 | # 'Page' 128 | # ] 129 | # 130 | # _PAGE_FN = [ 131 | # 'parse_file', 132 | # 'json_serialize' 133 | # ] 134 | 135 | __all__ = _INPUT # + _PAGE_OBJECTS + _PAGE_FN 136 | 137 | from .input import * 138 | from .input_utils import * 139 | from . import PAGE 140 | from . import via 141 | 142 | -------------------------------------------------------------------------------- /dh_segment_text/post_processing/line_vectorization.py: -------------------------------------------------------------------------------- 1 | from skimage.graph import MCP_Connect 2 | from skimage.morphology import skeletonize 3 | from skimage.measure import label as skimage_label 4 | from sklearn.metrics.pairwise import euclidean_distances 5 | from scipy.signal import convolve2d 6 | from collections import defaultdict 7 | import numpy as np 8 | 9 | 10 | def find_lines(lines_mask: np.ndarray) -> list: 11 | """ 12 | Finds the longest central line for each connected component in the given binary mask. 13 | 14 | :param lines_mask: Binary mask of the detected line-areas 15 | :return: a list of Opencv-style polygonal lines (each contour encoded as [N,1,2] elements where each tuple is (x,y) ) 16 | """ 17 | # Make sure one-pixel wide 8-connected mask 18 | lines_mask = skeletonize(lines_mask) 19 | 20 | class MakeLineMCP(MCP_Connect): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.connections = dict() 24 | self.scores = defaultdict(lambda: np.inf) 25 | 26 | def create_connection(self, id1, id2, pos1, pos2, cost1, cost2): 27 | k = (min(id1, id2), max(id1, id2)) 28 | s = cost1 + cost2 29 | if self.scores[k] > s: 30 | self.connections[k] = (pos1, pos2, s) 31 | self.scores[k] = s 32 | 33 | def get_connections(self, subsample=5): 34 | results = dict() 35 | for k, (pos1, pos2, s) in self.connections.items(): 36 | path = np.concatenate([self.traceback(pos1), self.traceback(pos2)[::-1]]) 37 | results[k] = path[::subsample] 38 | return results 39 | 40 | def goal_reached(self, int_index, float_cumcost): 41 | if float_cumcost > 0: 42 | return 2 43 | else: 44 | return 0 45 | 46 | if np.sum(lines_mask) == 0: 47 | return [] 48 | # Find extremities points 49 | end_points_candidates = np.stack(np.where((convolve2d(lines_mask, np.ones((3, 3)), mode='same') == 2) & lines_mask)).T 50 | connected_components = skimage_label(lines_mask, connectivity=2) 51 | # Group endpoint by connected components and keep only the two points furthest away 52 | d = defaultdict(list) 53 | for pt in end_points_candidates: 54 | d[connected_components[pt[0], pt[1]]].append(pt) 55 | end_points = [] 56 | for pts in d.values(): 57 | d = euclidean_distances(np.stack(pts), np.stack(pts)) 58 | i, j = np.unravel_index(d.argmax(), d.shape) 59 | end_points.append(pts[i]) 60 | end_points.append(pts[j]) 61 | end_points = np.stack(end_points) 62 | 63 | mcp = MakeLineMCP(~lines_mask) 64 | mcp.find_costs(end_points) 65 | connections = mcp.get_connections() 66 | if not np.all(np.array(sorted([i for k in connections.keys() for i in k])) == np.arange(len(end_points))): 67 | print('Warning : find_lines seems weird') 68 | return [c[:, None, ::-1] for c in connections.values()] 69 | -------------------------------------------------------------------------------- /exps/Ornaments/ornaments_process_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import argparse 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | import numpy as np 9 | from glob import glob 10 | from ornaments_evaluation import ornament_evaluate_folder 11 | from ornaments_post_processing import ornaments_post_processing_fn 12 | import tempfile 13 | 14 | 15 | PARAMS = {"threshold": 0.6, "ksize_open": [0, 0], "ksize_close": [0, 0]} 16 | MIOU_THRESHOD = 0.8 17 | MIN_AREA = 0.005 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('-d', '--npy_directory', type=str, required=True, 22 | help='Directory containing bin .png files') 23 | parser.add_argument('-gt', '--gt_directory', type=str, required=True, 24 | help='Directory containing images and labels for evaluation') 25 | parser.add_argument('-o', '--output_directory', type=str, required=True, 26 | help='Output directory') 27 | parser.add_argument('-p', '--params_file', type=str, required=False, 28 | help='JSON params file') 29 | args = parser.parse_args() 30 | args = vars(args) 31 | 32 | output_dir = args.get('output_directory') 33 | npy_dir = args.get('npy_directory') 34 | # os.makedirs(output_dir) 35 | 36 | if args.get('params_file') is None: 37 | print('No params file found') 38 | params_list = [PARAMS] 39 | else: 40 | with open(args.get('params_file'), 'r') as f: 41 | configs_data = json.load(f) 42 | # If the file contains a list of configurations 43 | if 'configs' in configs_data.keys(): 44 | params_list = configs_data['configs'] 45 | assert isinstance(params_list, list) 46 | # Or if there is a single configuration 47 | else: 48 | params_list = [configs_data] 49 | 50 | npy_files = glob(os.path.join(npy_dir, '*.npy')) 51 | for params in params_list: 52 | new_output_dir = output_dir + 'th{}_a{}_{}'.format(MIOU_THRESHOD, MIN_AREA, np.random.randint(0, 1000)) 53 | os.makedirs(new_output_dir) 54 | 55 | with tempfile.TemporaryDirectory() as tmpdir: 56 | for filename in tqdm(npy_files): 57 | probs = np.load(filename) 58 | _ = ornaments_post_processing_fn(probs/np.max(probs), **params, 59 | output_basename=os.path.join(tmpdir, 60 | os.path.basename(filename).split('.')[0])) 61 | 62 | measures = ornament_evaluate_folder(tmpdir, args.get('gt_directory'), min_area=MIN_AREA, 63 | miou_threshold=MIOU_THRESHOD, debug_folder=new_output_dir) 64 | 65 | with open(os.path.join(new_output_dir, 'validation_scores.json'), 'w') as f: 66 | json.dump(measures, f) 67 | with open(os.path.join(new_output_dir, 'post_process_params.json'), 'w') as f: 68 | json.dump(params, f) 69 | -------------------------------------------------------------------------------- /exps/diva/process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | from dh_segment.loader import LoadedModel 7 | from glob import glob 8 | import numpy as np 9 | from tqdm import tqdm 10 | import os 11 | import cv2 12 | from typing import List 13 | from imageio import imsave 14 | 15 | 16 | def prediction_fn(model_dir: str, input_dir: str, output_dir: str=None): 17 | 18 | if not output_dir: 19 | # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions' 20 | output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions']) 21 | 22 | os.makedirs(output_dir, exist_ok=True) 23 | filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) 24 | # Load model 25 | with tf.Session(): 26 | m = LoadedModel(model_dir, 'filename_original_shape') 27 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 28 | pred = m.predict(filename)['probs'][0] 29 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), np.uint8(255 * pred)) 30 | 31 | 32 | def diva_post_processing_fn(probs: np.array, thresholds: List[float]=[0.5, 0.5, 0.5], min_cc: int=0, 33 | page_mask: np.array=None, output_basename: str=None) -> np.ndarray: 34 | """ 35 | 36 | :param probs: array in range [0, 1] of shape HxWx3 37 | :param thresholds: list of length 3 corresponding to the threshold for each channel 38 | :param min_cc: minimum size of connected components to keep 39 | :param border_removal: removes pixels in left and right border of the image that are within a certain margin 40 | :param output_basename: 41 | :return: 42 | """ 43 | # border_margin = probs.shape[1] * 0.02 44 | final_mask = np.zeros_like(probs, dtype=np.uint8) 45 | # Compute binary mask for each class (each channel) 46 | 47 | if page_mask is not None: 48 | probs = (page_mask > 0)[:, :, None] * probs 49 | 50 | for ch in range(probs.shape[-1]): 51 | probs_ch = probs[:, :, ch] 52 | if thresholds[ch] < 0: # Otsu thresholding 53 | probs_ch = np.uint8(probs_ch * 255) 54 | blur = cv2.GaussianBlur(probs_ch, (5, 5), 0) 55 | thresh_val, bin_img = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 56 | bin_img = bin_img / 255 57 | else: 58 | bin_img = probs_ch > thresholds[ch] 59 | 60 | if min_cc > 0: 61 | _, labeled_cc = cv2.connectedComponents(bin_img.astype(np.uint8), connectivity=8) 62 | for lab in np.unique(labeled_cc): 63 | mask = labeled_cc == lab 64 | if np.sum(mask) < min_cc: 65 | labeled_cc[mask] = 0 66 | final_mask[:, :, ch] = bin_img * (labeled_cc > 0) 67 | else: 68 | final_mask[:, :, ch] = bin_img 69 | 70 | result = final_mask.astype(int) 71 | 72 | if output_basename is not None: 73 | imsave('{}.png'.format(output_basename), result*255) 74 | 75 | return result 76 | -------------------------------------------------------------------------------- /doc/start/training.rst: -------------------------------------------------------------------------------- 1 | Training 2 | -------- 3 | 4 | .. note:: A good nvidia GPU (6GB RAM at least) is most likely necessary to train your own models. We assume CUDA 5 | and cuDNN are installed. 6 | 7 | **Input data** 8 | 9 | You need to have your training data in a folder containing ``images`` folder and ``labels`` folder. 10 | The pairs (images, labels) need to have the same name (it is not mandatory to have the same extension file, 11 | however we recommend having the label images as ``.png`` files). 12 | 13 | The annotated images in ``label`` folder are (usually) RGB images with the regions to segment annotated with 14 | a specific color. 15 | 16 | .. note:: It is now also possible to use a `csv` file containing the pairs ``original_image_filename``, 17 | ``label_image_filename`` as input data. 18 | 19 | To input a ``csv`` file instead of the two folders ``images`` and ``labels``, 20 | the content should be formatted in the following way: :: 21 | 22 | mypath/myfolder/original_image_filename1,mypath/myfolder/label_image_filename1 23 | mypath/myfolder/original_image_filename2,mypath/myfolder/label_image_filename2 24 | 25 | 26 | 27 | **The class.txt file** 28 | 29 | The file containing the classes has the format shown below, where each row corresponds to one class 30 | (including 'negative' or 'background' class) and each row has 3 values for the 3 RGB values. 31 | Of course each class needs to have a different code. :: 32 | 33 | classes.txt 34 | 35 | 0 0 0 36 | 0 255 0 37 | ... 38 | 39 | 40 | **Config file with ``sacred``** 41 | 42 | `sacred`_ package is used to deal with experiments and trainings. Have a look at the documentation to use it properly. 43 | 44 | In order to train a model, you should run ``python train.py with `` 45 | 46 | .. _sacred: https://sacred.readthedocs.io/en/latest/quickstart.html 47 | 48 | 49 | Multilabel classification training 50 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 51 | 52 | In case you want to be able to assign multiple labels to elements, the ``classes.txt`` file must be changed. 53 | Besides the color code, you need to add an *attribution* code to each color. The attribution code has length `n_classes` 54 | and indicates which classes are assigned to the color. 55 | 56 | Take for example 3 classes {A, B, C} and the following possible labelling combinations: 57 | 58 | - A (color code ``(0 255 0)``) with attribution code ``1 0 0`` 59 | - B (color code ``(255 0 0)``) with attribution code ``0 1 0`` 60 | - C (color code ``(0 0 255)``) with attribution code ``0 0 1`` 61 | - AB (color code ``(128 128 128)``) with attribution code ``1 1 0`` 62 | - BC (color code ``(0 255 255)``) with attribution code ``0 1 1`` 63 | 64 | The attributions code has value ``1`` when the label is assigned and ``0`` when it's not. 65 | (The attribution code ``1 0 1`` would mean that the color annotates elements that belong to classes A and C) 66 | 67 | In our example the ``classes.txt`` file would then look like : :: 68 | 69 | 70 | classes.txt 71 | 72 | 0 0 0 0 0 0 73 | 0 255 0 1 0 0 74 | 255 0 0 0 1 0 75 | 0 0 255 0 0 1 76 | 128 128 128 1 1 0 77 | 0 255 255 0 1 1 78 | -------------------------------------------------------------------------------- /dh_segment_text/utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __license__ = "GPL" 3 | 4 | import tensorflow as tf 5 | import json 6 | import pickle 7 | from hashlib import sha1 8 | from typing import Any 9 | import importlib 10 | import os 11 | import urllib.request 12 | import tarfile 13 | import os 14 | from tqdm import tqdm 15 | from random import shuffle 16 | import random 17 | 18 | 19 | def parse_json(filename): 20 | with open(filename, 'r') as f: 21 | return json.load(f) 22 | 23 | 24 | def dump_json(filename, dict): 25 | with open(filename, 'w') as f: 26 | json.dump(dict, f, indent=4, sort_keys=True) 27 | 28 | 29 | def load_pickle(filename): 30 | with open(filename, 'rb') as f: 31 | return pickle.load(f) 32 | 33 | 34 | def dump_pickle(filename, obj): 35 | with open(filename, 'wb') as f: 36 | return pickle.dump(obj, f) 37 | 38 | 39 | def hash_dict(params): 40 | return sha1(json.dumps(params, sort_keys=True).encode()).hexdigest() 41 | 42 | 43 | def shuffled(l: list, seed: int) -> list: 44 | random.seed(seed) 45 | ll = l.copy() 46 | shuffle(ll) 47 | return ll 48 | 49 | 50 | def get_class_from_name(full_class_name: str) -> Any: 51 | """ 52 | Tries to load the class from its naming, will import the corresponding module. 53 | Raises an Error if it does not work. 54 | 55 | :param full_class_name: full name of the class, for instance `foo.bar.Baz` 56 | :return: the loaded class 57 | """ 58 | module_name, class_name = full_class_name.rsplit('.', maxsplit=1) 59 | # load the module, will raise ImportError if module cannot be loaded 60 | m = importlib.import_module(module_name) 61 | # get the class, will raise AttributeError if class cannot be found 62 | c = getattr(m, class_name) 63 | return c 64 | 65 | 66 | def get_data_folder() -> str: 67 | folder = os.path.join(os.path.expanduser('~'), '.dh_segment') 68 | os.makedirs(folder, exist_ok=True) 69 | return folder 70 | 71 | 72 | def download_file(url: str, output_file: str): 73 | """ 74 | 75 | :param url: 76 | :param output_file: 77 | :return: 78 | """ 79 | def progress_hook(t): 80 | last_b = [0] 81 | 82 | def update_to(b=1, bsize=1, tsize=None): 83 | """ 84 | b : int, optional 85 | Number of blocks transferred so far [default: 1]. 86 | bsize : int, optional 87 | Size of each block (in tqdm units) [default: 1]. 88 | tsize : int, optional 89 | Total size (in tqdm units). If [default: None] remains unchanged. 90 | """ 91 | if tsize is not None: 92 | t.total = tsize 93 | t.update((b - last_b[0]) * bsize) 94 | last_b[0] = b 95 | 96 | return update_to 97 | 98 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, 99 | desc="Downloading pre-trained weights") as t: 100 | urllib.request.urlretrieve(url, output_file, reporthook=progress_hook(t)) 101 | -------------------------------------------------------------------------------- /exps/Cini/cini_post_processing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from doc_seg.utils import dump_pickle 4 | 5 | 6 | def cini_post_processing_fn(preds: np.ndarray, 7 | clean_predictions=False, 8 | advanced=False, 9 | output_basename=None): 10 | # class 0 -> cardboard 11 | # class 1 -> background 12 | # class 2 -> photograph 13 | 14 | def get_cleaned_prediction(prediction): 15 | # Perform Erosion and Dilation 16 | if not clean_predictions: 17 | return prediction 18 | opening = cv2.morphologyEx(prediction, cv2.MORPH_OPEN, np.ones((5, 5))) 19 | closing = cv2.medianBlur(opening, 11) 20 | return closing 21 | 22 | class_predictions = np.argmax(preds, axis=-1) 23 | 24 | # get cardboard rectangle 25 | cardboard_prediction = get_cleaned_prediction((class_predictions == 0).astype(np.uint8)) 26 | _, contours, hierarchy = cv2.findContours(cardboard_prediction, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 27 | cardboard_contour = np.concatenate(contours) # contours[np.argmax([cv2.contourArea(c) for c in contours])] 28 | cardboard_rectangle = cv2.minAreaRect(cardboard_contour) 29 | # If extracted cardboard too small compared to scan size, get cardboard+image prediction 30 | if cv2.contourArea(cv2.boxPoints(cardboard_rectangle)) < 0.20*cardboard_prediction.size: 31 | cardboard_prediction = get_cleaned_prediction(((class_predictions == 0) | (class_predictions == 2)).astype(np.uint8)) 32 | _, contours, hierarchy = cv2.findContours(cardboard_prediction, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 33 | cardboard_contour = np.concatenate(contours) # contours[np.argmax([cv2.contourArea(c) for c in contours])] 34 | cardboard_rectangle = cv2.minAreaRect(cardboard_contour) 35 | 36 | image_prediction = (class_predictions == 2).astype(np.uint8) 37 | if advanced: 38 | # Force the image prediction to be inside the extracted cardboard 39 | mask = np.zeros_like(image_prediction) 40 | cv2.fillConvexPoly(mask, cv2.boxPoints(cardboard_rectangle).astype(np.int32), 1) 41 | image_prediction = mask * image_prediction 42 | eroded_mask = cv2.erode(mask, np.ones((20, 20))) 43 | image_prediction = image_prediction | (~cardboard_prediction & eroded_mask) 44 | 45 | image_prediction = get_cleaned_prediction(image_prediction) 46 | _, contours, hierarchy = cv2.findContours(image_prediction, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 47 | contours = sorted(contours, key=cv2.contourArea, reverse=True) 48 | # Take the biggest contour or two biggest if similar size (two images in the page) 49 | image_contour = contours[0] if len(contours) == 1 or (cv2.contourArea(contours[0]) > 0.5*cv2.contourArea(contours[1])) \ 50 | else np.concatenate(contours[0:2]) 51 | image_rectangle = cv2.minAreaRect(image_contour) 52 | 53 | if output_basename is not None: 54 | dump_pickle(output_basename+'.pkl', { 55 | 'shape': preds.shape[:2], 56 | 'cardboard_rectangle': cardboard_rectangle, 57 | 'image_rectangle': image_rectangle 58 | }) 59 | return cardboard_rectangle, image_rectangle -------------------------------------------------------------------------------- /dh_segment_text/embeddings/conv1d_encoder.py: -------------------------------------------------------------------------------- 1 | from .encoder import EmbeddingsEncoder 2 | from .embeddings_utils import batch_resize_and_gather 3 | import tensorflow as tf 4 | from tensorflow.contrib import layers 5 | from tensorflow.contrib.slim import arg_scope 6 | import numpy as np 7 | 8 | class Conv1dEncoder(EmbeddingsEncoder): 9 | def __init__(self, target_dim: int, starting_dim: int=256, max_conv: int=-1, renorm=False, weight_decay: float=0.): 10 | self.target_dim = target_dim 11 | self.starting_dim = starting_dim 12 | max_power = int(np.round(np.log2(self.starting_dim))) 13 | min_power = int(np.floor(np.log2(self.target_dim))) 14 | if max_conv == -1: 15 | max_conv = (max_power-min_power)+1 16 | self.conv_sizes = np.logspace(min_power,max_power,max_conv, base=2).astype(int)[::-1][:-1] 17 | 18 | self.batch_norm_params = { 19 | "renorm": renorm, 20 | "renorm_clipping": {'rmax': 100, 'rmin': 0.1, 'dmax': 1}, 21 | "renorm_momentum": 0.98 22 | } 23 | self.weight_decay = weight_decay 24 | 25 | 26 | 27 | def __call__(self, embeddings: tf.Tensor, embeddings_map: tf.Tensor, target_shape: tf.Tensor, is_training=False) -> tf.Tensor: 28 | 29 | batch_norm_fn = lambda x: tf.layers.batch_normalization(x, axis=-1, training=is_training, 30 | name='batch_norm', **self.batch_norm_params) 31 | 32 | with tf.variable_scope("Conv1D_encoder"): 33 | with tf.variable_scope("Encoder"): 34 | with arg_scope([layers.conv1d], 35 | normalizer_fn=batch_norm_fn, 36 | weights_regularizer=layers.l2_regularizer(self.weight_decay)): 37 | if self.target_dim >= self.starting_dim: 38 | raise IndexError(f"Target dim was bigger than {self.starting_dim}, got {self.target_dim}") 39 | reduced_embeddings = embeddings 40 | for i, conv_size in enumerate(self.conv_sizes): 41 | reduced_embeddings = tf.contrib.layers.conv1d(reduced_embeddings, conv_size, (1), scope='conv_%01d'%i) 42 | reduced_embeddings = tf.contrib.layers.conv1d(reduced_embeddings, self.target_dim, (1), scope='conv_final') 43 | embeddings_feature_map = batch_resize_and_gather(embeddings_map, 44 | target_shape, 45 | reduced_embeddings) 46 | embeddings_feature_map.set_shape([None, None, None, self.target_dim]) 47 | embeddings_feature_map_first_dims = embeddings_feature_map[:,:,:,:3] 48 | embeddings_feature_map_first_dims = tf.div( 49 | tf.subtract( 50 | embeddings_feature_map_first_dims, 51 | tf.reduce_min(embeddings_feature_map_first_dims) 52 | ), 53 | tf.subtract( 54 | tf.reduce_max(embeddings_feature_map_first_dims), 55 | tf.reduce_min(embeddings_feature_map_first_dims) 56 | ) 57 | ) 58 | tf.summary.image('summary/embeddings_encoded', embeddings_feature_map_first_dims, max_outputs=1) 59 | 60 | return embeddings_feature_map 61 | 62 | -------------------------------------------------------------------------------- /exps/Cini/cini_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from scipy.misc import imread, imsave, imresize 4 | import cv2 5 | import numpy as np 6 | from cini_post_processing import cini_post_processing_fn 7 | from doc_seg.utils import load_pickle 8 | import pandas as pd 9 | 10 | 11 | def cini_evaluate_folder(output_folder: str, validation_dir: str, verbose=False, debug_folder=None) -> dict: 12 | filenames_processed = glob(os.path.join(output_folder, '*.pkl')) 13 | 14 | if debug_folder is not None: 15 | os.makedirs(debug_folder, exist_ok=True) 16 | 17 | iou_cardboards = [] 18 | iou_images = [] 19 | basenames = [] 20 | for filename in filenames_processed: 21 | basename = os.path.basename(filename).split('.')[0] 22 | 23 | data = load_pickle(filename) 24 | 25 | cardboard_coords, image_coords, shape = data['cardboard_rectangle'], data['image_rectangle'], data['shape'] 26 | 27 | # Open label image 28 | label_path = os.path.join(validation_dir, 'labels', '{}.png'.format(basename)) 29 | if not os.path.exists(label_path): 30 | label_path = label_path.replace('.png', '.jpg') 31 | label_image = imread(label_path, mode='RGB') 32 | label_image = imresize(label_image, shape[:2]) 33 | label_predictions = np.stack([ 34 | label_image[:, :, 0] > 250, 35 | label_image[:, :, 1] > 250, 36 | label_image[:, :, 2] > 250 37 | ], axis=-1).astype(np.float32) 38 | label_cardboard_coords, label_image_coords = cini_post_processing_fn(label_predictions, 39 | clean_predictions=False) 40 | 41 | # Compute errors 42 | def intersection_over_union(cnt1, cnt2): 43 | mask1 = np.zeros(shape, np.uint8) 44 | cv2.fillConvexPoly(mask1, cv2.boxPoints(cnt1).astype(np.int32), 1) 45 | mask2 = np.zeros(shape, np.uint8) 46 | cv2.fillConvexPoly(mask2, cv2.boxPoints(cnt2).astype(np.int32), 1) 47 | return np.sum(mask1 & mask2) / np.sum(mask1 | mask2) 48 | 49 | iou_cardboard = intersection_over_union(label_cardboard_coords, cardboard_coords) 50 | iou_image = intersection_over_union(label_image_coords, image_coords) 51 | 52 | iou_cardboards.append(iou_cardboard) 53 | iou_images.append(iou_image) 54 | basenames.append(basename) 55 | 56 | if debug_folder is not None: 57 | img_filename = os.path.join(validation_dir, 'images', '{}.jpg'.format(basename)) 58 | img = imresize(imread(img_filename), shape) 59 | cv2.polylines(img, cv2.boxPoints(cardboard_coords).astype(np.int32)[None], True, (255, 0, 0), 4) 60 | cv2.polylines(img, cv2.boxPoints(image_coords).astype(np.int32)[None], True, (0, 0, 255), 4) 61 | imsave(os.path.join(debug_folder, '{}.jpg'.format(basename)), img) 62 | 63 | result = { 64 | 'cardboard_mean_iou': np.mean(iou_cardboards), 65 | 'image_mean_iou': np.mean(iou_images), 66 | } 67 | 68 | if debug_folder is not None: 69 | df = pd.DataFrame(data=list(zip(basenames, iou_cardboards, iou_images)), 70 | columns=['basename', 'iou_cardboard', 'iou_image']) 71 | df = df.sort_values('iou_image', ascending=True) 72 | df.to_csv(os.path.join(debug_folder, 'scores.csv'), index=False) 73 | 74 | if verbose: 75 | print(result) 76 | return result 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/page/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from imageio import imread, imsave 6 | import numpy as np 7 | import cv2 8 | import os 9 | from tqdm import tqdm 10 | 11 | 12 | def get_coords_form_txt_line(line: str)-> tuple: 13 | """ 14 | gets the coordinates of the page from the txt file (line-wise) 15 | :param line: line of the .txt file 16 | :return: coordinates, filename 17 | """ 18 | splits = line.split(',') 19 | full_filename = splits[0] 20 | splits = splits[1:] 21 | if splits[-1] in ['SINGLE', 'ABNORMAL']: 22 | coords_simple = np.reshape(np.array(splits[:-1], dtype=int), (4, 2)) 23 | # coords_double = None 24 | coords = coords_simple 25 | else: 26 | coords_simple = np.reshape(np.array(splits[:8], dtype=int), (4, 2)) 27 | # coords_double = np.reshape(np.array(splits[-4:], dtype=int), (2, 2)) 28 | # coords = (coords_simple, coords_double) 29 | coords = coords_simple 30 | 31 | return coords, full_filename 32 | 33 | 34 | def make_binary_mask(txt_file): 35 | """ 36 | From export txt file with filnenames and coordinates of qudrilaterals, generate binary mask of page 37 | :param txt_file: txt file filename 38 | :return: 39 | """ 40 | for line in open(txt_file, 'r'): 41 | dirname, _ = os.path.split(txt_file) 42 | c, full_name = get_coords_form_txt_line(line) 43 | img = imread(full_name) 44 | label_img = np.zeros((img.shape[0], img.shape[1]), np.uint8) 45 | label_img = cv2.fillPoly(label_img, [c[:, None, :]], 255) 46 | basename = os.path.basename(full_name) 47 | imsave(os.path.join(dirname, '{}_bin.png'.format(basename.split('.')[0])), label_img) 48 | 49 | 50 | def page_dataset_generator(txt_filename: str, input_dir: str, output_dir: str): 51 | """ 52 | Given a txt file (filename, coords corners), generates a dataset of images + labels 53 | :param txt_filename: File (txt) containing list of images 54 | :param input_dir: Root directory to original images 55 | :param output_dir: Output directory for generated dataset 56 | :return: 57 | """ 58 | 59 | output_img_dir = os.path.join(output_dir, 'images') 60 | output_label_dir = os.path.join(output_dir, 'labels') 61 | os.makedirs(output_img_dir, exist_ok=True) 62 | os.makedirs(output_label_dir, exist_ok=True) 63 | 64 | for line in tqdm(open(txt_filename, 'r')): 65 | coords, full_filename = get_coords_form_txt_line(line) 66 | 67 | try: 68 | img = imread(os.path.join(input_dir, full_filename)) 69 | except FileNotFoundError: 70 | print('File {} not found'.format(full_filename)) 71 | continue 72 | label_img = np.zeros((img.shape[0], img.shape[1], 3)) 73 | 74 | label_img = cv2.fillPoly(label_img, [coords], (255, 0, 0)) 75 | # if coords_double is not None: 76 | # label_img = cv2.polylines(label_img, [coords_double], False, color=(0, 0, 0), thickness=50) 77 | 78 | col, filename = full_filename.split(os.path.sep)[-2:] 79 | 80 | imsave(os.path.join(output_img_dir, '{}_{}.jpg'.format(col.split('_')[0], filename.split('.')[0])), img) 81 | imsave(os.path.join(output_label_dir, '{}_{}.png'.format(col.split('_')[0], filename.split('.')[0])), label_img) 82 | 83 | # Class file 84 | classes = np.stack([(0, 0, 0), (255, 0, 0)]) 85 | np.savetxt(os.path.join(output_dir, 'classes.txt'), classes, fmt='%d') 86 | -------------------------------------------------------------------------------- /exps/page/process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import tensorflow as tf 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from glob import glob 10 | from dh_segment.loader import LoadedModel 11 | from imageio import imsave 12 | from dh_segment.post_processing import binarization 13 | from dh_segment.post_processing.boxes_detection import find_boxes 14 | 15 | 16 | def prediction_fn(model_dir: str, input_dir: str, output_dir: str=None) -> None: 17 | """ 18 | Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir. 19 | The predictions will be saved in output_dir as .npy files (values ranging [0,255]) 20 | :param model_dir: Directory containing the saved model 21 | :param input_dir: input directory where the images to predict are 22 | :param output_dir: output directory to save the predictions (probability images) 23 | :return: 24 | """ 25 | if not output_dir: 26 | # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions' 27 | output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions']) 28 | 29 | os.makedirs(output_dir, exist_ok=True) 30 | filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) + glob(os.path.join(input_dir, '*.png')) 31 | # Load model 32 | with tf.Session(): 33 | m = LoadedModel(model_dir, predict_mode='filename_original_shape') 34 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 35 | pred = m.predict(filename)['probs'][0] 36 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), np.uint8(255 * pred)) 37 | 38 | 39 | def page_post_processing_fn(probs: np.ndarray, threshold: float=0.5, output_basename: str=None, 40 | kernel_size: int = 5) -> np.ndarray: 41 | """ 42 | Computes the binary mask of the detected Page from the probabilities outputed by network 43 | :param probs: array in range [0, 1] of shape HxWx2 44 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 45 | :param output_basename: 46 | :param kernel_size: size of kernel for morphological cleaning 47 | """ 48 | 49 | mask = binarization.thresholding(probs[:, :, 1], threshold=threshold) 50 | result = binarization.cleaning_binary(mask, size=kernel_size) 51 | 52 | if output_basename is not None: 53 | imsave('{}.png'.format(output_basename), result*255) 54 | return result 55 | 56 | 57 | def format_quad_to_string(quad): 58 | s = '' 59 | for corner in quad: 60 | s += '{},{},'.format(corner[0], corner[1]) 61 | return s[:-1] 62 | 63 | 64 | def extract_page(prediction: np.ndarray, min_area: float=0.2, post_process_params: dict=None) -> list(): 65 | """ 66 | Given an image with probabilities, post-processes it and extracts one box 67 | :param prediction: probability mask [0, 1] 68 | :param min_area: minimum area to be considered as a valid extraction 69 | :param post_process_params: params for page prost processing function 70 | :return: list of coordinates of boxe 71 | """ 72 | if post_process_params: 73 | post_pred = page_post_processing_fn(prediction, **post_process_params) 74 | else: 75 | post_pred = prediction 76 | pred_box = find_boxes(np.uint8(post_pred), mode='quadrilateral', min_area=min_area, n_max_boxes=1) 77 | 78 | return pred_box 79 | -------------------------------------------------------------------------------- /exps/cbad/example_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from scipy.misc import imread\n", 13 | "from tqdm import tqdm\n", 14 | "from .evaluation import eval_fn\n", 15 | "from .process import prediction_fn, extract_lines" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "model_dirs_list = ['model1/export/timestamp/', \n", 25 | " 'model2/export/timestamp/']" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "set_dir = './baseline_dataset/images/'" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Prediction" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "for model_dir in model_dirs_list:\n", 51 | " output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions'])\n", 52 | " prediction_fn(model_dir, set_dir, output_dir)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Evaluation" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "CBAD_JAR = './TranskribusBaseLineEvaluationScheme_v0.1.3/' \\\n", 78 | " 'TranskribusBaseLineEvaluationScheme-0.1.3-jar-with-dependencies.jar'\n", 79 | "gt_dir = './dataset/test/gt/'\n", 80 | "pred_dir_list = ['./model1/preds_test/',\n", 81 | " './model2/preds_test/']" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "post_process_params = {'sigma': 1.5,\n", 91 | " 'low_threshold': 0.2,\n", 92 | " 'high_threshold': 0.4}" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "list_results = list()\n", 102 | "for pred_dir in pred_dir_list:\n", 103 | " list_results.append(eval_fn(pred_dir, gt_dir, pred_dir, post_process_params, CBAD_JAR))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 2", 119 | "language": "python", 120 | "name": "python2" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 2.0 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython2", 132 | "version": "2.7.6" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 0 137 | } -------------------------------------------------------------------------------- /doc/intro/intro.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Introduction 3 | ============ 4 | 5 | What is dhSegment? 6 | ------------------ 7 | 8 | .. image:: ../_static/system.png 9 | :width: 60 % 10 | :align: center 11 | :alt: dhSegment system 12 | 13 | 14 | dhSegment is a generic approach for Historical Document Processing. 15 | It relies on a Convolutional Neural Network to do the heavy lifting of predicting pixelwise characteristics. 16 | Then simple image processing operations are provided to extract the components of interest (boxes, polygons, lines, masks, ...) 17 | 18 | A few key facts: 19 | 20 | - You only need to provide a list of images with annotated masks, which can easily be created with an image editing software (Gimp, Photoshop). You only need to draw the elements you care about! 21 | 22 | - Allows to classify each pixel across multiple classes, with the possibility of assigning multiple labels per pixel. 23 | 24 | - On-the-fly data augmentation, and efficient batching of batches. 25 | 26 | - Leverages a state-of-the-art pre-trained network (Resnet50) to lower the need for training data and improve generalization. 27 | 28 | - Monitor training on Tensorboard very easily. 29 | 30 | - A list of simple image processing operations are already implemented such that the post-processing steps only take a couple of lines. 31 | 32 | What sort of training data do I need? 33 | --------------------------------------- 34 | 35 | Each training sample consists in an image of a document and its corresponding parts to be predicted. 36 | 37 | .. image:: ../_static/cini_input.jpg 38 | :width: 45 % 39 | :alt: example image input 40 | .. image:: ../_static/cini_labels.jpg 41 | :width: 45 % 42 | :alt: example label 43 | 44 | Additionally, a text file encoding the RGB values of the classes needs to be provided. 45 | In this case if we want the classes 'background', 'document' and 'photograph' to be respectively 46 | classes 0, 1, and 2 we need to encode their color line-by-line: :: 47 | 48 | 0 255 0 49 | 255 0 0 50 | 0 0 255 51 | 52 | .. _usecases-label: 53 | 54 | Use cases 55 | --------- 56 | 57 | Page Segmentation 58 | ^^^^^^^^^^^^^^^^^ 59 | 60 | .. image:: ../_static/page.jpg 61 | :width: 50 % 62 | :alt: page extraction use case 63 | 64 | Dataset : READ-BAD :cite:`gruning2018read` annotated by :cite:`tensmeyer2017pagenet`. 65 | 66 | 67 | Layout Analysis 68 | ^^^^^^^^^^^^^^^ 69 | 70 | .. image:: ../_static/diva.jpg 71 | :width: 45 % 72 | :alt: diva use case 73 | .. image:: ../_static/diva_preds.png 74 | :width: 45 % 75 | :alt: diva predictions use case 76 | 77 | Dataset : DIVA-HisDB :cite:`simistira2016diva`. 78 | 79 | Ornament Extraction 80 | ^^^^^^^^^^^^^^^^^^^ 81 | 82 | .. image:: ../_static/ornaments.jpg 83 | :width: 50 % 84 | :alt: ornaments use case 85 | 86 | Dataset : BCU collection. 87 | 88 | 89 | Line Detection 90 | ^^^^^^^^^^^^^^ 91 | 92 | .. image:: ../_static/cbad.jpg 93 | :width: 70 % 94 | :alt: line extraction use case 95 | 96 | Dataset : READ-BAD :cite:`gruning2018read`. 97 | 98 | 99 | Document Segmentation 100 | ^^^^^^^^^^^^^^^^^^^^^ 101 | 102 | .. image:: ../_static/cini.jpg 103 | :width: 70 % 104 | :alt: cini photo collection extraction use case 105 | 106 | Dataset : Photo-collection from the Cini Foundation. 107 | 108 | 109 | Tensorboard Integration 110 | ----------------------- 111 | The TensorBoard integration allows to visualize your TensorFlow graph, plot metrics 112 | and show the images and predictions during the execution of the graph. 113 | 114 | .. image:: ../_static/tensorboard_1.png 115 | :width: 65 % 116 | :alt: tensorboard example 1 117 | .. image:: ../_static/tensorboard_2.png 118 | :width: 65 % 119 | :alt: tensorboard example 2 120 | .. image:: ../_static/tensorboard_3.png 121 | :width: 65 % 122 | :alt: tensorboard example 3 -------------------------------------------------------------------------------- /dh_segment_text/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import numpy as np 6 | import json 7 | import cv2 8 | 9 | 10 | class Metrics: 11 | def __init__(self): 12 | self.total_elements = 0 13 | self.true_positives = 0 14 | self.true_negatives = 0 15 | self.false_positives = 0 16 | self.false_negatives = 0 17 | self.SE_list = list() 18 | self.IOU_list = list() 19 | 20 | self.MSE = 0 21 | self.psnr = 0 22 | self.mIOU = 0 23 | self.IU = 0 24 | self.accuracy = 0 25 | self.recall = 0 26 | self.precision = 0 27 | self.f_measure = 0 28 | 29 | def __add__(self, other): 30 | if isinstance(other, self.__class__): 31 | summable_attr = ['total_elements', 'false_negatives', 'false_positives', 'true_positives', 'true_negatives'] 32 | addlist_attr = ['SE_list', 'IOU_list'] 33 | m = Metrics() 34 | for k, v in self.__dict__.items(): 35 | if k in summable_attr: 36 | setattr(m, k, self.__dict__[k] + other.__dict__[k]) 37 | elif k in addlist_attr: 38 | mse1 = [self.__dict__[k]] if not isinstance(self.__dict__[k], list) else self.__dict__[k] 39 | mse2 = [other.__dict__[k]] if not isinstance(other.__dict__[k], list) else other.__dict__[k] 40 | 41 | setattr(m, k, mse1 + mse2) 42 | return m 43 | else: 44 | raise NotImplementedError 45 | 46 | def __radd__(self, other): 47 | return self.__add__(other) 48 | 49 | def compute_mse(self): 50 | self.MSE = np.sum(self.SE_list) / self.total_elements if self.total_elements > 0 else np.inf 51 | return self.MSE 52 | 53 | def compute_psnr(self): 54 | if self.MSE != 0: 55 | self.psnr = 10 * np.log10((1 ** 2) / self.MSE) 56 | return self.psnr 57 | else: 58 | print('Cannot compute PSNR, MSE is 0.') 59 | 60 | def compute_prf(self, beta=1): 61 | self.recall = self.true_positives / (self.true_positives + self.false_negatives) \ 62 | if (self.true_positives + self.false_negatives) > 0 else 0 63 | self.precision = self.true_positives / (self.true_positives + self.false_positives) \ 64 | if (self.true_positives + self.false_negatives) > 0 else 0 65 | self.f_measure = ((1 + beta ** 2) * self.recall * self.precision) / (self.recall + (beta ** 2) * self.precision) \ 66 | if (self.recall + self.precision) > 0 else 0 67 | 68 | return self.recall, self.precision, self.f_measure 69 | 70 | def compute_miou(self): 71 | self.mIOU = np.mean(self.IOU_list) 72 | return self.mIOU 73 | 74 | # See http://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/images/ConferencePapers/2017/DocUsingDeepFeatures.pdf 75 | def compute_iu(self): 76 | self.IU = self.true_positives / (self.true_positives + self.false_positives + self.false_negatives) \ 77 | if (self.true_positives + self.false_positives + self.false_negatives) > 0 else 0 78 | return self.IU 79 | 80 | def compute_accuracy(self): 81 | self.accuracy = (self.true_positives + self.true_negatives)/self.total_elements if self.total_elements > 0 else 0 82 | 83 | def save_to_json(self, json_filename: str) -> None: 84 | export_dic = self.__dict__.copy() 85 | del export_dic['MSE_list'] 86 | 87 | with open(json_filename, 'w') as outfile: 88 | json.dump(export_dic, outfile) 89 | 90 | 91 | def intersection_over_union(cnt1, cnt2, shape_mask): 92 | mask1 = np.zeros(shape_mask, np.uint8) 93 | mask1 = cv2.fillConvexPoly(mask1, cnt1.astype(np.int32), 1).astype(np.int8) 94 | mask2 = np.zeros(shape_mask, np.uint8) 95 | mask2 = cv2.fillConvexPoly(mask2, cnt2.astype(np.int32), 1).astype(np.int8) 96 | return np.sum(mask1 & mask2) / np.sum(mask1 | mask2) 97 | -------------------------------------------------------------------------------- /dh_segment_text/network/pretrained_models/vgg16.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib import slim, layers 2 | import tensorflow as tf 3 | from tensorflow.contrib.slim import nets 4 | import numpy as np 5 | from ..model import Encoder 6 | import os 7 | import tarfile 8 | from ...utils.misc import get_data_folder, download_file 9 | 10 | _VGG_MEANS = [123.68, 116.78, 103.94] 11 | 12 | 13 | def mean_substraction(input_tensor, means=_VGG_MEANS): 14 | return tf.subtract(input_tensor, np.array(means)[None, None, None, :], name='MeanSubstraction') 15 | 16 | 17 | class VGG16(Encoder): 18 | """VGG-16 implementation 19 | 20 | :ivar blocks: number of blocks (vgg blocks) 21 | :ivar weight_decay: weight decay value 22 | :ivar pretrained_file: path to the file (.ckpt) containing the pretrained weights 23 | """ 24 | def __init__(self, blocks: int=5, weight_decay: float=0.0005): 25 | self.blocks = blocks 26 | self.weight_decay = weight_decay 27 | self.pretrained_file = os.path.join(get_data_folder(), 'vgg_16.ckpt') 28 | if not os.path.exists(self.pretrained_file): 29 | print("Could not find pre-trained file {}, downloading it!".format(self.pretrained_file)) 30 | tar_filename = os.path.join(get_data_folder(), 'vgg_16.tar.gz') 31 | download_file('http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz', tar_filename) 32 | tar = tarfile.open(tar_filename) 33 | tar.extractall(path=get_data_folder()) 34 | tar.close() 35 | os.remove(tar_filename) 36 | assert os.path.exists(self.pretrained_file) 37 | print('Pre-trained weights downloaded!') 38 | 39 | def pretrained_information(self): 40 | return self.pretrained_file, [v for v in tf.global_variables() 41 | if 'vgg_16' in v.name 42 | and 'renorm' not in v.name] 43 | 44 | def __call__(self, images: tf.Tensor, is_training=False): 45 | outputs = [] 46 | 47 | with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay=self.weight_decay)): 48 | with tf.variable_scope(None, 'vgg_16', [images]) as sc: 49 | input_tensor = mean_substraction(images) 50 | outputs.append(input_tensor) 51 | end_points_collection = sc.original_name_scope + '_end_points' 52 | # Collect outputs for conv2d, fully_connected and max_pool2d. 53 | with slim.arg_scope( 54 | [layers.conv2d, layers.fully_connected, layers.max_pool2d], 55 | outputs_collections=end_points_collection): 56 | net = layers.repeat( 57 | input_tensor, 2, layers.conv2d, 64, [3, 3], scope='conv1') 58 | net = layers.max_pool2d(net, [2, 2], scope='pool1') 59 | outputs.append(net) 60 | if self.blocks >= 2: 61 | net = layers.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2') 62 | net = layers.max_pool2d(net, [2, 2], scope='pool2') 63 | outputs.append(net) 64 | if self.blocks >= 3: 65 | net = layers.repeat(net, 3, layers.conv2d, 256, [3, 3], scope='conv3') 66 | net = layers.max_pool2d(net, [2, 2], scope='pool3') 67 | outputs.append(net) 68 | if self.blocks >= 4: 69 | net = layers.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv4') 70 | net = layers.max_pool2d(net, [2, 2], scope='pool4') 71 | outputs.append(net) 72 | if self.blocks >= 5: 73 | net = layers.repeat(net, 3, layers.conv2d, 512, [3, 3], scope='conv5') 74 | net = layers.max_pool2d(net, [2, 2], scope='pool5') 75 | outputs.append(net) 76 | 77 | return outputs 78 | -------------------------------------------------------------------------------- /exps/diva/example_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import numpy as np\n", 13 | "from tqdm import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from dh_segment.loader import LoadedModel\n", 23 | "from .process import prediction_fn" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "model_dirs_list = ['diva_0/export/timestamp/', \n", 33 | " 'diva_1/export/timestamp/']\n", 34 | "input_dir = 'diva_test_set/images/'\n", 35 | "output_dir_name = 'out_predictions'" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Prediction" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "for model_dir in model_dirs_list:\n", 52 | " output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + [output_dir_name])\n", 53 | " os.makedirs(output_dir, exist_ok=True)\n", 54 | " prediction_fn(model_dir, input_dir, output_dir)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Evaluation" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "from .evaluation import eval_fn" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "pred_dir_list = [os.path.abspath(os.path.join(md, '..', '..', output_dir_name)) \n", 89 | " for md in model_dirs_list]\n", 90 | "gt_dir = 'diva_test_set/pixel-level-gt/'\n", 91 | "DIVA_JAR = 'DIVA_Layout_Analysis_Evaluator/out/artifacts/LayoutAnalysisEvaluator.jar'" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "list_mIUs = list()\n", 101 | "for i, pred_dir in enumerate(pred_dir_list):\n", 102 | " output_txt_filename = './{i}_diva_results.txt'\n", 103 | " mean_mIU = eval_fn(pred_dir, gt_dir, output_txt_filename, diva_jar=DIVA_JAR)\n", 104 | " list_mIUs.append(mean_mIU)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "print('MIOU : {:.03} +- {:.03} ([{:.03}, {:.03}])'.format(np.mean(list_mIUs), \n", 114 | " np.std(list_mIUs), \n", 115 | " np.min(list_mIUs), \n", 116 | " np.max(list_mIUs)))" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "" 126 | ] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "Python 2", 132 | "language": "python", 133 | "name": "python2" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 2.0 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython2", 145 | "version": "2.7.6" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 0 150 | } -------------------------------------------------------------------------------- /exps/page/example_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "from glob import glob\n", 13 | "from .process import prediction_fn" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "model_dirs_list = ['model1/export/timestamp/', \n", 23 | " 'model2/export/timestamp/']\n", 24 | "input_dir = 'dataset_page/set/images/'\n", 25 | "output_dir_name = 'out_predictions'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Predictions" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "for model_dir in model_dirs_list:\n", 42 | " output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + [output_dir_name])\n", 43 | " os.makedirs(output_dir, exist_ok=True)\n", 44 | " prediction_fn(model_dir, input_dir, output_dir)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Evaluation" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from .evaluation import eval_fn\n", 70 | "import numpy as np" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "pred_dir_list = [os.path.abspath(os.path.join(md, '..', '..', output_dir_name)) \n", 80 | " for md in model_dirs_list]\n", 81 | "gt_dir = 'dataset_page/set/labels/'" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "list_metrics = list()\n", 91 | "for pred_dir in pred_dir_list:\n", 92 | " metrics = eval_fn(pred_dir, gt_dir)\n", 93 | " list_metrics.append(metrics)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "list_mIOUs = [m.mIOU for m in list_metrics]\n", 103 | "\n", 104 | "print('MIOU : {:.03} +- {:.03} ([{:.03}, {:.03}])'.format(np.mean(list_mIOUs), \n", 105 | " np.std(list_mIOUs), \n", 106 | " np.min(list_mIOUs), \n", 107 | " np.max(list_mIOUs)))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Export" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import json\n", 124 | "export_metric_filename = './metrics_page.json'" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "with open(export_metric_filename, 'w', encoding='utf8') as f:\n", 134 | " json.dump({'{}'.format(i+1): vars(m) for i, m in enumerate(list_metrics)}, f, indent=4)" 135 | ] 136 | } 137 | ], 138 | "metadata": { 139 | "kernelspec": { 140 | "display_name": "Python 2", 141 | "language": "python", 142 | "name": "python2" 143 | }, 144 | "language_info": { 145 | "codemirror_mode": { 146 | "name": "ipython", 147 | "version": 2.0 148 | }, 149 | "file_extension": ".py", 150 | "mimetype": "text/x-python", 151 | "name": "python", 152 | "nbconvert_exporter": "python", 153 | "pygments_lexer": "ipython2", 154 | "version": "2.7.6" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 0 159 | } -------------------------------------------------------------------------------- /exps/cbad/evaluation.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import subprocess 4 | from glob import glob 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from dh_segment.post_processing import PAGE 8 | from .process import extract_lines 9 | 10 | CBAD_JAR = './cBAD/TranskribusBaseLineEvaluationScheme_v0.1.3/' \ 11 | 'TranskribusBaseLineEvaluationScheme-0.1.3-jar-with-dependencies.jar' 12 | PP_PARAMS = post_process_params = {'sigma': 1.5, 'low_threshold': 0.2, 'high_threshold': 0.4} 13 | 14 | 15 | def eval_fn(input_dir: str, groudtruth_dir: str, output_dir: str=None, post_process_params: dict=PP_PARAMS, 16 | jar_tool_path: str=CBAD_JAR, masks_dir: str=None) -> dict: 17 | """ 18 | 19 | :param input_dir: Input directory containing probability maps (.npy) 20 | :param groudtruth_dir: directory containg XML groundtruths 21 | :param output_dir: output directory for results 22 | :param post_process_params: parameters form post processing of probability maps 23 | :param jar_tool_path: path to cBAD evaluation tool (.jar file) 24 | :param masks_dir: optional, directory where binary masks of the page are stored (.png) 25 | :return: 26 | """ 27 | 28 | if output_dir is None: 29 | output_dir = input_dir 30 | 31 | # Apply post processing and find lines 32 | for file in tqdm(glob(os.path.join(input_dir, '*.npy'))): 33 | basename = os.path.basename(file).split('.')[0] 34 | gt_xml_filename = os.path.join(groudtruth_dir, basename + '.xml') 35 | gt_page_xml = PAGE.parse_file(gt_xml_filename) 36 | 37 | original_shape = [gt_page_xml.image_height, gt_page_xml.image_width] 38 | 39 | _, _ = extract_lines(file, output_dir, original_shape, **post_process_params, mask_dir=masks_dir) 40 | 41 | # Create pairs predicted XML - groundtruth XML to be evaluated 42 | xml_pred_filenames_list = glob(os.path.join(output_dir, '*.xml')) 43 | xml_filenames_tuples = list() 44 | for xml_filename in xml_pred_filenames_list: 45 | basename = os.path.basename(xml_filename) 46 | gt_xml_filename = os.path.join(groudtruth_dir, basename) 47 | 48 | xml_filenames_tuples.append((gt_xml_filename, xml_filename)) 49 | 50 | gt_pages_list_filename = os.path.join(output_dir, 'gt_pages_simple.lst') 51 | generated_pages_list_filename = os.path.join(output_dir, 'generated_pages_simple.lst') 52 | with open(gt_pages_list_filename, 'w') as f: 53 | f.writelines('\n'.join([s[0] for s in xml_filenames_tuples])) 54 | with open(generated_pages_list_filename, 'w') as f: 55 | f.writelines('\n'.join([s[1] for s in xml_filenames_tuples])) 56 | 57 | # Evaluation using JAVA Tool 58 | cmd = 'java -jar {} {} {}'.format(jar_tool_path, gt_pages_list_filename, generated_pages_list_filename) 59 | result = subprocess.check_output(cmd, shell=True).decode() 60 | with open(os.path.join(output_dir, 'scores.txt'), 'w') as f: 61 | f.write(result) 62 | parse_score_txt(result, os.path.join(output_dir, 'scores.csv')) 63 | 64 | # Parse results from output of tool 65 | lines = result.splitlines() 66 | avg_precision = float(next(filter(lambda l: 'Avg (over pages) P value:' in l, lines)).split()[-1]) 67 | avg_recall = float(next(filter(lambda l: 'Avg (over pages) R value:' in l, lines)).split()[-1]) 68 | f_measure = float(next(filter(lambda l: 'Resulting F_1 value:' in l, lines)).split()[-1]) 69 | 70 | print('P {}, R {}, F {}'.format(avg_precision, avg_recall, f_measure)) 71 | 72 | return { 73 | 'avg_precision': avg_precision, 74 | 'avg_recall': avg_recall, 75 | 'f_measure': f_measure 76 | } 77 | 78 | 79 | def parse_score_txt(score_txt, output_csv): 80 | lines = score_txt.splitlines() 81 | header_ind = next((i for i, l in enumerate(lines) 82 | if l == '#P value, #R value, #F_1 value, #TruthFileName, #HypoFileName')) 83 | final_line = next((i for i, l in enumerate(lines) if i > header_ind and l == '')) 84 | csv_data = '\n'.join(lines[header_ind:final_line]) 85 | df = pd.read_csv(io.StringIO(csv_data)) 86 | df = df.rename(columns={k: k.strip() for k in df.columns}) 87 | df['#HypoFileName'] = [os.path.basename(f).split('.')[0] for f in df['#HypoFileName']] 88 | del df['#TruthFileName'] 89 | df = df.rename(columns={'#P value': 'P', '#R value': 'R', '#F_1 value': 'F_1', '#HypoFileName': 'basename'}) 90 | df = df.reindex(columns=['basename', 'F_1', 'P', 'R']) 91 | df = df.sort_values('F_1', ascending=True) 92 | df.to_csv(output_csv, index=False) 93 | -------------------------------------------------------------------------------- /exps/_misc/layout_generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from doc_seg_datasets import PAGE 3 | import cv2 4 | from scipy.misc import imread, imsave 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from glob import glob 8 | from tqdm import tqdm 9 | 10 | 11 | TARGET_WIDTH = 1200 12 | INPUT_DIR = '/home/seguin/document_datasets/layout_analysis' 13 | OUTPUT_DIR = '/scratch/benoit/layout_analysis_1200' 14 | 15 | 16 | def save_and_resize(img, filename, nearest=False): 17 | resized = cv2.resize(img, (TARGET_WIDTH, (img.shape[0]*TARGET_WIDTH)//img.shape[1]), 18 | interpolation=cv2.INTER_NEAREST if nearest else cv2.INTER_LINEAR) 19 | imsave(filename, resized) 20 | 21 | 22 | def process_one(image_dir, page_dir, output_dir, basename, colormap, color_labels): 23 | image_filename = os.path.join(image_dir, "{}.jpg".format(basename)) 24 | page_filename = os.path.join(page_dir, "{}.xml".format(basename)) 25 | 26 | page = PAGE.parse_file(page_filename) 27 | text_lines = [tl for tr in page.text_regions for tl in tr.text_lines] 28 | graphic_regions = page.graphic_regions 29 | img = imread(image_filename, mode='RGB') 30 | 31 | gt = np.zeros_like(img[:, :, 0]) 32 | mask1 = cv2.fillPoly(gt.copy(), [PAGE.Point.list_to_cv2poly(tl.coords) 33 | for tl in text_lines if 'comment' in tl.id], 1) 34 | mask2 = cv2.fillPoly(gt.copy(), [PAGE.Point.list_to_cv2poly(tl.coords) 35 | for tl in text_lines if not 'comment' in tl.id], 1) 36 | mask3 = cv2.fillPoly(gt.copy(), [PAGE.Point.list_to_cv2poly(tl.coords) 37 | for tl in graphic_regions], 1) 38 | arr = np.dstack([mask1, mask2, mask3]) 39 | 40 | gt_img = convert_array_masks(arr, colormap, color_labels) 41 | save_and_resize(img, os.path.join(output_dir, 'images', '{}.jpg'.format(basename))) 42 | save_and_resize(gt_img, os.path.join(output_dir, 'labels', '{}.png'.format(basename)), nearest=True) 43 | 44 | 45 | def make_cmap(N_CLASSES): 46 | # Generate the colors for the classes (with background class being 0,0,0) 47 | c_size = 2**N_CLASSES - 1 48 | cmap = np.concatenate([[[0, 0, 0]], plt.cm.Set1(np.arange(c_size) / (c_size))[:, :3]]) 49 | cmap = (cmap * 255).astype(np.uint8) 50 | assert N_CLASSES <= 8, "ARGH!! can not handle more than 8 classes" 51 | c_full_label = np.unpackbits(np.arange(2 ** N_CLASSES).astype(np.uint8)[:, None], axis=-1)[:, -N_CLASSES:] 52 | return cmap, c_full_label 53 | 54 | 55 | def convert_array_masks(arr, cmap, c_full_label): 56 | N_CLASSES = arr.shape[-1] 57 | c = np.zeros((2,) * N_CLASSES, np.int32) 58 | for i, inds in enumerate(c_full_label): 59 | c[tuple(inds)] = i 60 | c_ind = c[[arr[:, :, i] for i in range(arr.shape[-1])]] 61 | return cmap[c_ind] 62 | 63 | 64 | def save_cmap_to_txt(filename, cmap, c_full_label): 65 | np.savetxt(filename, np.concatenate([cmap, c_full_label], axis=1), fmt='%i') 66 | 67 | 68 | colormap, color_labels = make_cmap(3) 69 | 70 | train_basenames = [os.path.basename(p)[:-4] for p in glob('{}/img/training/*.jpg'.format(INPUT_DIR))] 71 | os.makedirs('{}/train/images'.format(OUTPUT_DIR)) 72 | os.makedirs('{}/train/labels'.format(OUTPUT_DIR)) 73 | for basename in tqdm(train_basenames): 74 | process_one(os.path.join(INPUT_DIR, 'img', 'training'), 75 | os.path.join(INPUT_DIR, 'PAGE-gt', 'training'), 76 | os.path.join(OUTPUT_DIR, 'train'), 77 | basename, colormap, color_labels) 78 | 79 | 80 | val_basenames = [os.path.basename(p)[:-4] for p in glob('{}/img/validation/*.jpg'.format(INPUT_DIR))] 81 | os.makedirs('{}/eval/images'.format(OUTPUT_DIR)) 82 | os.makedirs('{}/eval/labels'.format(OUTPUT_DIR)) 83 | for basename in tqdm(val_basenames): 84 | process_one(os.path.join(INPUT_DIR, 'img', 'validation'), 85 | os.path.join(INPUT_DIR, 'PAGE-gt', 'validation'), 86 | os.path.join(OUTPUT_DIR, 'eval'), 87 | basename, colormap, color_labels) 88 | 89 | 90 | test_basenames = [os.path.basename(p)[:-4] for p in glob('{}/img/public-test/*.jpg'.format(INPUT_DIR))] 91 | os.makedirs('{}/test/images'.format(OUTPUT_DIR)) 92 | os.makedirs('{}/test/labels'.format(OUTPUT_DIR)) 93 | for basename in tqdm(test_basenames): 94 | process_one(os.path.join(INPUT_DIR, 'img', 'public-test'), 95 | os.path.join(INPUT_DIR, 'PAGE-gt', 'public-test'), 96 | os.path.join(OUTPUT_DIR, 'test'), 97 | basename, colormap, color_labels) 98 | 99 | save_cmap_to_txt(os.path.join(OUTPUT_DIR, 'train', 'classes.txt'), colormap, color_labels) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from glob import glob 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | from imageio import imread, imsave 10 | from tqdm import tqdm 11 | 12 | from dh_segment.io import PAGE 13 | from dh_segment.inference import LoadedModel 14 | from dh_segment.post_processing import boxes_detection, binarization 15 | 16 | # To output results in PAGE XML format (http://www.primaresearch.org/schema/PAGE/gts/pagecontent/2013-07-15/) 17 | PAGE_XML_DIR = './page_xml' 18 | 19 | 20 | def page_make_binary_mask(probs: np.ndarray, threshold: float=-1) -> np.ndarray: 21 | """ 22 | Computes the binary mask of the detected Page from the probabilities outputed by network 23 | :param probs: array with values in range [0, 1] 24 | :param threshold: threshold between [0 and 1], if negative Otsu's adaptive threshold will be used 25 | :return: binary mask 26 | """ 27 | 28 | mask = binarization.thresholding(probs, threshold) 29 | mask = binarization.cleaning_binary(mask, kernel_size=5) 30 | return mask 31 | 32 | 33 | def format_quad_to_string(quad): 34 | """ 35 | Formats the corner points into a string. 36 | :param quad: coordinates of the quadrilateral 37 | :return: 38 | """ 39 | s = '' 40 | for corner in quad: 41 | s += '{},{},'.format(corner[0], corner[1]) 42 | return s[:-1] 43 | 44 | 45 | if __name__ == '__main__': 46 | 47 | # If the model has been trained load the model, otherwise use the given model 48 | model_dir = 'demo/page_model/export' 49 | if not os.path.exists(model_dir): 50 | model_dir = 'demo/model/' 51 | 52 | input_files = glob('demo/pages/test_a1/images/*') 53 | 54 | output_dir = 'demo/processed_images' 55 | os.makedirs(output_dir, exist_ok=True) 56 | # PAGE XML format output 57 | output_pagexml_dir = os.path.join(output_dir, PAGE_XML_DIR) 58 | os.makedirs(output_pagexml_dir, exist_ok=True) 59 | 60 | # Store coordinates of page in a .txt file 61 | txt_coordinates = '' 62 | 63 | with tf.Session(): # Start a tensorflow session 64 | # Load the model 65 | m = LoadedModel(model_dir, predict_mode='filename') 66 | 67 | for filename in tqdm(input_files, desc='Processed files'): 68 | # For each image, predict each pixel's label 69 | prediction_outputs = m.predict(filename) 70 | probs = prediction_outputs['probs'][0] 71 | original_shape = prediction_outputs['original_shape'] 72 | probs = probs[:, :, 1] # Take only class '1' (class 0 is the background, class 1 is the page) 73 | probs = probs / np.max(probs) # Normalize to be in [0, 1] 74 | 75 | # Binarize the predictions 76 | page_bin = page_make_binary_mask(probs) 77 | 78 | # Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes) 79 | bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False), 80 | tuple(original_shape[::-1]), interpolation=cv2.INTER_NEAREST) 81 | 82 | # Find quadrilateral enclosing the page 83 | pred_page_coords = boxes_detection.find_boxes(bin_upscaled.astype(np.uint8, copy=False), 84 | mode='min_rectangle', n_max_boxes=1) 85 | 86 | # Draw page box on original image and export it. Add also box coordinates to the txt file 87 | original_img = imread(filename, pilmode='RGB') 88 | if pred_page_coords is not None: 89 | cv2.polylines(original_img, [pred_page_coords[:, None, :]], True, (0, 0, 255), thickness=5) 90 | # Write corners points into a .txt file 91 | txt_coordinates += '{},{}\n'.format(filename, format_quad_to_string(pred_page_coords)) 92 | else: 93 | print('No box found in {}'.format(filename)) 94 | basename = os.path.basename(filename).split('.')[0] 95 | imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), original_img) 96 | 97 | # Create page region and XML file 98 | page_border = PAGE.Border(coords=PAGE.Point.cv2_to_point_list(pred_page_coords[:, None, :])) 99 | page_xml = PAGE.Page(image_filename=filename, image_width=original_shape[1], image_height=original_shape[0], 100 | page_border=page_border) 101 | xml_filename = os.path.join(output_pagexml_dir, '{}.xml'.format(basename)) 102 | page_xml.write_to_file(xml_filename, creator_name='PageExtractor') 103 | 104 | # Save txt file 105 | with open(os.path.join(output_dir, 'pages.txt'), 'w') as f: 106 | f.write(txt_coordinates) 107 | -------------------------------------------------------------------------------- /exps/diva/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | import json 6 | import numpy as np 7 | from collections import OrderedDict 8 | from tqdm import tqdm 9 | from glob import glob 10 | from imageio import imread, imsave 11 | import os 12 | 13 | 14 | MAP_COLORS = OrderedDict([('background', ((0, 0, 1), (0, 0, 0))), 15 | ('comment', ((0, 0, 2), (0, 0, 255))), 16 | ('decoration', ((0, 0, 4), (0, 255, 0))), 17 | ('text', ((0, 0, 8), (255, 0, 0))), 18 | ('comment_deco', ((0, 0, 6), (0, 255, 255))), 19 | ('text_comment', ((0, 0, 10), (255, 0, 255))), 20 | ('text_deco', ((0, 0, 12), (255, 255, 0)))]) 21 | 22 | 23 | def parse_diva_tool_output(score_txt: str, output_json_filename: str=None)-> dict: 24 | """ 25 | This fn parses the output of JAR DIVA Evaluation tool 26 | :param score_txt: filename of txt score containing output of DIVA evaluation tool 27 | :param output_json_filename: filename to output the parsed result in json 28 | :return: dict containing the parsed results 29 | """ 30 | def process_hlp_fn(string): 31 | """ 32 | Processes format : R=0.64,0.52 33 | """ 34 | key, vals = string.split('=') 35 | vals = vals.split(',') 36 | return {key: float(vals[0]), key + '_fw': float(vals[1])}, key 37 | 38 | def process_hlp_per_class_format(string, measure_key): 39 | """ 40 | Processes format : '0.26|0.77|0.77|0.77 41 | '""" 42 | return {measure_key: [float(t) for t in string.split('|')]} 43 | 44 | lines = score_txt.splitlines() 45 | 46 | dic_results = {'Mean_IU': float(lines[0].split(' = ')[1])} 47 | measures = lines[1].split(' ') 48 | dic_results = {**dic_results, **{m.split('=')[0]: float(m.split('=')[1]) for m in measures[:2]}} 49 | for m in measures[2:5]: 50 | eq, tab = m.split('[') 51 | dic, key = process_hlp_fn(eq) 52 | dic_results = {**dic_results, **dic} 53 | dic_results = {**dic_results, **process_hlp_per_class_format(tab[:-1], key + '_per_class')} 54 | sp = measures[-1].split('[') 55 | dic, key = process_hlp_fn(sp[0]) 56 | dic_results = {**dic_results, **dic} 57 | dic_results = {**dic_results, **process_hlp_per_class_format(sp[1][:-6], key + '_per_class')} 58 | dic_results = {**dic_results, **process_hlp_per_class_format(sp[2][:-1], sp[1][-5:-1])} 59 | 60 | if output_json_filename is not None: 61 | with open(output_json_filename, 'w') as f: 62 | json.dump(dic_results, f) 63 | 64 | return dic_results 65 | 66 | 67 | def to_original_color_code(bin_prediction): 68 | """ 69 | (0,0,0) : Background 70 | (255,0,0) : Text 71 | (0,255,0) : decoration 72 | (0,0,255) : comment 73 | 74 | RGB=0x000008: main text body 75 | RGB=0x000004: decoration 76 | RGB=0x000002: comment 77 | RGB=0x000001: background 78 | RGB=0x00000A: main text body+comment 79 | RGB=0x00000C: main text body+decoration 80 | RGB=0x000006: comment +decoration 81 | 82 | :param bin_prediction: 83 | :return: 84 | """ 85 | pred_original_colors = np.zeros_like(bin_prediction) 86 | for key, colors in MAP_COLORS.items(): 87 | pred_original_colors[np.all(bin_prediction == colors[1], axis=-1)] = colors[0] 88 | 89 | return pred_original_colors 90 | 91 | 92 | def diva_dataset_generator(input_dir: str, output_dir: str): 93 | """ 94 | 95 | :param input_dir: Input directory containing images and PAGE files 96 | :param output_dir: Output directory to save images and labels 97 | :return: 98 | """ 99 | 100 | img_filenames = glob(os.path.join(input_dir, 'img', '*.jpg')) 101 | output_img_dir = os.path.join(output_dir, 'images') 102 | output_label_dir = os.path.join(output_dir, 'labels') 103 | 104 | def annotate_one(gt_image: np.array, map_colors: dict=MAP_COLORS): 105 | label_img = np.zeros_like(gt_image) 106 | for key, colors in map_colors.items(): 107 | label_img[np.all(gt_image == colors[0], axis=-1)] = colors[1] 108 | 109 | return label_img 110 | 111 | for filename in tqdm(img_filenames): 112 | img = imread(filename, pilmode='RGB') 113 | basename = os.path.basename(filename).split('.')[0] 114 | filename_label = os.path.join(input_dir, 'pixel-level-gt', '{}.png'.format(basename)) 115 | gt_img = imread(filename_label, pilmode='RGB') 116 | label_image = annotate_one(gt_img, MAP_COLORS) 117 | 118 | # Save 119 | imsave(os.path.join(output_img_dir, '{}.jpg'.format(basename)), img) 120 | imsave(os.path.join(output_label_dir, '{}.png'.format(basename)), label_image) 121 | -------------------------------------------------------------------------------- /exps/DIBCO/dibco_dataset_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | from glob import glob 6 | from scipy.misc import imread, imsave 7 | import numpy as np 8 | from tqdm import tqdm 9 | import argparse 10 | from typing import List 11 | 12 | RANDOM_SEED = 0 13 | np.random.seed(RANDOM_SEED) 14 | 15 | 16 | def get_img_filenames(directory: str) -> List[str]: 17 | directory = os.path.abspath(directory) 18 | year = int(directory.split(os.path.sep)[-1].split('-')[0][-4:]) 19 | if year == 2010: 20 | extension = ['*.jpg', '*.bmp', '*.tif'] 21 | elif year == 2011 or year == 2012 or year == 2014: 22 | extension = ['*.png'] 23 | elif year == 2013 or year == 2009: 24 | extension = ['*.bmp'] 25 | elif year == 2016: 26 | extension = ['*[!_gt].bmp'] 27 | else: 28 | raise NotImplementedError 29 | 30 | files = [glob(os.path.join(directory, ext)) for ext in extension] 31 | return [item for l in files for item in l] 32 | 33 | 34 | def get_gt_filename(img_filename: str) -> str: 35 | directory, basename = os.path.os.path.split(img_filename) 36 | year = int(directory.split(os.path.sep)[-1].split('-')[0][-4:]) 37 | if year == 2009: 38 | extension = '.tiff' 39 | elif year == 2011: 40 | extension = '_GT.tiff' 41 | elif year == 2012: 42 | extension = '_GT.tif' 43 | elif year == 2013 or year == 2014 or year == 2010: 44 | extension = '_estGT.tiff' 45 | elif year == 2016: 46 | extension = '_gt.bmp' 47 | else: 48 | raise NotImplementedError 49 | 50 | return os.path.join(directory, '{}{}'.format(basename.split('.')[0], extension)) 51 | 52 | 53 | def get_exported_image_basename(image_filename: str) -> str: 54 | # Get acronym followed by name of file 55 | directory, basename = os.path.split(image_filename) 56 | acronym = directory.split(os.path.sep)[-1].split('-')[0] 57 | return '{}_{}'.format(acronym, basename.split('.')[0]) 58 | 59 | 60 | def generate_one_tuple(img_filename: str, output_dir: str) -> None: 61 | image = imread(img_filename, mode='RGB') 62 | basename_export = get_exported_image_basename(img_filename) 63 | imsave(os.path.join(output_dir, 'images', '{}.jpg'.format(basename_export)), image) 64 | gt = np.zeros_like(image) 65 | gt[:, :, 0] = ~imread(get_gt_filename(img_filename), mode='L') 66 | imsave(os.path.join(output_dir, 'labels', '{}.png'.format(basename_export)), gt) 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('-i', '--input_dir', required=True, type=str, default=None, 72 | help='Input directory containing images and PAGE files') 73 | parser.add_argument('-o', '--output_dir', required=True, type=str, default=None, 74 | help='Output directory to save images and labels') 75 | parser.add_argument('-s', '--split', action='store_true', # default : false 76 | help='Split inputs into training and validation set') 77 | args = vars(parser.parse_args()) 78 | 79 | img_filenames = get_img_filenames(args.get('input_dir')) 80 | 81 | if args.get('split'): 82 | # Split data into training and validation set (0.9/0.1) 83 | train_inds = np.random.choice(len(img_filenames), size=int(0.9 * len(img_filenames)), replace=False) 84 | train_mask = np.zeros(len(img_filenames), dtype=np.bool_) 85 | train_mask[train_inds] = 1 86 | image_filenames_train = np.array(img_filenames)[train_mask] 87 | image_filenames_eval = np.array(img_filenames)[~train_mask] 88 | 89 | # Training set 90 | root_train_dir = os.path.join(args.get('output_dir'), 'train') 91 | os.makedirs(os.path.join(root_train_dir, 'images'), exist_ok=True) 92 | os.makedirs(os.path.join(root_train_dir, 'labels'), exist_ok=True) 93 | for img_filename in tqdm(image_filenames_train): 94 | generate_one_tuple(img_filename, root_train_dir) 95 | 96 | # Validation set 97 | root_val_dir = os.path.join(args.get('output_dir'), 'validation') 98 | os.makedirs(os.path.join(root_val_dir, 'images'), exist_ok=True) 99 | os.makedirs(os.path.join(root_val_dir, 'labels'), exist_ok=True) 100 | for img_filename in tqdm(image_filenames_eval): 101 | generate_one_tuple(img_filename, root_val_dir) 102 | else: 103 | root_test_dir = os.path.join(args.get('output_dir'), 'test') 104 | os.makedirs(os.path.join(root_test_dir, 'images'), exist_ok=True) 105 | os.makedirs(os.path.join(root_test_dir, 'labels'), exist_ok=True) 106 | for img_filename in tqdm(img_filenames): 107 | generate_one_tuple(img_filename, root_test_dir) 108 | 109 | # Class file 110 | classes = np.stack([(0, 0, 0), (255, 0, 0)]) 111 | np.savetxt(os.path.join(args.get('output_dir'), 'classes.txt'), classes, fmt='%d') -------------------------------------------------------------------------------- /dh_segment_text/utils/labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __license__ = "GPL" 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | from typing import Tuple 8 | 9 | 10 | def label_image_to_class(label_image: tf.Tensor, classes_file: str) -> tf.Tensor: 11 | classes_color_values = get_classes_color_from_file(classes_file) 12 | # Convert label_image [H,W,3] to the classes [H,W],int32 according to the classes [C,3] 13 | with tf.name_scope('LabelAssign'): 14 | if len(label_image.get_shape()) == 3: 15 | diff = tf.cast(label_image[:, :, None, :], tf.float32) - tf.constant(classes_color_values[None, None, :, :]) # [H,W,C,3] 16 | elif len(label_image.get_shape()) == 4: 17 | diff = tf.cast(label_image[:, :, :, None, :], tf.float32) - tf.constant( 18 | classes_color_values[None, None, None, :, :]) # [B,H,W,C,3] 19 | else: 20 | raise NotImplementedError('Length is : {}'.format(len(label_image.get_shape()))) 21 | 22 | pixel_class_diff = tf.reduce_sum(tf.square(diff), axis=-1) # [H,W,C] or [B,H,W,C] 23 | class_label = tf.argmin(pixel_class_diff, axis=-1) # [H,W] or [B,H,W] 24 | return class_label 25 | 26 | 27 | def class_to_label_image(class_label: tf.Tensor, classes_file: str) -> tf.Tensor: 28 | classes_color_values = get_classes_color_from_file(classes_file) 29 | return tf.gather(classes_color_values, tf.cast(class_label, dtype=tf.int32)) 30 | 31 | 32 | def multilabel_image_to_class(label_image: tf.Tensor, classes_file: str) -> tf.Tensor: 33 | """ 34 | Combines image annotations with classes info of the txt file to create the input label for the training. 35 | 36 | :param label_image: annotated image [H,W,Ch] or [B,H,W,Ch] (Ch = color channels) 37 | :param classes_file: the filename of the txt file containing the class info 38 | :return: [H,W,Cl] or [B,H,W,Cl] (Cl = number of classes) 39 | """ 40 | classes_color_values, colors_labels = get_classes_color_from_file_multilabel(classes_file) 41 | # Convert label_image [H,W,3] to the classes [H,W,C],int32 according to the classes [C,3] 42 | with tf.name_scope('LabelAssign'): 43 | if len(label_image.get_shape()) == 3: 44 | diff = tf.cast(label_image[:, :, None, :], tf.float32) - tf.constant(classes_color_values[None, None, :, :]) # [H,W,C,3] 45 | elif len(label_image.get_shape()) == 4: 46 | diff = tf.cast(label_image[:, :, :, None, :], tf.float32) - tf.constant( 47 | classes_color_values[None, None, None, :, :]) # [B,H,W,C,3] 48 | else: 49 | raise NotImplementedError('Length is : {}'.format(len(label_image.get_shape()))) 50 | 51 | pixel_class_diff = tf.reduce_sum(tf.square(diff), axis=-1) # [H,W,C] or [B,H,W,C] 52 | class_label = tf.argmin(pixel_class_diff, axis=-1) # [H,W] or [B,H,W] 53 | 54 | return tf.gather(colors_labels, class_label) > 0 55 | 56 | 57 | def multiclass_to_label_image(class_label_tensor: tf.Tensor, classes_file: str) -> tf.Tensor: 58 | 59 | classes_color_values, colors_labels = get_classes_color_from_file_multilabel(classes_file) 60 | 61 | n_classes = colors_labels.shape[1] 62 | c = np.zeros((2,)*n_classes+(3,), np.int32) 63 | for c_value, inds in zip(classes_color_values, colors_labels): 64 | c[tuple(inds)] = c_value 65 | 66 | with tf.name_scope('Label2Img'): 67 | return tf.gather_nd(c, tf.cast(class_label_tensor, tf.int32)) 68 | 69 | 70 | def get_classes_color_from_file(classes_file: str) -> np.ndarray: 71 | if not os.path.exists(classes_file): 72 | raise FileNotFoundError(classes_file) 73 | result = np.loadtxt(classes_file).astype(np.float32) 74 | assert result.shape[1] == 3, "Color file should represent RGB values" 75 | return result 76 | 77 | 78 | def get_n_classes_from_file(classes_file: str) -> int: 79 | return get_classes_color_from_file(classes_file).shape[0] 80 | 81 | 82 | def get_classes_color_from_file_multilabel(classes_file: str) -> Tuple[np.ndarray, np.array]: 83 | """ 84 | Get classes and code labels from txt file. 85 | This function deals with the case of elements with multiple labels. 86 | 87 | :param classes_file: file containing the classes (usually named *classes.txt*) 88 | :return: for each class the RGB color (array size [N, 3]); and the label's code (array size [N, C]), 89 | with N the number of combinations and C the number of classes 90 | """ 91 | if not os.path.exists(classes_file): 92 | raise FileNotFoundError(classes_file) 93 | result = np.loadtxt(classes_file).astype(np.float32) 94 | assert result.shape[1] > 3, "The number of columns should be greater in multilabel framework" 95 | colors = result[:, :3] 96 | labels = result[:, 3:] 97 | return colors, labels.astype(np.int32) 98 | 99 | 100 | def get_n_classes_from_file_multilabel(classes_file: str) -> int: 101 | return get_classes_color_from_file_multilabel(classes_file)[1].shape[1] 102 | -------------------------------------------------------------------------------- /dh_segment_text/post_processing/boxes_detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | from shapely import geometry 5 | from scipy.spatial import KDTree 6 | 7 | 8 | def find_boxes(boxes_mask: np.ndarray, mode: str= 'min_rectangle', min_area: float=0.2, 9 | p_arc_length: float=0.01, n_max_boxes=math.inf) -> list: 10 | """ 11 | Finds the coordinates of the box in the binary image `boxes_mask`. 12 | 13 | :param boxes_mask: Binary image: the mask of the box to find. uint8, 2D array 14 | :param mode: 'min_rectangle' : minimum enclosing rectangle, can be rotated 15 | 'rectangle' : minimum enclosing rectangle, not rotated 16 | 'quadrilateral' : minimum polygon approximated by a quadrilateral 17 | :param min_area: minimum area of the box to be found. A value in percentage of the total area of the image. 18 | :param p_arc_length: used to compute the epsilon value to approximate the polygon with a quadrilateral. 19 | Only used when 'quadrilateral' mode is chosen. 20 | :param n_max_boxes: maximum number of boxes that can be found (default inf). 21 | This will select n_max_boxes with largest area. 22 | :return: list of length n_max_boxes containing boxes with 4 corners [[x1,y1], ..., [x4,y4]] 23 | """ 24 | 25 | assert len(boxes_mask.shape) == 2, \ 26 | 'Input mask must be a 2D array ! Mask is now of shape {}'.format(boxes_mask.shape) 27 | 28 | contours, _ = cv2.findContours(boxes_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 29 | if contours is None: 30 | print('No contour found') 31 | return None 32 | found_boxes = list() 33 | 34 | h_img, w_img = boxes_mask.shape[:2] 35 | 36 | def validate_box(box: np.array) -> (np.array, float): 37 | """ 38 | 39 | :param box: array of 4 coordinates with format [[x1,y1], ..., [x4,y4]] 40 | :return: (box, area) 41 | """ 42 | polygon = geometry.Polygon([point for point in box]) 43 | if polygon.area > min_area * boxes_mask.size: 44 | 45 | # Correct out of range corners 46 | box = np.maximum(box, 0) 47 | box = np.stack((np.minimum(box[:, 0], boxes_mask.shape[1]), 48 | np.minimum(box[:, 1], boxes_mask.shape[0])), axis=1) 49 | 50 | # return box 51 | return box, polygon.area 52 | 53 | if mode not in ['quadrilateral', 'min_rectangle', 'rectangle']: 54 | raise NotImplementedError 55 | if mode == 'quadrilateral': 56 | for c in contours: 57 | epsilon = p_arc_length * cv2.arcLength(c, True) 58 | cnt = cv2.approxPolyDP(c, epsilon, True) 59 | # box = np.vstack(simplify_douglas_peucker(cnt[:, 0, :], 4)) 60 | 61 | # Find extreme points in Convex Hull 62 | hull_points = cv2.convexHull(cnt, returnPoints=True) 63 | # points = cnt 64 | points = hull_points 65 | if len(points) > 4: 66 | # Find closes points to corner using nearest neighbors 67 | tree = KDTree(points[:, 0, :]) 68 | _, ul = tree.query((0, 0)) 69 | _, ur = tree.query((w_img, 0)) 70 | _, dl = tree.query((0, h_img)) 71 | _, dr = tree.query((w_img, h_img)) 72 | box = np.vstack([points[ul, 0, :], points[ur, 0, :], 73 | points[dr, 0, :], points[dl, 0, :]]) 74 | elif len(hull_points) == 4: 75 | box = hull_points[:, 0, :] 76 | else: 77 | continue 78 | # Todo : test if it looks like a rectangle (2 sides must be more or less parallel) 79 | # todo : (otherwise we may end with strange quadrilaterals) 80 | if len(box) != 4: 81 | mode = 'min_rectangle' 82 | print('Quadrilateral has {} points. Switching to minimal rectangle mode'.format(len(box))) 83 | else: 84 | # found_box = validate_box(box) 85 | found_boxes.append(validate_box(box)) 86 | if mode == 'min_rectangle': 87 | for c in contours: 88 | rect = cv2.minAreaRect(c) 89 | box = np.int0(cv2.boxPoints(rect)) 90 | found_boxes.append(validate_box(box)) 91 | elif mode == 'rectangle': 92 | for c in contours: 93 | x, y, w, h = cv2.boundingRect(c) 94 | box = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]], dtype=int) 95 | found_boxes.append(validate_box(box)) 96 | # sort by area 97 | found_boxes = [fb for fb in found_boxes if fb is not None] 98 | found_boxes = sorted(found_boxes, key=lambda x: x[1], reverse=True) 99 | if n_max_boxes == 1: 100 | if found_boxes: 101 | return found_boxes[0][0] 102 | else: 103 | return None 104 | else: 105 | return [fb[0] for i, fb in enumerate(found_boxes) if i <= n_max_boxes] 106 | -------------------------------------------------------------------------------- /exps/Ornaments/ornaments_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | from glob import glob 6 | 7 | import cv2 8 | import numpy as np 9 | from scipy.misc import imread, imsave 10 | from tqdm import tqdm 11 | 12 | from doc_seg.post_processing.boxes_detection import find_box 13 | from exps.evaluation.base import Metrics 14 | 15 | 16 | def ornament_evaluate_folder(output_folder: str, validation_dir: str, debug_folder: str=None, 17 | verbose: bool=False, min_area: float=0.0, miou_threshold: float=0.8) -> dict: 18 | 19 | if debug_folder is not None: 20 | os.makedirs(debug_folder, exist_ok=True) 21 | 22 | filenames_binary_masks = glob(os.path.join(output_folder, '*.png')) 23 | 24 | global_metrics = Metrics() 25 | for filename in tqdm(filenames_binary_masks, desc='Evaluation'): 26 | basename = os.path.basename(filename).split('.')[0] 27 | 28 | # Open post_processed and label image 29 | post_processed_img = imread(filename) 30 | post_processed_img = post_processed_img / np.maximum(np.max(post_processed_img), 1) 31 | 32 | label_image = imread(os.path.join(validation_dir, 'labels', '{}.png'.format(basename)), mode='L') 33 | label_image = label_image / np.max(label_image) if np.max(label_image) > 0 else label_image 34 | 35 | # Upsample processed image to compare it to original image 36 | target_shape = (label_image.shape[1], label_image.shape[0]) 37 | bin_upscaled = cv2.resize(np.uint8(post_processed_img), target_shape, interpolation=cv2.INTER_NEAREST) 38 | 39 | pred_boxes = find_box(np.uint8(bin_upscaled), mode='min_rectangle', min_area=min_area, n_max_boxes=np.inf) 40 | label_boxes = find_box(np.uint8(label_image), mode='min_rectangle', min_area=min_area, n_max_boxes=np.inf) 41 | 42 | if debug_folder is not None: 43 | # imsave(os.path.join(debug_folder, '{}_bin.png'.format(basename)), np.uint8(bin_upscaled*255)) 44 | # orig_img = imread(os.path.join(validation_dir, 'images', '{}.jpg'.format(basename)), mode='RGB') 45 | orig_img = imread(os.path.join(validation_dir, 'images', '{}.png'.format(basename)), mode='RGB') 46 | cv2.polylines(orig_img, [lb[:, None, :] for lb in label_boxes], True, (0, 255, 0), thickness=15) 47 | if pred_boxes is not None: 48 | cv2.polylines(orig_img, [pb[:, None, :] for pb in pred_boxes], True, (0, 0, 255), thickness=15) 49 | imsave(os.path.join(debug_folder, '{}_boxes.jpg'.format(basename)), orig_img) 50 | 51 | def intersection_over_union(cnt1, cnt2): 52 | mask1 = np.zeros_like(label_image) 53 | mask1 = cv2.fillConvexPoly(mask1, cnt1.astype(np.int32), 1).astype(np.int8) 54 | mask2 = np.zeros_like(label_image) 55 | mask2 = cv2.fillConvexPoly(mask2, cnt2.astype(np.int32), 1).astype(np.int8) 56 | 57 | return np.sum(mask1 & mask2) / np.sum(mask1 | mask2) 58 | 59 | def compute_metric_boxes(predicted_boxes: np.array, label_boxes: np.array, threshold: float=miou_threshold): 60 | # Todo test this fn 61 | metric = Metrics() 62 | if label_boxes is None: 63 | if predicted_boxes is None: 64 | metric.true_negatives += 1 65 | metric.total_elements += 1 66 | else: 67 | metric.false_negatives += len(predicted_boxes) 68 | 69 | else: 70 | for pb in predicted_boxes: 71 | best_iou = 0 72 | for lb in label_boxes: 73 | iou = intersection_over_union(pb[:, None, :], lb[:, None, :]) 74 | if iou > best_iou: 75 | best_iou = iou 76 | 77 | if best_iou > threshold: 78 | metric.true_positives += 1 79 | metric.IOU_list.append(best_iou) 80 | elif best_iou < 0.1: 81 | metric.false_negatives += 1 82 | else: 83 | metric.false_positives += 1 84 | metric.IOU_list.append(best_iou) 85 | 86 | metric.total_elements += len(label_boxes) 87 | return metric 88 | 89 | global_metrics += compute_metric_boxes(pred_boxes, label_boxes) 90 | 91 | global_metrics.compute_miou() 92 | global_metrics.compute_accuracy() 93 | global_metrics.compute_prf() 94 | print('EVAL --- mIOU : {}, accuracy : {}, precision : {}, ' 95 | 'recall : {}, f_measure : {}\n'.format(global_metrics.mIOU, global_metrics.accuracy, global_metrics.precision, 96 | global_metrics.recall, global_metrics.f_measure)) 97 | 98 | return { 99 | 'precision': global_metrics.precision, 100 | 'recall': global_metrics.recall, 101 | 'f_measure': global_metrics.f_measure, 102 | 'mIOU': global_metrics.mIOU 103 | } -------------------------------------------------------------------------------- /exps/Ornaments/ornaments_process_set.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | import sys 6 | 7 | import tensorflow as tf 8 | 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 10 | from doc_seg.loader import LoadedModel 11 | from ornaments_post_processing import ornaments_post_processing_fn 12 | from doc_seg.post_processing import boxes_detection 13 | from exps.evaluation.base import format_quad_to_string 14 | from tqdm import tqdm 15 | import numpy as np 16 | import argparse 17 | import json 18 | import tempfile 19 | from glob import glob 20 | from scipy.misc import imread, imsave 21 | import cv2 22 | 23 | 24 | def predict_on_set(filenames_to_predict, model_dir, output_dir): 25 | """ 26 | 27 | :param filenames_to_predict: 28 | :param model_dir: 29 | :param output_dir: 30 | :return: 31 | """ 32 | with tf.Session(): 33 | m = LoadedModel(model_dir, 'filename') 34 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 35 | pred = m.predict(filename)['probs'][0] 36 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), 37 | np.uint8(255 * pred)) 38 | 39 | 40 | def find_ornament(img_filenames, dir_predictions, post_process_params, output_dir, debug=False): 41 | """ 42 | 43 | :param img_filenames: 44 | :param dir_predictions: 45 | :param post_process_params: 46 | :param output_dir: 47 | :return: 48 | """ 49 | 50 | with open(os.path.join(output_dir, 'ornaments.txt'), 'w') as f: 51 | for filename in tqdm(img_filenames, 'Post-processing'): 52 | orig_img = imread(filename, mode='RGB') 53 | basename = os.path.basename(filename).split('.')[0] 54 | 55 | filename_pred = os.path.join(dir_predictions, basename + '.npy') 56 | pred = np.load(filename_pred) 57 | page_bin = ornaments_post_processing_fn(pred / np.max(pred), **post_process_params) 58 | 59 | target_shape = (orig_img.shape[1], orig_img.shape[0]) 60 | bin_upscaled = cv2.resize(np.uint8(page_bin), target_shape, interpolation=cv2.INTER_NEAREST) 61 | if debug: 62 | imsave(os.path.join(output_dir, '{}_bin.png'.format(basename)), bin_upscaled) 63 | pred_box = boxes_detection.find_box(np.uint8(bin_upscaled), mode='min_rectangle', 64 | min_area=0.005, n_max_boxes=10) 65 | if pred_box is not None: 66 | for box in pred_box: 67 | cv2.polylines(orig_img, [box[:, None, :]], True, (0, 0, 255), thickness=15) 68 | else: 69 | print('No box found in {}'.format(filename)) 70 | imsave(os.path.join(output_dir, '{}_boxes.jpg'.format(basename)), orig_img) 71 | 72 | f.write('{},{}\n'.format(filename, [format_quad_to_string(box) for box in pred_box])) 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('-m', '--model_dir', type=str, required=True, 78 | help='Directory of the model (should be of type ''*/export/)') 79 | parser.add_argument('-i', '--input_files', type=str, required=True, nargs='+', 80 | help='Folder containing the images to evaluate the model on') 81 | parser.add_argument('-o', '--output_dir', type=str, required=True, 82 | help='Folder containing the outputs (.npy predictions and visualization errors)') 83 | parser.add_argument('--post_process_params', type=str, default=None, 84 | help='JSOn file containing the params for post-processing') 85 | parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') 86 | parser.add_argument('-pp', '--post_process_only', default=False, action='store_true', 87 | help='Whether to make or not the prediction') 88 | args = parser.parse_args() 89 | args = vars(args) 90 | 91 | os.environ["CUDA_VISIBLE_DEVICES"] = args.get('gpu') 92 | model_dir = args.get('model_dir') 93 | input_files = args.get('input_files') 94 | if len(input_files) == 0: 95 | raise FileNotFoundError 96 | 97 | output_dir = args.get('output_dir') 98 | os.makedirs(output_dir, exist_ok=True) 99 | post_process_params = args.get('post_proces_params') 100 | 101 | if post_process_params: 102 | with open(post_process_params, 'r') as f: 103 | post_process_params = json.load(f) 104 | post_process_params = post_process_params['params'] 105 | else: 106 | post_process_params = {"threshold": -1, "ksize_open": [5, 5], "ksize_close": [5, 5]} 107 | 108 | # Prediction 109 | with tempfile.TemporaryDirectory() as tmpdirname: 110 | npy_directory = output_dir 111 | if not args.get('post_process_only'): 112 | predict_on_set(input_files, model_dir, npy_directory) 113 | 114 | npy_files = glob(os.path.join(npy_directory, '*.npy')) 115 | find_ornament(input_files, npy_directory, post_process_params, output_dir) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /exps/Cini/cini_process_set.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 8 | from doc_seg.loader import LoadedModel 9 | from cini_post_processing import cini_post_processing_fn 10 | from cini_evaluation import cini_evaluate_folder 11 | 12 | import tensorflow as tf 13 | from tqdm import tqdm 14 | import numpy as np 15 | import argparse 16 | from glob import glob 17 | from scipy.misc import imread, imresize, imsave 18 | import tempfile 19 | import json 20 | from doc_seg.post_processing import PAGE 21 | from doc_seg.utils import hash_dict, dump_json 22 | 23 | 24 | def predict_on_set(filenames_to_predict, model_dir, output_dir): 25 | """ 26 | 27 | :param filenames_to_predict: 28 | :param model_dir: 29 | :param output_dir: 30 | :return: 31 | """ 32 | with tf.Session(): 33 | m = LoadedModel(model_dir, 'filename') 34 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 35 | pred = m.predict(filename)['probs'][0] 36 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), 37 | np.uint8(255 * pred)) 38 | 39 | 40 | def find_elements(img_filenames, dir_predictions, post_process_params, output_dir, debug=False, mask_dir: str=None): 41 | """ 42 | 43 | :param img_filenames: 44 | :param dir_predictions: 45 | :param post_process_params: 46 | :param output_dir: 47 | :return: 48 | """ 49 | 50 | os.makedirs(output_dir, exist_ok=True) 51 | 52 | for filename in tqdm(img_filenames, 'Post-processing'): 53 | orig_img = imread(filename, mode='RGB') 54 | basename = os.path.basename(filename).split('.')[0] 55 | 56 | filename_pred = os.path.join(dir_predictions, basename + '.npy') 57 | pred = np.load(filename_pred)/255 # type: np.ndarray 58 | 59 | contours, lines_mask = cini_post_processing_fn(pred, **post_process_params, 60 | output_basename=os.path.join(output_dir, basename)) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-m', '--model-dir', type=str, required=True, 66 | help='Directory of the model (should be of type ''*/export/)') 67 | parser.add_argument('-i', '--input-files', type=str, required=True, nargs='+', 68 | help='Folder containing the images to evaluate the model on') 69 | parser.add_argument('-o', '--output-dir', type=str, required=True, 70 | help='Folder containing the outputs (.npy predictions and visualization errors)') 71 | parser.add_argument('-gt', '--ground_truth_dir', type=str, required=True, 72 | help='Ground truth directory containing the labeled images') 73 | parser.add_argument('--params-file', type=str, default=None, 74 | help='JSOn file containing the params for post-processing') 75 | parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') 76 | parser.add_argument('-pp', '--post-process-only', default=False, action='store_true', 77 | help='Whether to make or not the prediction') 78 | args = parser.parse_args() 79 | args = vars(args) 80 | 81 | os.environ["CUDA_VISIBLE_DEVICES"] = args.get('gpu') 82 | model_dir = args.get('model_dir') 83 | input_files = args.get('input_files') 84 | if len(input_files) == 0: 85 | raise FileNotFoundError 86 | 87 | output_dir = args.get('output_dir') 88 | os.makedirs(output_dir, exist_ok=True) 89 | 90 | # Prediction 91 | npy_directory = output_dir 92 | if not args.get('post_process_only'): 93 | predict_on_set(input_files, model_dir, npy_directory) 94 | 95 | npy_files = glob(os.path.join(npy_directory, '*.npy')) 96 | 97 | if args.get('params_file') is None: 98 | print('No params file found') 99 | params_list = [{"clean_predictions": True, "advanced": True}] 100 | else: 101 | with open(args.get('params_file'), 'r') as f: 102 | configs_data = json.load(f) 103 | # If the file contains a list of configurations 104 | if 'configs' in configs_data.keys(): 105 | params_list = configs_data['configs'] 106 | assert isinstance(params_list, list) 107 | # Or if there is a single configuration 108 | else: 109 | params_list = [configs_data] 110 | 111 | gt_dir = args.get('ground_truth_dir') 112 | 113 | for params in tqdm(params_list, desc='Params'): 114 | print(params) 115 | exp_dir = os.path.join(output_dir, '_' + hash_dict(params)) 116 | find_elements(input_files, npy_directory, params, exp_dir, debug=False) 117 | 118 | if gt_dir is not None: 119 | scores = cini_evaluate_folder(exp_dir, gt_dir, debug_folder=os.path.join(exp_dir, '_debug')) 120 | dump_json(os.path.join(exp_dir, 'post_process_config.json'), params) 121 | dump_json(os.path.join(exp_dir, 'scores.json'), scores) 122 | print('Scores : {}'.format(scores)) 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /exps/_misc/post_process_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = 'solivr' 3 | 4 | import argparse 5 | import json 6 | import os 7 | from glob import glob 8 | from hashlib import sha1 9 | import sys 10 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 11 | from exps import cbad_post_processing_fn, dibco_binarization_fn, cini_post_processing_fn, \ 12 | page_post_processing_fn, diva_post_processing_fn, ornaments_post_processing_fn 13 | from exps import cbad_evaluate_folder, dibco_evaluate_folder, cini_evaluate_folder, \ 14 | page_evaluate_folder, diva_evaluate_folder, ornament_evaluate_folder, evaluate_epoch 15 | from dh_segment.utils import parse_json 16 | from tqdm import tqdm 17 | from functools import partial 18 | 19 | POST_PROCESSING_DIR_NAME = 'post_processing' 20 | 21 | POST_PROCESSING_EVAL_FN_DICT = { 22 | 'cbad': (cbad_post_processing_fn, cbad_evaluate_folder), 23 | 'dibco': (dibco_binarization_fn, dibco_evaluate_folder), 24 | 'cini': (cini_post_processing_fn, cini_evaluate_folder), 25 | 'page': (page_post_processing_fn, page_evaluate_folder), 26 | 'diva': (diva_post_processing_fn, diva_evaluate_folder), 27 | 'ornaments': (ornaments_post_processing_fn, ornament_evaluate_folder) 28 | } 29 | 30 | 31 | def _hash_dict(params): 32 | return sha1(json.dumps(params, sort_keys=True).encode()).hexdigest() 33 | 34 | 35 | def evaluate_one_model(model_dir, validation_dir, post_processing_pair, post_processing_params, 36 | verbose=False, save_params=True, n_selected_epochs=None) -> None: 37 | """ 38 | Evaluate a combination model/post-process 39 | :param model_dir: 40 | :param validation_dir: 41 | :param post_processing_pair: 42 | :param post_processing_params: 43 | :param verbose: 44 | :param save_params: 45 | :return: 46 | """ 47 | eval_outputs_dir = os.path.join(model_dir, 'eval', 'epoch_*') 48 | list_saved_epochs = glob(eval_outputs_dir) 49 | list_saved_epochs.sort() 50 | if n_selected_epochs is not None: 51 | list_saved_epochs = list_saved_epochs[-n_selected_epochs:] 52 | if len(list_saved_epochs) == 0: 53 | print('No file found in : {}'.format(eval_outputs_dir)) 54 | return 55 | 56 | post_process_dir = os.path.join(model_dir, POST_PROCESSING_DIR_NAME, _hash_dict(post_processing_params)) 57 | os.makedirs(post_process_dir, exist_ok=True) 58 | 59 | validation_scores = dict() 60 | for saved_epoch in tqdm(list_saved_epochs, desc='Epoch dir'): 61 | epoch_dir_name = saved_epoch.split(os.path.sep)[-1] 62 | epoch, timestamp = (int(s) for s in epoch_dir_name.split('_')[1:3]) 63 | validation_scores[epoch_dir_name] = {**evaluate_epoch(saved_epoch, validation_dir, 64 | post_process_fn=partial(post_processing_pair[0], 65 | **post_processing_params), 66 | evaluation_fn=post_processing_pair[1] 67 | ), 68 | "epoch": epoch, 69 | "timestamp": timestamp 70 | } 71 | 72 | with open(os.path.join(post_process_dir, 'validation_scores.json'), 'w') as f: 73 | json.dump(validation_scores, f) 74 | 75 | if save_params: 76 | with open(os.path.join(post_process_dir, 'post_process_params.json'), 'w') as f: 77 | json.dump({'post_process_fn': post_processing_pair[0].__name__, 'params': post_processing_params}, f) 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('-m', '--model-dir', type=str, required=True, nargs='+') 83 | parser.add_argument('-p', '--params-json-file', type=str, required=True) 84 | parser.add_argument('-t', '--task-type', type=str, required=True, 85 | help="Choose among : 'cbad', 'dibco', 'page', 'cini'") 86 | parser.add_argument('-v', '--verbose', type=bool, default=False) 87 | parser.add_argument('-ne', '--n_epochs', type=int, default=None, help='Number of selected epochs to evaluate') 88 | # Labels dir is not necessary anymore, can be obtained directly from model config 89 | # parser.add_argument('-l', '--labels-dir', type=str, required=True) 90 | args = vars(parser.parse_args()) 91 | 92 | # get the pair post-process fn and post-process eval 93 | post_processing_pair = POST_PROCESSING_EVAL_FN_DICT[args['task_type']] 94 | with open(args.get('params_json_file'), 'r') as f: 95 | configs_data = json.load(f) 96 | # If the file contains a list of configurations 97 | if 'configs' in configs_data.keys(): 98 | params_list = configs_data['configs'] 99 | assert isinstance(params_list, list) 100 | # Or if there is a single configuration 101 | else: 102 | params_list = [configs_data] 103 | 104 | model_dirs = args.get('model_dir') 105 | print('Found {} configs and {} model directories'.format(len(params_list), len(model_dirs))) 106 | 107 | for params in tqdm(params_list, desc='Params'): 108 | for model_dir in tqdm(model_dirs, desc='Model directory'): 109 | eval_data_dir = parse_json(os.path.join(model_dir, 'config.json'))['eval_dir'] 110 | evaluate_one_model(model_dir, eval_data_dir, 111 | post_processing_pair, 112 | params, args.get('verbose'), 113 | n_selected_epochs=args.get('n_epochs')) 114 | -------------------------------------------------------------------------------- /dh_segment_text/network/simple_decoder.py: -------------------------------------------------------------------------------- 1 | from .model import Decoder 2 | import tensorflow as tf 3 | from tensorflow.contrib import layers 4 | from tensorflow.contrib.slim import arg_scope 5 | from ..embeddings.encoder import EmbeddingsEncoder 6 | from typing import List, Union, Tuple, Type 7 | 8 | class SimpleDecoder(Decoder): 9 | """ 10 | 11 | :ivar upsampling_dims: 12 | :ivar max_depth: 13 | :ivar weight_decay: 14 | :ivar self.batch_norm_fn: 15 | """ 16 | def __init__(self, upsampling_dims: List[int], max_depth: int = None, weight_decay: float=0., 17 | concat_level: int=-1): 18 | self.upsampling_dims = upsampling_dims 19 | self.max_depth = max_depth 20 | self.weight_decay = weight_decay 21 | self.concat_level = concat_level 22 | 23 | renorm = True 24 | self.batch_norm_params = { 25 | "renorm": renorm, 26 | "renorm_clipping": {'rmax': 100, 'rmin': 0.1, 'dmax': 1}, 27 | "renorm_momentum": 0.98 28 | } 29 | 30 | def __call__(self, feature_maps: List[tf.Tensor], num_classes: int, is_training=False, 31 | embeddings_encoder: Type[EmbeddingsEncoder]=None, 32 | embeddings: tf.Tensor=tf.zeros((1,300), dtype=tf.float32), 33 | embeddings_map: tf.Tensor=tf.zeros((200,200), dtype=tf.int32)): 34 | 35 | #batch_norm_fn = lambda x: tf.layers.batch_normalization(x, axis=-1, training=is_training, 36 | #name='batch_norm', **self.batch_norm_params) 37 | 38 | batch_norm_fn = lambda x: tf.contrib.layers.batch_norm(x, scale=True, renorm_decay=0.99, renorm=True, renorm_clipping= {'rmax': 100, 'rmin': 0.1, 'dmax': 1}) 39 | 40 | # Upsampling 41 | with tf.variable_scope('SimpleDecoder'): 42 | with arg_scope([layers.conv2d], 43 | normalizer_fn=batch_norm_fn, 44 | weights_regularizer=layers.l2_regularizer(self.weight_decay)): 45 | 46 | assert len(self.upsampling_dims) + 1 == len(feature_maps), \ 47 | 'Upscaling : length of {} does net match {}'.format(len(self.upsampling_dims), 48 | len(feature_maps)) 49 | 50 | # Force layers to not be too big to reduce memory usage 51 | for i, l in enumerate(feature_maps): 52 | if self.max_depth and l.get_shape()[-1] > self.max_depth: 53 | feature_maps[i] = layers.conv2d( 54 | inputs=l, 55 | num_outputs=self.max_depth, 56 | kernel_size=[1, 1], 57 | scope="dimreduc_{}".format(i), 58 | normalizer_fn=batch_norm_fn, 59 | activation_fn=None 60 | ) 61 | 62 | # Deconvolving loop 63 | out_tensor = feature_maps[-1] 64 | for i, f_map in reversed(list(enumerate(feature_maps[:-1]))): 65 | out_tensor = _upsample_concat(out_tensor, f_map, scope_name='upsample_{}'.format(i)) 66 | if i == self.concat_level: 67 | with tf.variable_scope('Embeddings'): 68 | embeddings_features = embeddings_encoder(embeddings, embeddings_map, tf.shape(out_tensor)[1:3], is_training) 69 | out_tensor = tf.concat([out_tensor, embeddings_features], axis=-1) 70 | out_tensor = layers.conv2d(inputs=out_tensor, 71 | num_outputs=self.upsampling_dims[i], 72 | kernel_size=[3, 3], 73 | scope="conv_{}".format(i)) 74 | if self.concat_level == 100: 75 | with tf.variable_scope('Embeddings'): 76 | embeddings_features = embeddings_encoder(embeddings, embeddings_map, tf.shape(out_tensor)[1:3], is_training) 77 | out_tensor = tf.concat([out_tensor, embeddings_features], axis=-1) 78 | 79 | 80 | logits = layers.conv2d(inputs=out_tensor, 81 | num_outputs=num_classes, 82 | activation_fn=None, 83 | kernel_size=[1, 1], 84 | scope="conv-logits") 85 | 86 | return logits 87 | 88 | 89 | def _get_image_shape_tensor(tensor: tf.Tensor) -> Union[Tuple[int, int], tf.Tensor]: 90 | """ 91 | Get the image shape of the tensor 92 | 93 | :param tensor: Input image tensor [N,H,W,...] 94 | :return: a (int, int) tuple if shape is defined, otherwise the corresponding tf.Tensor value 95 | """ 96 | if tensor.get_shape()[1].value and \ 97 | tensor.get_shape()[2].value: 98 | target_shape = tensor.get_shape()[1:3] 99 | else: 100 | target_shape = tf.shape(tensor)[1:3] 101 | return target_shape 102 | 103 | 104 | def _upsample_concat(pooled_layer: tf.Tensor, previous_layer: tf.Tensor, scope_name: str='UpsampleConcat'): 105 | """ 106 | 107 | :param pooled_layer: [N,H,W,C] coarse layer 108 | :param previous_layer: [N,H',W',C'] fine layer (H'>H, and W'>W) 109 | :param scope_name: 110 | :return: [N,H',W',C+C'] concatenation of upsampled-`pooled_layer` and `previous_layer` 111 | """ 112 | with tf.name_scope(scope_name): 113 | # Upsamples the coarse level 114 | target_shape = _get_image_shape_tensor(previous_layer) 115 | upsampled_layer = tf.image.resize_images(pooled_layer, target_shape, 116 | method=tf.image.ResizeMethod.BILINEAR) 117 | # Concatenate the upsampled-coarse and the other feature_map 118 | input_tensor = tf.concat([upsampled_layer, previous_layer], 3) 119 | return input_tensor 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dhSegment text 2 | 3 | This a fork of the original [dhSegment repository](https://github.com/dhlab-epfl/dhSegment), developed to carry out experiments on combining visual and textual features published in the paper **Combining Visual and Textual Features for Semantic Segmentation of Historical Newspapers** (see reference below). 4 | 5 | 6 | ## 1. Modifications 7 | 8 | Compared to the original dhSegment repository, the following modifications were made: 9 | 10 | - Change of the input pipeline to read embeddings; 11 | - Creation of embeddings maps with several dimensionality reduction algorithms; 12 | - Concatenation of the embeddings map inside the encoder or decoder. 13 | 14 | ## 2. Usage 15 | For general usage of dhSegment, see the [original documentation](https://dhsegment.readthedocs.io/). 16 | 17 | - The csv file now needs four columns: image, label, embeddings, embeddings_map. 18 | - Different configuration options were added for choosing the different hyperparameters and can be found in `dh_segment_text/utils/params_config.py` and in the encoder and decoder. 19 | - An example config can be found under `embeddings_config.json`. 20 | 21 | The training can be launched using the trainer script with `python dh_segment_train.py with /path/to/config.json`. 22 | 23 | ## 3. Data & Models 24 | 25 | **Pay attention to the terms of use of the material.** 26 | 27 | ### 3.1 Data 28 | 29 | #### Image annotations 30 | The folder contains image annotations, with one file per newspaper containing region annotations (label and coordinates) in [VIA](http://www.robots.ox.ac.uk/~vgg/software/via/) format (v2.0.10). 31 | 32 | The following licenses apply: 33 | - `luxwort.json`: those annotations are under a [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/legalcode) license. Please refer to the right statement specified for each image in the JSON file. 34 | 35 | - `GDL.json`, `IMP.json` and `JDG.json`: those annotations are under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) license. 36 | 37 | #### Image files 38 | *(these files are available on Zenodo, see badge below)* 39 | - Images of Swiss titles (GDL, IMP, JDG) are released as an asset of the current Github [release](https://github.com/dhlab-epfl/dhSegment-text/releases/tag/0.1), in the `images.zip` archive. 40 | **Terms of use**: Those images are under copyright (property of the journal *Le Temps* and of *ArcInfo*) and can be used for academic research or educational purposes only. Redistribution, publication or commercial use are not permitted. These terms of use are similar to the following right statement: http://rightsstatements.org/vocab/InC-EDU/1.0/ 41 | 42 | - Images of the Luxembourgish title are available through the IIIF endpoint of the National Library of Luxembourg (see URL in the annnotation file `luxwort.json`). 43 | 44 | ### 3.2 Trained models 45 | *(these files are available on Zenodo, see badge below)* 46 | 47 | Some of the best models are released as assets of the current Github [release](https://github.com/dhlab-epfl/dhSegment-text/releases/tag/0.1) in zip files. 48 | 49 | - **JDG_flair-FT**: this model was trained on JDG using french Flair and FastText embeddings. It is able to predict the four classes presented in the paper (`Serial`, `Weather`, `Death notice` and `Stocks`). 50 | - **Luxwort_obituary_flair-bpemb**: this model was trained on Luxwort using multilingual Flair and Byte-pair embeddings. It is able to predict the `Death notice` class. 51 | - **Luxwort_obituary_flair-FT_indomain**: this model was trained on Luxwort using in-domain Flair and FastText embeddings (trained on Luxwort data). It is also able to predict the `Death notice` class. 52 | 53 | Those models can be used to predict probabilities on new images using the same code as in the original dhSegment repository. 54 | One needs to adjust three parameters to the `predict` function: 1) `embeddings_path` (the path to the embeddings list), 2) `embeddings_map_path`(the path to the compressed embedding map), and 3) `embeddings_dim` (the size of the embeddings). 55 | 56 | Models are available under a [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) license. Please refer to the paper (see below) for further information or contact us. 57 | 58 | **DOI data and models:** 59 | 60 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3706863.svg)](https://doi.org/10.5281/zenodo.3706863) 61 | 62 | 63 | ## 4. Paper 64 | 65 | Please cite this paper if you are using the tool/datasets or find it relevant to your research: 66 | 67 | [*Combining Visual and Textual Features for Semantic Segmentation of Historical Newspapers*](https://infoscience.epfl.ch/record/282863?&ln=en). Barman Raphaël, Ehrmann Maud, Clematide Simon, Ares Oliveira Sofia, Kaplan Frédéric. 68 | 69 | 70 | ``` 71 | @article{barman_combining_2020, 72 | title = {{Combining Visual and Textual Features for Semantic Segmentation of Historical Newspapers}}, 73 | author = {Raphaël Barman and Maud Ehrmann and Simon Clematide and Sofia Ares Oliveira and Frédéric Kaplan}, 74 | journal= {Journal of Data Mining \& Digital Humanities}, 75 | volume= {HistoInformatics} 76 | DOI = {10.5281/zenodo.4065271}, 77 | year = {2021}, 78 | url = {https://jdmdh.episciences.org/7097}, 79 | } 80 | ``` 81 | **DOI paper:** 82 | 83 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.4065271.svg)](https://doi.org/10.5281/zenodo.4065271) 84 | 85 | 86 | ## 5. Background & Acknowledgements 87 | 88 | This work was carried out in the frame of the master thesis of Raphaël Barman. 89 | 90 | We warmly thank the journal [Le Temps](https://letemps.ch) (owner of *La Gazette de Lausanne* and the *Journal de Genève*) and the group [ArcInfo](https://www.arcinfo.ch/) (owner of *L'Impartial*) for accepting to share the related datasets for academic purposes. We also thank the [National Library of Luxembourg](https://bnl.public.lu/fr.html) for its support with all steps related to the *Luxemburger Wort* annotation release. 91 | 92 | This work was realized in the context of the ['*impresso* - Media Monitoring of the Past'](https://impresso-project.ch) project supported by the Swiss National Science Foundation under grant CR-SII5_173719. 93 | 94 | 95 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'dhSegment' 23 | copyright = '2018, Digital Humanities Lab - EPFL' 24 | author = 'Sofia ARES OLIVEIRA, Benoit SEGUIN' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.autosummary', 44 | 'sphinx.ext.coverage', 45 | 'sphinx.ext.githubpages', 46 | 'sphinxcontrib.bibtex', # for bibtex 47 | 'sphinx_autodoc_typehints', # for typing 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = '.rst' 58 | 59 | # The master toctree document. 60 | master_doc = 'index' 61 | 62 | # The language for content autogenerated by Sphinx. Refer to documentation 63 | # for a list of supported languages. 64 | # 65 | # This is also used if you do content translation via gettext catalogs. 66 | # Usually you set "language" from the command line for these cases. 67 | language = None 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path . 72 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 73 | 74 | # The name of the Pygments (syntax highlighting) style to use. 75 | pygments_style = 'sphinx' 76 | 77 | 78 | # -- Options for HTML output ------------------------------------------------- 79 | 80 | # The theme to use for HTML and HTML Help pages. See the documentation for 81 | # a list of builtin themes. 82 | # 83 | html_theme = 'sphinx_rtd_theme' # alabaster, haiku, nature, pyramid, agogo, bizstyle, sphinx_rtd_theme 84 | 85 | # Theme options are theme-specific and customize the look and feel of a theme 86 | # further. For a list of options available for each theme, see the 87 | # documentation. 88 | # 89 | # html_theme_options = {} 90 | 91 | # Add any paths that contain custom static files (such as style sheets) here, 92 | # relative to this directory. They are copied after the builtin static files, 93 | # so a file named "default.css" will overwrite the builtin "default.css". 94 | html_static_path = ['_static'] 95 | 96 | # Custom sidebar templates, must be a dictionary that maps document names 97 | # to template names. 98 | # 99 | # The default sidebars (for documents that don't match any pattern) are 100 | # defined by theme itself. Builtin themes are using these templates by 101 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 102 | # 'searchbox.html']``. 103 | # 104 | # html_sidebars = {} 105 | 106 | 107 | # -- Options for HTMLHelp output --------------------------------------------- 108 | 109 | # Output file base name for HTML help builder. 110 | htmlhelp_basename = 'dhsegmentdoc' 111 | 112 | 113 | # -- Options for LaTeX output ------------------------------------------------ 114 | 115 | latex_elements = { 116 | # The paper size ('letterpaper' or 'a4paper'). 117 | # 118 | # 'papersize': 'letterpaper', 119 | 120 | # The font size ('10pt', '11pt' or '12pt'). 121 | # 122 | # 'pointsize': '10pt', 123 | 124 | # Additional stuff for the LaTeX preamble. 125 | # 126 | # 'preamble': '', 127 | 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class]). 136 | latex_documents = [ 137 | (master_doc, 'dhsegment.tex', 'dhsegment Documentation', 138 | author, 'manual'), 139 | ] 140 | 141 | 142 | # -- Options for manual page output ------------------------------------------ 143 | 144 | # One entry per manual page. List of tuples 145 | # (source start file, name, description, authors, manual section). 146 | man_pages = [ 147 | (master_doc, 'dhsegment', 'dhsegment Documentation', 148 | [author], 1) 149 | ] 150 | 151 | 152 | # -- Options for Texinfo output ---------------------------------------------- 153 | 154 | # Grouping the document tree into Texinfo files. List of tuples 155 | # (source start file, target name, title, author, 156 | # dir menu entry, description, category) 157 | texinfo_documents = [ 158 | (master_doc, 'dhsegment', 'dhsegment Documentation', 159 | author, 'dhsegment', 'One line description of project.', 160 | 'Miscellaneous'), 161 | ] 162 | 163 | 164 | # -- Extension configuration ------------------------------------------------- 165 | 166 | autodoc_mock_imports = [ 167 | # 'numpy', 168 | 'scipy', 169 | 'tensorflow', 170 | 'pandas', 171 | 'sklearn', 172 | 'skimage', 173 | 'shapely', 174 | 'typing', 175 | 'cv2' 176 | ] 177 | -------------------------------------------------------------------------------- /exps/cbad/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | import cv2 5 | import numpy as np 6 | from imageio import imread, imsave 7 | from tqdm import tqdm 8 | from dh_segment.post_processing import PAGE 9 | 10 | TARGET_HEIGHT = 1100 11 | DRAWING_COLOR_BASELINES = (255, 0, 0) 12 | DRAWING_COLOR_POINTS = (0, 255, 0) 13 | 14 | RANDOM_SEED = 0 15 | np.random.seed(RANDOM_SEED) 16 | 17 | 18 | def get_page_filename(image_filename: str) -> str: 19 | return os.path.join(os.path.dirname(image_filename), 20 | 'page', 21 | '{}.xml'.format(os.path.basename(image_filename)[:-4])) 22 | 23 | 24 | def get_image_label_basename(image_filename: str) -> str: 25 | # Get acronym followed by name of file 26 | directory, basename = os.path.split(image_filename) 27 | acronym = directory.split(os.path.sep)[-1].split('_')[0] 28 | return '{}_{}'.format(acronym, basename.split('.')[0]) 29 | 30 | 31 | def save_and_resize(img: np.array, filename: str, size=None, nearest: bool=False) -> None: 32 | if size is not None: 33 | h, w = img.shape[:2] 34 | ratio = float(np.sqrt(size/(h*w))) 35 | resized = cv2.resize(img, (int(w*ratio), int(h*ratio)), 36 | interpolation=cv2.INTER_NEAREST if nearest else cv2.INTER_LINEAR) 37 | imsave(filename, resized) 38 | else: 39 | imsave(filename, img) 40 | 41 | 42 | def annotate_one_page(image_filename: str, output_dir: str, size: int, endpoints: bool=False, 43 | line_thickness: int=10, diameter_endpoint: int=20) -> None: 44 | page_filename = get_page_filename(image_filename) 45 | page = PAGE.parse_file(page_filename) 46 | text_lines = [tl for tr in page.text_regions for tl in tr.text_lines] 47 | img = imread(image_filename, pilmode='RGB') 48 | gt = np.zeros_like(img) 49 | 50 | if not endpoints: 51 | gt = cv2.polylines(gt, 52 | [(PAGE.Point.list_to_cv2poly(tl.baseline)[:, 0, :])[:, None, :] for tl in text_lines], 53 | isClosed=False, color=DRAWING_COLOR_BASELINES, 54 | thickness=int(line_thickness * (gt.shape[0] / TARGET_HEIGHT))) 55 | 56 | else: 57 | gt_lines = np.zeros_like(img[:, :, 0]) 58 | gt_lines = cv2.polylines(gt_lines, 59 | [(PAGE.Point.list_to_cv2poly(tl.baseline)[:, 0, :])[:, None, :] for tl in text_lines], 60 | isClosed=False, color=255, 61 | thickness=int(line_thickness * (gt_lines.shape[0] / TARGET_HEIGHT))) 62 | 63 | gt_points = np.zeros_like(img[:, :, 0]) 64 | for tl in text_lines: 65 | try: 66 | gt_points = cv2.circle(gt_points, (tl.baseline[0].x, tl.baseline[0].y), 67 | radius=int((diameter_endpoint/2 * (gt_points.shape[0]/TARGET_HEIGHT))), 68 | color=255, thickness=-1) 69 | gt_points = cv2.circle(gt_points, (tl.baseline[-1].x, tl.baseline[-1].y), 70 | radius=int((diameter_endpoint/2 * (gt_points.shape[0]/TARGET_HEIGHT))), 71 | color=255, thickness=-1) 72 | except IndexError: 73 | print('Length of baseline is {}'.format(len(tl.baseline))) 74 | 75 | gt[:, :, np.argmax(DRAWING_COLOR_BASELINES)] = gt_lines 76 | gt[:, :, np.argmax(DRAWING_COLOR_POINTS)] = gt_points 77 | 78 | save_and_resize(img, os.path.join(output_dir, 'images', '{}.jpg'.format(get_image_label_basename(image_filename))), 79 | size=size) 80 | save_and_resize(gt, os.path.join(output_dir, 'labels', '{}.png'.format(get_image_label_basename(image_filename))), 81 | size=size, nearest=True) 82 | shutil.copy(page_filename, os.path.join(output_dir, 'gt', '{}.xml'.format(get_image_label_basename(image_filename)))) 83 | 84 | 85 | def cbad_set_generator(input_dir: str, output_dir: str, img_size: int, line_thickness: int=4, 86 | draw_endpoints: bool=False, circle_thickness: int =20): 87 | """ 88 | 89 | :param input_dir: Input directory containing images and PAGE files 90 | :param output_dir: Output directory to save images and labels 91 | :param img_size: Size of the resized image (# pixels) 92 | :param line_thickness: Thickness of annotated baseline 93 | :param draw_endpoints: Predict beginning and end of baselines (True, False) 94 | :param circle_thickness: Diameter of annotated start/end points 95 | :return: 96 | """ 97 | 98 | # Get image filenames to process 99 | image_filenames_list = glob('{}/**/*.jpg'.format(input_dir)) 100 | 101 | # set 102 | os.makedirs(os.path.join('{}'.format(output_dir), 'images')) 103 | os.makedirs(os.path.join('{}'.format(output_dir), 'labels')) 104 | os.makedirs(os.path.join('{}'.format(output_dir), 'gt')) 105 | for image_filename in tqdm(image_filenames_list): 106 | annotate_one_page(image_filename, output_dir, img_size, draw_endpoints, line_thickness, circle_thickness) 107 | 108 | if draw_endpoints: 109 | classes = np.stack([(0, 0, 0), DRAWING_COLOR_BASELINES, DRAWING_COLOR_POINTS]) 110 | np.savetxt(os.path.join(output_dir, 'classes.txt'), classes, fmt='%d') 111 | else: 112 | classes = np.stack([(0, 0, 0), DRAWING_COLOR_BASELINES]) 113 | np.savetxt(os.path.join(output_dir, 'classes.txt'), classes, fmt='%d') 114 | 115 | 116 | def draw_lines_fn(xml_filename: str, output_dir: str): 117 | """ 118 | GIven an XML PAGE file, draws the corresponding lines in the original image. 119 | :param xml_filename: 120 | :param output_dir: 121 | :return: 122 | """ 123 | basename = os.path.basename(xml_filename).split('.')[0] 124 | generated_page = PAGE.parse_file(xml_filename) 125 | drawing_img = generated_page.image_filename 126 | generated_page.draw_baselines(drawing_img, color=(0, 0, 255)) 127 | imsave(os.path.join(output_dir, '{}.jpg'.format(basename)), drawing_img) -------------------------------------------------------------------------------- /exps/cbad/process.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import numpy as np 3 | from scipy.ndimage import label 4 | import cv2 5 | import os 6 | import tensorflow as tf 7 | from tqdm import tqdm 8 | from glob import glob 9 | from imageio import imsave, imread 10 | from scipy.misc import imresize 11 | from dh_segment.utils import dump_pickle 12 | from dh_segment.post_processing.binarization import hysteresis_thresholding, cleaning_probs 13 | from dh_segment.post_processing.line_vectorization import find_lines 14 | from dh_segment.post_processing import PAGE 15 | from dh_segment.loader import LoadedModel 16 | 17 | 18 | def prediction_fn(model_dir: str, input_dir: str, output_dir: str=None) -> None: 19 | """ 20 | Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir. 21 | The predictions will be saved in output_dir as .npy files (values ranging [0,255]) 22 | :param model_dir: Directory containing the saved model 23 | :param input_dir: input directory where the images to predict are 24 | :param output_dir: output directory to save the predictions (probability images) 25 | :return: 26 | """ 27 | if not output_dir: 28 | # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions' 29 | output_dir = '{}'.format(os.path.sep).join(model_dir.split(os.path.sep)[:-3] + ['predictions']) 30 | 31 | os.makedirs(output_dir, exist_ok=True) 32 | filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) + glob(os.path.join(input_dir, '*.png')) 33 | 34 | with tf.Session(): 35 | m = LoadedModel(model_dir, predict_mode='filename_original_shape') 36 | for filename in tqdm(filenames_to_predict, desc='Prediction'): 37 | pred = m.predict(filename)['probs'][0] 38 | np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), np.uint8(255 * pred)) 39 | 40 | 41 | def cbad_post_processing_fn(probs: np.array, 42 | sigma: float=2.5, 43 | low_threshold: float=0.8, 44 | high_threshold: float=0.9, 45 | filter_width: float=0, 46 | vertical_maxima: bool=False, 47 | output_basename=None) -> Tuple[List[np.ndarray], np.ndarray]: 48 | """ 49 | 50 | :param probs: output of the model (probabilities) in range [0, 255] 51 | :param sigma: 52 | :param low_threshold: 53 | :param high_threshold: 54 | :param filter_width: 55 | :param output_basename: 56 | :param vertical_maxima: 57 | :return: contours, mask 58 | WARNING : contours IN OPENCV format List[np.ndarray(n_points, 1, (x,y))] 59 | """ 60 | 61 | contours, lines_mask = line_extraction_v1(probs[:, :, 1], sigma, low_threshold, high_threshold, 62 | filter_width, vertical_maxima) 63 | if output_basename is not None: 64 | dump_pickle(output_basename+'.pkl', (contours, lines_mask.shape)) 65 | return contours, lines_mask 66 | 67 | 68 | def line_extraction_v1(probs: np.ndarray, low_threshold: float, high_threshold: float, sigma: float=0.0, 69 | filter_width: float=0.00, vertical_maxima: bool=False) -> Tuple[List[np.ndarray], np.ndarray]: 70 | # Smooth 71 | probs2 = cleaning_probs(probs, sigma=sigma) 72 | 73 | lines_mask = hysteresis_thresholding(probs2, low_threshold, high_threshold, 74 | candidates_mask=vertical_local_maxima(probs2) if vertical_maxima else None) 75 | # Remove lines touching border 76 | # lines_mask = remove_borders(lines_mask) 77 | 78 | # Extract polygons from line mask 79 | contours = find_lines(lines_mask) 80 | 81 | filtered_contours = [] 82 | page_width = probs.shape[1] 83 | for cnt in contours: 84 | centroid_x, centroid_y = np.mean(cnt, axis=0)[0] 85 | if centroid_x < filter_width*page_width or centroid_x > (1-filter_width)*page_width: 86 | continue 87 | # if cv2.arcLength(cnt, False) < filter_width*page_width: 88 | # continue 89 | filtered_contours.append(cnt) 90 | 91 | return filtered_contours, lines_mask 92 | 93 | 94 | def vertical_local_maxima(probs: np.ndarray) -> np.ndarray: 95 | local_maxima = np.zeros_like(probs, dtype=bool) 96 | local_maxima[1:-1] = (probs[1:-1] >= probs[:-2]) & (probs[2:] <= probs[1:-1]) 97 | local_maxima = cv2.morphologyEx(local_maxima.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), dtype=np.uint8)) 98 | return local_maxima > 0 99 | 100 | 101 | def remove_borders(mask: np.ndarray, margin: int=5) -> np.ndarray: 102 | tmp = mask.copy() 103 | tmp[:margin] = 1 104 | tmp[-margin:] = 1 105 | tmp[:, :margin] = 1 106 | tmp[:, -margin:] = 1 107 | label_components, count = label(tmp, np.ones((3, 3))) 108 | result = mask.copy() 109 | border_component = label_components[0, 0] 110 | result[label_components == border_component] = 0 111 | return result 112 | 113 | 114 | def extract_lines(npy_filename: str, output_dir: str, original_shape: list, post_process_params: dict, 115 | mask_dir: str=None, debug: bool=False): 116 | """ 117 | From the prediction files (probs) (.npy) finds and extracts the lines into PAGE-XML format. 118 | :param npy_filename: filename of saved predictions (probs) in range (0,255) 119 | :param output_dir: output direcoty to save the xml files 120 | :param original_shape: shpae of the original input image (to rescale the extracted lines if necessary) 121 | :param post_process_params: pramas for lines detection (sigma, thresholds, ...) 122 | :param mask_dir: directory containing masks of the page in order to improve the line extraction 123 | :param debug: if True will output the binary image of the extracted lines 124 | :return: countours of lines (open cv format), binary image of lines (lines mask) 125 | """ 126 | 127 | os.makedirs(output_dir, exist_ok=True) 128 | 129 | basename = os.path.basename(npy_filename).split('.')[0] 130 | 131 | pred = np.load(npy_filename)/255 # type: np.ndarray 132 | lines_prob = pred[:, :, 1] 133 | 134 | if mask_dir is not None: 135 | mask = imread(os.path.join(mask_dir, basename + '.png'), mode='L') 136 | mask = imresize(mask, lines_prob.shape) 137 | lines_prob[mask == 0] = 0. 138 | 139 | contours, lines_mask = line_extraction_v1(lines_prob, **post_process_params) 140 | 141 | if debug: 142 | imsave(os.path.join(output_dir, '{}_bin.jpg'.format(basename)), lines_mask) 143 | 144 | ratio = (original_shape[0] / pred.shape[0], original_shape[1] / pred.shape[1]) 145 | xml_filename = os.path.join(output_dir, basename + '.xml') 146 | PAGE.save_baselines(xml_filename, contours, ratio, initial_shape=pred.shape[:2]) 147 | 148 | return contours, lines_mask 149 | -------------------------------------------------------------------------------- /dh_segment_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import tensorflow as tf 5 | # Tensorflow logging level 6 | from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity 7 | 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'} 9 | tf.logging.set_verbosity(WARNING) 10 | from dh_segment_text import estimator_fn, utils 11 | from dh_segment_text.io import input 12 | import json 13 | 14 | try: 15 | import better_exceptions 16 | except ImportError: 17 | print('/!\ W -- Not able to import package better_exceptions') 18 | pass 19 | from tqdm import trange 20 | from sacred import Experiment 21 | 22 | ex = Experiment('dhSegment_experiment') 23 | 24 | 25 | @ex.config 26 | def default_config(): 27 | train_data = None # Directory with training data 28 | eval_data = None # Directory with validation data 29 | model_output_dir = None # Directory to output tf model 30 | restore_model = False # Set to true to continue training 31 | classes_file = None # txt file with classes values (unused for REGRESSION) 32 | gpu = '' # GPU to be used for training 33 | use_embeddings = False 34 | weights_histogram = False 35 | seed_augment = False 36 | embeddings_dim = 300 37 | prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL 38 | model_params = utils.ModelParams().to_dict() # Model parameters 39 | embeddings_params = utils.EmbeddingsParams().to_dict() # Embeddings parameters 40 | training_params = utils.TrainingParams().to_dict() # Training parameters 41 | if prediction_type == utils.PredictionType.CLASSIFICATION: 42 | assert classes_file is not None 43 | model_params['n_classes'] = utils.get_n_classes_from_file(classes_file) 44 | elif prediction_type == utils.PredictionType.REGRESSION: 45 | model_params['n_classes'] = 1 46 | elif prediction_type == utils.PredictionType.MULTILABEL: 47 | assert classes_file is not None 48 | model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file) 49 | 50 | 51 | @ex.automain 52 | def run(train_data, eval_data, model_output_dir, gpu, training_params, use_embeddings, embeddings_dim, _config): 53 | tf.set_random_seed(_config['seed']) 54 | # Create output directory 55 | if not os.path.isdir(model_output_dir): 56 | os.makedirs(model_output_dir) 57 | else: 58 | assert _config.get('restore_model'), \ 59 | '{0} already exists, you cannot use it as output directory. ' \ 60 | 'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir) 61 | # Save config 62 | with open(os.path.join(model_output_dir, 'config.json'), 'w') as f: 63 | json.dump(_config, f, indent=4, sort_keys=True) 64 | 65 | # Create export directory for saved models 66 | saved_model_dir = os.path.join(model_output_dir, 'export') 67 | if not os.path.isdir(saved_model_dir): 68 | os.makedirs(saved_model_dir) 69 | 70 | training_params = utils.TrainingParams.from_dict(training_params) 71 | 72 | session_config = tf.ConfigProto() 73 | session_config.gpu_options.visible_device_list = str(gpu) 74 | session_config.gpu_options.per_process_gpu_memory_fraction = 1.0 75 | estimator_config = tf.estimator.RunConfig().replace(session_config=session_config, 76 | save_summary_steps=10, 77 | keep_checkpoint_max=1, 78 | tf_random_seed=_config['seed']) 79 | estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir, 80 | params=_config, config=estimator_config) 81 | 82 | def get_dirs_or_files(input_data): 83 | if os.path.isdir(input_data): 84 | image_input, labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels') 85 | # Check if training dir exists 86 | assert os.path.isdir(image_input), "{} is not a directory".format(image_input) 87 | assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input) 88 | 89 | elif os.path.isfile(input_data) and input_data.endswith('.csv'): 90 | image_input = input_data 91 | labels_input = None 92 | else: 93 | raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data)) 94 | return image_input, labels_input 95 | 96 | train_input, train_labels_input = get_dirs_or_files(train_data) 97 | if eval_data is not None: 98 | eval_input, eval_labels_input = get_dirs_or_files(eval_data) 99 | 100 | # Configure exporter 101 | serving_input_fn = input.serving_input_filename(training_params.input_resized_size, use_embeddings=use_embeddings, embeddings_dim=embeddings_dim) 102 | exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2) 103 | 104 | #if eval_data is not None: 105 | # exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2) 106 | #else: 107 | # exporter = tf.estimator.LatestExporter(name='SimpleExporter', serving_input_receiver_fn=serving_input_fn, 108 | # exports_to_keep=5) 109 | 110 | nb_cores = os.cpu_count() 111 | if nb_cores: 112 | num_threads = min(nb_cores//2, 16) 113 | else: 114 | num_threads = 4 115 | 116 | for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'): 117 | estimator.train(input.input_fn(train_input, 118 | input_label_dir=train_labels_input, 119 | num_epochs=training_params.evaluate_every_epoch, 120 | batch_size=training_params.batch_size, 121 | data_augmentation=training_params.data_augmentation, 122 | make_patches=training_params.make_patches, 123 | image_summaries=True, 124 | params=_config, 125 | num_threads=num_threads, 126 | progressbar_description="Training".format(i), 127 | seed=_config['seed'])) 128 | 129 | if eval_data is not None: 130 | eval_result = estimator.evaluate(input.input_fn(eval_input, 131 | input_label_dir=eval_labels_input, 132 | batch_size=1, 133 | data_augmentation=False, 134 | make_patches=False, 135 | image_summaries=False, 136 | params=_config, 137 | num_threads=num_threads, 138 | progressbar_description="Evaluation" 139 | )) 140 | 141 | else: 142 | eval_result = None 143 | 144 | exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result, 145 | is_the_final_export=False) 146 | -------------------------------------------------------------------------------- /dh_segment_text/utils/params_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "solivr" 3 | __license__ = "GPL" 4 | 5 | from .misc import get_class_from_name 6 | from ..network.model import Encoder, Decoder 7 | from ..embeddings.encoder import EmbeddingsEncoder 8 | from typing import Type 9 | 10 | 11 | class PredictionType: 12 | """ 13 | 14 | :cvar CLASSIFICATION: 15 | :cvar REGRESSION: 16 | :cvar MULTILABEL: 17 | """ 18 | CLASSIFICATION = 'CLASSIFICATION' 19 | REGRESSION = 'REGRESSION' 20 | MULTILABEL = 'MULTILABEL' 21 | 22 | @classmethod 23 | def parse(cls, prediction_type) -> 'PredictionType': 24 | if prediction_type == 'CLASSIFICATION': 25 | return PredictionType.CLASSIFICATION 26 | elif prediction_type == 'REGRESSION': 27 | return PredictionType.REGRESSION 28 | elif prediction_type == 'MULTILABEL': 29 | return PredictionType.MULTILABEL 30 | else: 31 | raise NotImplementedError('Unknown prediction type : {}'.format(prediction_type)) 32 | 33 | 34 | class BaseParams: 35 | def to_dict(self): 36 | return self.__dict__ 37 | 38 | @classmethod 39 | def from_dict(cls, d): 40 | result = cls() 41 | keys = result.to_dict().keys() 42 | for k, v in d.items(): 43 | assert k in keys, k 44 | setattr(result, k, v) 45 | result.check_params() 46 | return result 47 | 48 | def check_params(self): 49 | pass 50 | 51 | 52 | class ModelParams(BaseParams): 53 | """ 54 | Parameters related to the model 55 | :param encoder_name: 56 | :param encoder_params: 57 | :param decoder_name: 58 | :param decoder_params: 59 | :param n_classes: 60 | """ 61 | def __init__(self, **kwargs): 62 | self.encoder_name = kwargs.get('encoder_name', 'dh_segment_text.network.pretrained_models.ResnetV1_50') # type: str 63 | self.encoder_params = kwargs.get('encoder_params', { 64 | 'weight_decay': 1e-6 65 | }) # type: dict 66 | self.decoder_name = kwargs.get('decoder_name', 'dh_segment_text.network.SimpleDecoder') # type: str 67 | self.decoder_params = kwargs.get('decoder_params', { 68 | 'upsampling_dims': [32, 64, 128, 256, 512], 69 | 'max_depth': 512, 70 | 'weight_decay': 1e-6 71 | }) # type: dict 72 | self.n_classes = kwargs.get('n_classes', None) # type: int 73 | 74 | self.check_params() 75 | 76 | def get_encoder(self) -> Type[Encoder]: 77 | encoder = get_class_from_name(self.encoder_name) 78 | assert issubclass(encoder, Encoder), "{} is not an Encoder".format(encoder) 79 | return encoder 80 | 81 | def get_decoder(self) -> Type[Decoder]: 82 | decoder = get_class_from_name(self.decoder_name) 83 | assert issubclass(decoder, Decoder), "{} is not a Decoder".format(decoder) 84 | return decoder 85 | 86 | def check_params(self): 87 | self.get_encoder() 88 | self.get_decoder() 89 | 90 | class EmbeddingsParams(BaseParams): 91 | def __init__(self, **kwargs): 92 | self.target_dim = kwargs.get('target_dim', 8) 93 | self.encoder_name = kwargs.get('encoder_name', "dh_segment_text.embeddings.PCAEncoder") 94 | self.encoder_params = kwargs.get('encoder_params', dict()) 95 | self.check_params() 96 | 97 | def get_encoder(self) -> Type[EmbeddingsEncoder]: 98 | encoder = get_class_from_name(self.encoder_name) 99 | assert issubclass(encoder, EmbeddingsEncoder), f"{encoder} is not an EmbeddingsEncoder" 100 | return encoder 101 | 102 | def check_params(self): 103 | self.get_encoder() 104 | 105 | 106 | class TrainingParams(BaseParams): 107 | """Parameters to configure training process 108 | 109 | :ivar n_epochs: number of epoch for training 110 | :vartype n_epochs: int 111 | :ivar evaluate_every_epoch: the model will be evaluated every `n` epochs 112 | :vartype evaluate_every_epoch: int 113 | :ivar learning_rate: the starting learning rate value 114 | :vartype learning_rate: float 115 | :ivar exponential_learning: option to use exponential learning rate 116 | :vartype exponential_learning: bool 117 | :ivar batch_size: size of batch 118 | :vartype batch_size: int 119 | :ivar data_augmentation: option to use data augmentation (by default is set to False) 120 | :vartype data_augmentation: bool 121 | :ivar data_augmentation_flip_lr: option to use image flipping in right-left direction 122 | :vartype data_augmentation_flip_lr: bool 123 | :ivar data_augmentation_flip_ud: option to use image flipping in up down direction 124 | :vartype data_augmentation_flip_ud: bool 125 | :ivar data_augmentation_color: option to use data augmentation with color 126 | :vartype data_augmentation_color: bool 127 | :ivar data_augmentation_max_rotation: maximum angle of rotation (in radians) for data augmentation 128 | :vartype data_augmentation_max_rotation: float 129 | :ivar data_augmentation_max_scaling: maximum scale of zooming during data augmentation (range: [0,1]) 130 | :vartype data_augmentation_max_scaling: float 131 | :ivar make_patches: option to crop image into patches. This will cut the entire image in several patches 132 | :vartype make_patches: bool 133 | :ivar patch_shape: shape of the patches 134 | :vartype patch_shape: tuple 135 | :ivar input_resized_size: size (in pixel) of the image after resizing. The original ratio is kept. If no resizing \ 136 | is wanted, set it to -1 137 | :vartype input_resized_size: int 138 | :ivar weights_labels: weight given to each label. Should be a list of length = number of classes 139 | :vartype weights_labels: list 140 | :ivar training_margin: size of the margin to add to the images. This is particularly useful when training with \ 141 | patches 142 | :vartype training_margin: int 143 | :ivar local_entropy_ratio: 144 | :vartype local_entropy_ratio: float 145 | :ivar local_entropy_sigma: 146 | :vartype local_entropy_sigma: float 147 | :ivar focal_loss_gamma: value of gamma for the focal loss. See paper : https://arxiv.org/abs/1708.02002 148 | :vartype focal_loss_gamma: float 149 | """ 150 | def __init__(self, **kwargs): 151 | self.n_epochs = kwargs.get('n_epochs', 20) 152 | self.evaluate_every_epoch = kwargs.get('evaluate_every_epoch', 10) 153 | self.learning_rate = kwargs.get('learning_rate', 1e-5) 154 | self.exponential_learning = kwargs.get('exponential_learning', True) 155 | self.cosine_restart_learning = kwargs.get('cosine_restart_learning', False) 156 | self.adamw_optimizer = kwargs.get('adamw_optimizer', False) 157 | self.batch_size = kwargs.get('batch_size', 5) 158 | self.data_augmentation = kwargs.get('data_augmentation', False) 159 | self.data_augmentation_flip_lr = kwargs.get('data_augmentation_flip_lr', False) 160 | self.data_augmentation_flip_ud = kwargs.get('data_augmentation_flip_ud', False) 161 | self.data_augmentation_color = kwargs.get('data_augmentation_color', False) 162 | self.data_augmentation_max_rotation = kwargs.get('data_augmentation_max_rotation', 0.2) 163 | self.data_augmentation_max_scaling = kwargs.get('data_augmentation_max_scaling', 0.05) 164 | self.make_patches = kwargs.get('make_patches', True) 165 | self.patch_shape = kwargs.get('patch_shape', (300, 300)) 166 | self.input_resized_size = int(kwargs.get('input_resized_size', 72e4)) # (600*1200) 167 | self.weights_labels = kwargs.get('weights_labels') 168 | self.weights_evaluation_miou = kwargs.get('weights_evaluation_miou', None) 169 | self.training_margin = kwargs.get('training_margin', 16) 170 | self.local_entropy_ratio = kwargs.get('local_entropy_ratio', 0.) 171 | self.local_entropy_sigma = kwargs.get('local_entropy_sigma', 3) 172 | self.focal_loss_gamma = kwargs.get('focal_loss_gamma', 0.) 173 | 174 | def check_params(self) -> None: 175 | """Checks if there is no parameter inconsistency 176 | """ 177 | assert self.training_margin*2 < min(self.patch_shape) 178 | --------------------------------------------------------------------------------