├── .gitignore ├── LICENSE ├── README.md ├── chemschematicresolver ├── __init__.py ├── actions.py ├── clean.py ├── decorators.py ├── dict │ ├── spelling.txt │ └── superatom.txt ├── extract.py ├── io.py ├── model.py ├── ocr.py ├── parse.py ├── r_group.py ├── utils.py └── validate.py ├── setup.py └── tests ├── data └── S014372081630122X_gr1.jpg ├── eval_markush.py ├── eval_osra.py ├── test_actions.py ├── test_extract.py ├── test_io.py ├── test_model.py ├── test_ocr.py ├── test_parse.py ├── test_r_group.py ├── test_system.py └── test_validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.idea/ 2 | /tests/.cache/ 3 | /tests/train* 4 | /training/ 5 | **/output/ 6 | /stats/* 7 | debug.py 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Edward Beard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChemSchematicResolver 2 | **ChemSchematicResolver** is a toolkit for the automatic resolution of chemical schematic diagrams and their labels. You can find out how it works on the [website](http://www.chemschematicresolver.org) , and try out the online demo [here](http://www.chemschematicresolver.org/demo) 3 | 4 | ## Features 5 | 6 | - Extraction of generic R-Group structures 7 | - Automatic detection and download of schematic chemical diagrams from scientific articles 8 | - HTML and XML document format support from RSC and Elsevier 9 | - High-throughput capabilities 10 | - Direct extraction from image files 11 | - PNG, GIF, JPEG, TIFF image format support 12 | 13 | ## Installation 14 | 15 | Installation of ChemSchematicResolver is achieved using [conda](https://docs.conda.io/en/latest). 16 | 17 | First, install [Miniconda](https://docs.conda.io/en/latest/miniconda.html), which contains a complete Python distribution alongside the conda package manager. 18 | 19 | Next, go to the command line terminal and create a working environment by typing 20 | 21 | conda create --name python=3.6 22 | 23 | Once this is created, enter this environment with the command 24 | 25 | conda activate 26 | 27 | There are two ways to continue the installation - via the anaconda cloud, and from source. 28 | 29 | ### Option 1 - Installation via anaconda 30 | 31 | *Please note that the following option will not work until release.* 32 | 33 | We recommend the installation of ChemSchematicResolver through the anaconda cloud. 34 | 35 | Simply type: 36 | 37 | conda install -c edbeard -c mcs07 -c conda-forge chemschematicresolver 38 | 39 | This command installs ChemSchematicResolver and all it's dependencies from the author's channel. 40 | This includes [pyosra](https://github.com/edbeard/pyosra), the Python wrapper for the OSRA toolkit, and [ChemDataExtractor-CSR](https://github.com/edbeard/chemdataextractor-csr), the bespoke version of ChemDataExtractor containing diagram parsers. 41 | 42 | *This method of installation is currently supported on linux machines only* 43 | 44 | ### Option 2 - Installation from source 45 | 46 | *Please note that all following links will not work until release.* 47 | 48 | We strongly recommend installation via the conda cloud whenever possible, as all the dependencies are automatically handled. 49 | 50 | If this cannot be done, users are invited to compile the code from source. This is easiest to do through [conda build](https://docs.conda.io/projects/conda-build/en/latest/), by building and installing using the recipes [here](https://github.com/edbeard/conda_recipes). 51 | 52 | The following packages will need to be built from a recipe, in the order below: 53 | 54 | 1. **Pyosra**: [[recipe](https://github.com/edbeard/conda_recipes/tree/master/pyosra), [source code](https://github.com/edbeard/pyosra)] 55 | 56 | 2. **ChemDataExtracor-CSR**: [[recipe](https://github.com/edbeard/conda_recipes/tree/master/cde-csr/recipes/chemdataextractor), [source code](https://github.com/edbeard/chemdataextractor-csr)] 57 | 58 | 3. **ChemSchematicResolver**: [[recipe](https://github.com/edbeard/conda_recipes/tree/master/csr), [source code](https://github.com/edbeard/ChemSchematicResolver)] 59 | 60 | For each, enter the directory and run: 61 | 62 | conda build . 63 | 64 | to create a compressed tarball file, which contains the instructions for installing the code *(Please note that this can take up to 30 minutes to build)*. 65 | 66 | Move all compressed tarballs to a single directory, enter the directory and run: 67 | 68 | conda index . 69 | 70 | This changes the directory to a format emulating a conda channel. To install all code and dependencies, then simply run 71 | 72 | conda install -c chemschematicresolver 73 | 74 | And you should have everything installed! 75 | 76 | 77 | # Getting Started 78 | 79 | This section gives a introduction on how to get started with ChemSchematicResolver. This assumes you already have 80 | ChemSchematicResolver and all dependencies installed. 81 | 82 | ## Extract Image 83 | It's simplest to run ChemSchematicResolver on an image file. 84 | 85 | Open a python terminal and import the library with: 86 | 87 | >>> import chemschematicresolver as csr 88 | 89 | Then run: 90 | 91 | >>> result = csr.extract_image('') 92 | 93 | to perform the extraction. 94 | 95 | This runs ChemSchematicResolver on the image and returns a list of tuples to `output`. Each tuple consists of a SMILES string and a list of label candidates, where each tuple identifies a unique structure. For example: 96 | 97 | >>> print(result) 98 | [(['1a'], 'C1CCCCC1'), (['1b'], 'CC1CCCCC1')] 99 | 100 | ## Extract Document 101 | 102 | To automatically extract the structures and labels of diagrams from a HTML or XML article, use the `extract_document` method instead: 103 | 104 | >>> result = csr.extract_document('') 105 | 106 | If the user has permissions to access the full article, this function will download all relevant images locally to a directory called *csr*, and extract from them automatically. The *csr* directory with then be deleted. 107 | 108 | The tool currently supports HTML documents from the [Royal Society of Chemistry](https://www.rsc.org/) and [Springer](https://www.springer.com), as well as XML files obtained using the [Elsevier Developers Portal](https://dev.elsevier.com/index.html) . 109 | 110 | ChemSchematicResolver will return the complete chemical records from the document extracted with [ChemDataExtractor](http://chemdataextractor.org/), enriched with extracted structure and raw label. For example: 111 | 112 | >>> print(result) 113 | {'labels': ['1a'], 'roles': ['compound'], 'melting_points': [{'value': '5', 'units': '°C'}], 'diagram': { 'smiles': 'C1CCCCC1', 'label': '1a' } } 114 | 115 | Alternatively, if you just want the structures and labels extracted from the images without the ChemDataExtractor output, run: 116 | 117 | >>> result = csr.extract_document('', extract_all=False) 118 | 119 | which, for the above example, will return: 120 | 121 | >>> print(output) 122 | [(['1a'], 'C1CCCCC1')] 123 | -------------------------------------------------------------------------------- /chemschematicresolver/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """ 4 | ChemSchematicResolver 5 | =================== 6 | 7 | Automatically extract data from schematic chemical diagrams 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | 17 | __title__ = 'ChemSchematicResolver' 18 | __version__ = '0.0.1' 19 | __author__ = 'Ed Beard' 20 | __email__ = 'ed.beard94@gmail.com' 21 | __copyright__ = 'Copyright 2019 Ed Beard, All rights reserved.' 22 | 23 | 24 | log = logging.getLogger(__name__) 25 | log.addHandler(logging.NullHandler()) 26 | 27 | from .extract import extract_image, extract_images, extract_document, extract_documents 28 | -------------------------------------------------------------------------------- /chemschematicresolver/actions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Image Processing Actions 4 | ======================== 5 | 6 | A toolkit of image processing actions for segmentation. 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk, ed.beard94@gmail.com 10 | 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | from __future__ import unicode_literals 17 | import logging 18 | 19 | import numpy as np 20 | import os 21 | from skimage.util import pad 22 | from skimage.measure import regionprops 23 | 24 | import itertools 25 | import copy 26 | from scipy import ndimage as ndi 27 | from sklearn.cluster import KMeans 28 | import osra_rgroup 29 | 30 | from .model import Panel, Diagram, Label, Rect, Figure 31 | from .io import imsave, imdel 32 | from .clean import find_repeating_unit, clean_output 33 | from .utils import crop, skeletonize, binarize, binary_close, binary_floodfill, merge_rect, merge_overlap 34 | 35 | # Standard path to superatom dictionary file 36 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 37 | superatom_file = os.path.join(parent_dir, 'dict', 'superatom.txt') 38 | spelling_file = os.path.join(parent_dir, 'dict', 'spelling.txt') 39 | 40 | 41 | log = logging.getLogger(__name__) 42 | 43 | 44 | def segment(fig): 45 | """ Segments image. 46 | 47 | :param fig: Input Figure 48 | :return panels: List of segmented Panel objects 49 | """ 50 | 51 | bin_fig = binarize(fig) 52 | 53 | bbox = fig.get_bounding_box() 54 | skel_pixel_ratio = skeletonize_area_ratio(fig, bbox) 55 | 56 | log.debug(" The skeletonized pixel ratio is %s" % skel_pixel_ratio) 57 | 58 | # Choose kernel size according to skeletonized pixel ratio 59 | if skel_pixel_ratio > 0.025: 60 | kernel = 4 61 | closed_fig = binary_close(bin_fig, size=kernel) 62 | log.debug("Segmentation kernel size = %s" % kernel) 63 | 64 | elif 0.02 < skel_pixel_ratio <= 0.025: 65 | kernel = 6 66 | closed_fig = binary_close(bin_fig, size=kernel) 67 | log.debug("Segmentation kernel size = %s" % kernel) 68 | 69 | elif 0.015 < skel_pixel_ratio <= 0.02: 70 | kernel = 10 71 | closed_fig = binary_close(bin_fig, size=kernel) 72 | log.debug("Segmentation kernel size = %s" % kernel) 73 | 74 | else: 75 | kernel = 15 76 | closed_fig = binary_close(bin_fig, size=kernel) 77 | log.debug("Segmentation kernel size = %s" % kernel) 78 | 79 | # Using a binary floodfill to identify panel regions 80 | fill_img = binary_floodfill(closed_fig) 81 | tag_img = binary_tag(fill_img) 82 | panels = get_bounding_box(tag_img) 83 | 84 | # Removing relatively tiny pixel islands that are determined to be noise 85 | area_threshold = fig.get_bounding_box().area / 200 86 | width_threshold = fig.get_bounding_box().width / 150 87 | panels = [panel for panel in panels if panel.area > area_threshold or panel.width > width_threshold] 88 | return panels 89 | 90 | 91 | def classify_kmeans(panels, fig, skel=True): 92 | """Takes input image and classifies through k means cluster of the panel area""" 93 | 94 | if len(panels) <= 1: 95 | raise Exception('Only one panel detected. Cannot cluster') 96 | return get_labels_and_diagrams_k_means_clustering(panels, fig, skel) 97 | 98 | 99 | def get_labels_and_diagrams_k_means_clustering(panels, fig, skel=True): 100 | """ Splits into labels and diagrams using K-means clustering by the skeletonized area ratio or panel height. 101 | 102 | :param panels: List of Panel objects to be clustered 103 | :param fig: Input Figure 104 | :param skel: Boolean indication the clustering parameters to use 105 | :return Lists of Labels and Diagrams after clustering 106 | """ 107 | 108 | cluster_params = [] 109 | 110 | for panel in panels: 111 | if skel: 112 | cluster_params.append([skeletonize_area_ratio(fig, panel)]) 113 | else: 114 | cluster_params.append([panel.height]) 115 | 116 | all_params = np.array(cluster_params) 117 | 118 | km = KMeans(n_clusters=2) 119 | clusters = km.fit(all_params) 120 | 121 | group_1, group_2 = [], [] 122 | 123 | for i, cluster in enumerate(clusters.labels_): 124 | if cluster == 0: 125 | group_1.append(panels[i]) 126 | else: 127 | group_2.append(panels[i]) 128 | 129 | if np.nanmean([panel.area for panel in group_1]) > np.nanmean([panel.area for panel in group_2]): 130 | diags = group_1 131 | labels = group_2 132 | else: 133 | diags = group_2 134 | labels = group_1 135 | 136 | # Convert to appropriate types 137 | labels = [Label(label.left, label.right, label.top, label.bottom, label.tag) for label in labels] 138 | diags = [Diagram(diag.left, diag.right, diag.top, diag.bottom, diag.tag) for diag in diags] 139 | return labels, diags 140 | 141 | 142 | def preprocessing(labels, diags, fig): 143 | """Pre-processing steps before final K-means classification 144 | :param labels: List of Label objects 145 | :param diags: List of Diagram objects 146 | :param fig: Figure object 147 | 148 | :return out_labels: List of Labels after merging and re-tagging 149 | :return out_diags: List of Diagrams after re-tagging 150 | """ 151 | 152 | # Remove repeating unit indicators 153 | labels, diags = find_repeating_unit(labels, diags, fig) 154 | 155 | # Remove small pixel islands from diagrams 156 | diags = remove_diag_pixel_islands(diags, fig) 157 | 158 | # Merge labels together that are sufficiently local 159 | label_candidates_horizontally_merged = merge_label_horizontally(labels,fig) 160 | label_candidates_fully_merged = merge_labels_vertically(label_candidates_horizontally_merged) 161 | labels_converted = convert_panels_to_labels(label_candidates_fully_merged) 162 | 163 | # Re-tagging all diagrams and labels 164 | retagged_panels = retag_panels(labels_converted + diags) 165 | out_labels = retagged_panels[:len(labels_converted)] 166 | out_diags = retagged_panels[len(labels_converted):] 167 | 168 | return out_labels, out_diags 169 | 170 | 171 | def label_diags(labels, diags, fig_bbox): 172 | """ Pair all Diagrams to Labels. 173 | 174 | :param labels: List of Label objects 175 | :param diags: List of Diagram objects 176 | :param fig_bbox: Co-ordinates of the bounding box of the entire figure 177 | 178 | :returns: List of Diagrams with assigned Labels 179 | 180 | """ 181 | 182 | # Sort diagrams from largest to smallest 183 | diags.sort(key=lambda x: x.area, reverse=True) 184 | initial_sorting = [assign_label_to_diag(diag, labels, fig_bbox) for diag in diags] 185 | 186 | # Identify failures by the presence of duplicate labels 187 | failed_diag_label = get_duplicate_labelling(initial_sorting) 188 | 189 | if len(failed_diag_label) == 0: 190 | return initial_sorting 191 | 192 | # Find average position of label relative to diagram for successful pairings (denoted with compass points: NSEW) 193 | successful_diag_label = [diag for diag in diags if diag not in failed_diag_label] 194 | 195 | # Where no sucessful pairings found, attempt looking 'South' for all diagrams (most common relative label position) 196 | if len(successful_diag_label) == 0: 197 | altered_sorting = [assign_label_to_diag_postprocessing(diag, labels, 'S', fig_bbox) for diag in failed_diag_label] 198 | if len(get_duplicate_labelling(altered_sorting)) != 0: 199 | altered_sorting = initial_sorting 200 | pass 201 | else: 202 | return altered_sorting 203 | else: 204 | # Get compass positions of labels relative to diagram 205 | diag_compass = [diag.compass_position(diag.label) for diag in successful_diag_label if diag.label] 206 | mode_compass = max(diag_compass, key=diag_compass.count) 207 | 208 | # Expand outwards in compass direction for all failures 209 | altered_sorting = [assign_label_to_diag_postprocessing(diag, labels, mode_compass, fig_bbox) for diag in failed_diag_label] 210 | 211 | # Check for duplicates after relabelling 212 | failed_diag_label = get_duplicate_labelling(altered_sorting + successful_diag_label) 213 | successful_diag_label = [diag for diag in successful_diag_label if diag not in failed_diag_label] 214 | 215 | # If no duplicates return all results 216 | if len(failed_diag_label) == 0: 217 | return altered_sorting + successful_diag_label 218 | 219 | # Add non duplicates to successes 220 | successful_diag_label.extend([diag for diag in altered_sorting if diag not in failed_diag_label]) 221 | 222 | # Remove duplicate results 223 | diags_with_labels, diags_without_labels = remove_duplicates(failed_diag_label, fig_bbox) 224 | 225 | return diags_with_labels + successful_diag_label 226 | 227 | 228 | def assign_label_to_diag(diag, labels, fig_bbox, rate=1): 229 | """ Iteratively expands the bounding box of diagram until it intersects a Label object 230 | 231 | :param diag: Input Diagram object to expand from 232 | :param labels: List of Label objects 233 | :param fig_bbox: Panel object representing the co-ordinates for the entire Figure 234 | :param rate: Number of pixels to expand by upon each iteration 235 | 236 | :return diag: Diagram with Label object assigned 237 | """ 238 | 239 | probe_rect = Rect(diag.left, diag.right, diag.top, diag.bottom) 240 | found = False 241 | max_threshold_width = fig_bbox.width 242 | max_threshold_height = fig_bbox.height 243 | 244 | while found is False and (probe_rect.width < max_threshold_width or probe_rect.height < max_threshold_height): 245 | # Increase border value each loop 246 | probe_rect.right = probe_rect.right + rate 247 | probe_rect.bottom = probe_rect.bottom + rate 248 | probe_rect.left = probe_rect.left - rate 249 | probe_rect.top = probe_rect.top - rate 250 | 251 | for label in labels: 252 | if probe_rect.overlaps(label): 253 | found = True 254 | diag.label = label 255 | return diag 256 | 257 | 258 | def assign_label_to_diag_postprocessing(diag, labels, direction, fig_bbox, rate=1): 259 | """ Iteratively expands the bounding box of diagram in the specified compass direction 260 | 261 | :param diag: Input Diagram object to expand from 262 | :param labels: List of Label objects 263 | :param direction: String representing determined compass direction (allowed values: 'E', 'S', 'W', 'N') 264 | :param fig_bbox: Panel object representing the co-ordinates for the entire Figure 265 | :param rate: Number of pixels to expand by upon each iteration 266 | """ 267 | 268 | probe_rect = Rect(diag.left, diag.right, diag.top, diag.bottom) 269 | found = False 270 | 271 | def label_loop(): 272 | 273 | for label in labels: 274 | # Only accepting labels in the average direction 275 | if diag.compass_position(label) != direction: 276 | pass 277 | elif probe_rect.overlaps(label): 278 | diag.label = label 279 | return True 280 | 281 | return False 282 | 283 | # Increase border value each loop 284 | if direction == 'E': 285 | while found is False and probe_rect.right < fig_bbox.right: 286 | probe_rect.right = probe_rect.right + rate 287 | found = label_loop() 288 | 289 | elif direction == 'S': 290 | while found is False and probe_rect.bottom < fig_bbox.bottom: 291 | probe_rect.bottom = probe_rect.bottom + rate 292 | found = label_loop() 293 | 294 | elif direction == 'W': 295 | while found is False and probe_rect.left > fig_bbox.left: 296 | probe_rect.left = probe_rect.left - rate 297 | found = label_loop() 298 | 299 | elif direction == 'N': 300 | while found is False and probe_rect.top > fig_bbox.top: 301 | probe_rect.top = probe_rect.top - rate 302 | found = label_loop() 303 | else: 304 | return diag 305 | 306 | return diag 307 | 308 | 309 | def read_diagram_pyosra(diag, extension='jpg', debug=False, superatom_path=superatom_file, spelling_path=spelling_file): 310 | """ Converts a diagram to SMILES using pyosra 311 | 312 | :param diag: Diagram to be extracted 313 | :param extension: String file extension 314 | :param debug: Bool inicating debug mode 315 | 316 | :return smile: String of extracted chemical SMILE 317 | 318 | """ 319 | 320 | # Add some padding to image to help resolve characters on the edge 321 | padded_img = pad(diag.fig.img, ((5, 5), (5, 5), (0, 0)), mode='constant', constant_values=1) 322 | 323 | # Save a temp image 324 | temp_img_fname = 'osra_temp.' + extension 325 | imsave(temp_img_fname, padded_img) 326 | 327 | # Run osra on temp image 328 | smile = osra_rgroup.read_diagram(temp_img_fname, debug=debug, superatom_file=superatom_path, spelling_file=spelling_path) 329 | 330 | if not smile: 331 | log.warning('No SMILES string was extracted for diagram %s' % diag.tag) 332 | 333 | if not debug: 334 | imdel(temp_img_fname) 335 | 336 | smile = clean_output(smile) 337 | return smile 338 | 339 | 340 | def remove_diag_pixel_islands(diags, fig): 341 | """ Removes small pixel islands from the diagram 342 | 343 | :param diags: List of input Diagrams 344 | :param fig: Figure object 345 | 346 | :return diags: List of Diagrams with small pixel islands removed 347 | 348 | """ 349 | 350 | for diag in diags: 351 | 352 | # Make a cleaned copy of image to be used when resolving diagrams 353 | clean_fig = copy.deepcopy(fig) 354 | 355 | diag_fig = Figure(crop(clean_fig.img, diag.left, diag.right, diag.top, diag.bottom)) 356 | seg_fig = Figure(crop(clean_fig.img, diag.left, diag.right, diag.top, diag.bottom)) 357 | sub_panels = segment(seg_fig) 358 | 359 | panel_areas = [panel.area for panel in sub_panels] 360 | diag_area = max(panel_areas) 361 | 362 | sub_panels = [panel for panel in sub_panels if panel.area != diag_area] 363 | 364 | sub_bbox = [(panel.left, panel.right, panel.top, panel.bottom) for panel in sub_panels] 365 | 366 | for bbox in sub_bbox: 367 | diag_fig.img[bbox[2]:bbox[3], bbox[0]:bbox[1]] = np.ones(3) 368 | 369 | diag.fig = diag_fig 370 | 371 | return diags 372 | 373 | 374 | def pixel_ratio(fig, diag): 375 | """ Calculates the ratio of 'on' pixels to bounding box area for binary figure 376 | 377 | :param fig : Input binary Figure 378 | :param diag : Area to calculate pixel ratio 379 | 380 | :return ratio: Float detailing ('on' pixels / bounding box area) 381 | """ 382 | 383 | cropped_img = crop(fig.img, diag.left, diag.right, diag.top, diag.bottom) 384 | ones = np.count_nonzero(cropped_img) 385 | all_pixels = np.size(cropped_img) 386 | ratio = ones / all_pixels 387 | return ratio 388 | 389 | 390 | def binary_tag(fig): 391 | """ Tag connected regions with pixel value of 1 392 | 393 | :param fig: Input Figure 394 | :returns fig: Connected Figure 395 | """ 396 | fig.img, no_tagged = ndi.label(fig.img) 397 | return fig 398 | 399 | 400 | def get_bounding_box(fig): 401 | """ Gets the bounding box of each segment 402 | 403 | :param fig: Input Figure 404 | :returns panels: List of panel objects 405 | """ 406 | panels = [] 407 | regions = regionprops(fig.img) 408 | for region in regions: 409 | y1, x1, y2, x2 = region.bbox 410 | panels.append(Panel(x1, x2, y1, y2, region.label - 1))# Sets tags to start from 0 411 | return panels 412 | 413 | 414 | def retag_panels(panels): 415 | """ Re-tag panels. 416 | 417 | :param panels: List of Panel objects 418 | :returns: List of re-tagged Panel objects 419 | """ 420 | 421 | for i, panel in enumerate(panels): 422 | panel.tag = i 423 | return panels 424 | 425 | 426 | def skeletonize_area_ratio(fig, panel): 427 | """ Calculates the ratio of skeletonized image pixels to total number of pixels 428 | 429 | :param fig: Input figure 430 | :param panel: Original panel object 431 | :return: Float : Ratio of skeletonized pixels to total area (see pixel_ratio) 432 | """ 433 | 434 | skel_fig = skeletonize(fig) 435 | return pixel_ratio(skel_fig, panel) 436 | 437 | 438 | def order_by_area(panels): 439 | """ Returns a list of panel objects ordered by area. 440 | 441 | :param panels: Input list of Panels 442 | :return panels: Output list of sorted Panels 443 | """ 444 | 445 | def get_area(panel): 446 | return panel.area 447 | 448 | panels.sort(key=get_area) 449 | return panels 450 | 451 | 452 | def merge_label_horizontally(merge_candidates, fig): 453 | """ Iteratively attempt to merge horizontally 454 | 455 | :param merge_candidates: Input list of Panels to be merged 456 | :return merge_candidates: List of Panels after merging 457 | """ 458 | 459 | done = False 460 | 461 | # Identifies panels within horizontal merging criteria 462 | while done is False: 463 | ordered_panels = order_by_area(merge_candidates) 464 | merge_candidates, done = merge_loop_horizontal(ordered_panels, fig) 465 | 466 | merge_candidates, done = merge_all_overlaps(merge_candidates) 467 | return merge_candidates 468 | 469 | 470 | def merge_labels_vertically(merge_candidates): 471 | """ Iteratively attempt to merge vertically 472 | 473 | :param merge_candidates: Input list of Panels to be merged 474 | :return merge_candidates: List of Panels after merging 475 | """ 476 | 477 | # Identifies panels within horizontal merging criteria 478 | ordered_panels = order_by_area(merge_candidates) 479 | merge_candidates = merge_loop_vertical(ordered_panels) 480 | 481 | merge_candidates, done = merge_all_overlaps(merge_candidates) 482 | return merge_candidates 483 | 484 | 485 | def merge_loop_horizontal(panels, fig_input): 486 | """ Iteratively merges panels by relative proximity to each other along the x axis. 487 | This is repeated until no panels are merged by the algorithm 488 | 489 | :param panels: List of Panels to be merged. 490 | 491 | :return output_panels: List of merged panels 492 | :return done: Bool indicating whether a merge occurred 493 | """ 494 | 495 | output_panels = [] 496 | blacklisted_panels = [] 497 | done = True 498 | 499 | for a, b in itertools.combinations(panels, 2): 500 | 501 | # Check panels lie in roughly the same line, that they are of label size and similar height 502 | if abs(a.center[1] - b.center[1]) < 1.5 * a.height \ 503 | and abs(a.height - b.height) < min(a.height, b.height): 504 | 505 | # Check that the distance between the edges of panels is not too large 506 | if (0 <= a.left - b.right < (min(a.height, b.height) * 2)) or (0 <= (b.left - a.right) < (min(a.height, b.height) * 2)): 507 | 508 | merged_rect = merge_rect(a, b) 509 | merged_panel = Panel(merged_rect.left, merged_rect.right, merged_rect.top, merged_rect.bottom, 0) 510 | output_panels.append(merged_panel) 511 | blacklisted_panels.extend([a, b]) 512 | done = False 513 | 514 | log.debug('Length of blacklisted : %s' % len(blacklisted_panels)) 515 | log.debug('Length of output panels : %s' % len(output_panels)) 516 | 517 | for panel in panels: 518 | if panel not in blacklisted_panels: 519 | output_panels.append(panel) 520 | 521 | output_panels = retag_panels(output_panels) 522 | 523 | return output_panels, done 524 | 525 | 526 | def merge_loop_vertical(panels): 527 | """ Iteratively merges panels by relative proximity to each other along the y axis. 528 | This is repeated until no panels are merged by the algorithm 529 | 530 | :param panels: List of Panels to be merged. 531 | 532 | :return output_panels: List of merged panels 533 | :return done: Bool indicating whether a merge occurred 534 | """ 535 | 536 | output_panels = [] 537 | blacklisted_panels = [] 538 | 539 | # Merging labels that are in close proximity vertically 540 | for a, b in itertools.combinations(panels, 2): 541 | 542 | if (abs(a.left - b.left) < 3 * min(a.height, b.height) or abs(a.center[0] - b.center[0]) < 3 * min(a.height, b.height)) \ 543 | and abs(a.center[1] - b.center[1]) < 3 * min(a.height, b.height) \ 544 | and min(abs(a.top - b.bottom), abs(b.top - a.bottom)) < 2 * min(a.height, b.height): 545 | 546 | merged_rect = merge_rect(a, b) 547 | merged_panel = Panel(merged_rect.left, merged_rect.right, merged_rect.top, merged_rect.bottom, 0) 548 | output_panels.append(merged_panel) 549 | blacklisted_panels.extend([a, b]) 550 | 551 | for panel in panels: 552 | if panel not in blacklisted_panels: 553 | output_panels.append(panel) 554 | 555 | output_panels = retag_panels(output_panels) 556 | 557 | return output_panels 558 | 559 | 560 | def get_one_to_merge(all_combos, panels): 561 | """Merges the first overlapping set of panels found and an returns updated panel list 562 | 563 | :param all_combos: List of Tuple(Panel, Panel) objects of all possible combinations of the input 'panels' variable 564 | :param panels: List of input Panels 565 | 566 | :return panels: List of updated panels after one overlap is merged 567 | :return: Bool indicated whether all overlaps have been completed 568 | """ 569 | 570 | for a, b in all_combos: 571 | 572 | overlap_panel = merge_overlap(a, b) 573 | if overlap_panel is not None: 574 | merged_panel = Panel(overlap_panel.left, overlap_panel.right, overlap_panel.top, overlap_panel.bottom, 0) 575 | panels.remove(a) 576 | panels.remove(b) 577 | panels.append(merged_panel) 578 | return panels, False 579 | 580 | return panels, True 581 | 582 | 583 | def convert_panels_to_labels(panels): 584 | """ Converts a list of panels to a list of labels 585 | 586 | :param panels: Input list of Panels 587 | :return : List of Labels 588 | """ 589 | 590 | return [Label(panel.left, panel.right, panel.top, panel.bottom, panel.tag) for panel in panels] 591 | 592 | 593 | def merge_all_overlaps(panels): 594 | """ Merges all overlapping rectangles together 595 | 596 | :param panels : Input list of Panels 597 | :return output_panels: List of merged panels 598 | :return all_merged: Bool indicating whether all merges are completed 599 | """ 600 | 601 | all_merged = False 602 | 603 | while all_merged is False: 604 | all_combos = list(itertools.combinations(panels, 2)) 605 | panels, all_merged = get_one_to_merge(all_combos, panels) 606 | 607 | output_panels = retag_panels(panels) 608 | return output_panels, all_merged 609 | 610 | 611 | def get_duplicate_labelling(labelled_diags): 612 | """ Returns diagrams sharing a Label object with other diagrams. 613 | 614 | :param labelled_diags: List of Diagrams with Label objects assigned 615 | :return failed_diag_label: List of Diagrams that share a Label object with another Diagram 616 | """ 617 | 618 | failed_diag_label = set(diag for diag in labelled_diags if not diag.label) 619 | filtered_labelled_diags = [diag for diag in labelled_diags if diag not in failed_diag_label] 620 | 621 | # Identifying cases with the same label: 622 | for a, b in itertools.combinations(filtered_labelled_diags, 2): 623 | if a.label == b.label: 624 | failed_diag_label.add(a) 625 | failed_diag_label.add(b) 626 | 627 | return failed_diag_label 628 | 629 | 630 | def remove_duplicates(diags, fig_bbox, rate=1): 631 | """ 632 | Removes the least likely of the duplicate results. 633 | Likeliness is determined from the distance from the bounding box 634 | 635 | :param diags: All detected diagrams with assigned labels 636 | :param fig_bbox: Panel object representing the co-ordinates for the entire Figure 637 | :param rate: Number of pixels to expand by upon each iteration 638 | 639 | :return output_diags : List of Diagrams with Labels 640 | :return output_labelless_diags : List of Diagrams with Labels removed due to duplicates 641 | """ 642 | 643 | output_diags = [] 644 | output_labelless_diags = [] 645 | 646 | # Unique labels 647 | unique_labels = set(diag.label for diag in diags if diag.label is not None) 648 | 649 | for label in unique_labels: 650 | 651 | diags_with_labels = [diag for diag in diags if diag.label is not None] 652 | diags_with_this_label = [diag for diag in diags_with_labels if diag.label.tag == label.tag] 653 | 654 | if len(diags_with_this_label) == 1: 655 | output_diags.append(diags_with_this_label[0]) 656 | continue 657 | 658 | diag_and_displacement = [] # List of diag-distance tuples 659 | 660 | for diag in diags_with_this_label: 661 | 662 | probe_rect = Rect(diag.left, diag.right, diag.top, diag.bottom) 663 | found = False 664 | max_threshold_width = fig_bbox.width 665 | max_threshold_height = fig_bbox.height 666 | rate_counter = 0 667 | 668 | while found is False and (probe_rect.width < max_threshold_width or probe_rect.height < max_threshold_height): 669 | # Increase border value each loop 670 | probe_rect.right = probe_rect.right + rate 671 | probe_rect.bottom = probe_rect.bottom + rate 672 | probe_rect.left = probe_rect.left - rate 673 | probe_rect.top = probe_rect.top - rate 674 | 675 | rate_counter += rate 676 | 677 | if probe_rect.overlaps(label): 678 | found = True 679 | diag_and_displacement.append((diag, rate_counter)) 680 | 681 | master_diag = min(diag_and_displacement, key=lambda x: x[1])[0] 682 | output_diags.append(master_diag) 683 | 684 | labelless_diags = [diag[0] for diag in diag_and_displacement if diag[0] is not master_diag] 685 | 686 | for diag in labelless_diags: 687 | diag.label = None 688 | 689 | output_labelless_diags.extend(labelless_diags) 690 | 691 | return output_diags, output_labelless_diags 692 | 693 | 694 | -------------------------------------------------------------------------------- /chemschematicresolver/clean.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Image Cleanup Operations 4 | ======================== 5 | 6 | Functions to clean data for improved extraction. 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | """ 12 | 13 | import copy 14 | import numpy as np 15 | import warnings 16 | 17 | from .ocr import read_label, read_diag_text 18 | 19 | 20 | def find_repeating_unit(labels, diags, fig): 21 | """ Identifies 'n' labels as repeating unit identifiers. 22 | Removal only occurs when a label and diagram overlap 23 | 24 | :param labels: List of Label objects 25 | :param diags: List of Diagram objects 26 | :param fig: Input Figure 27 | :returns labels: List of cleaned label objects 28 | :returns diags: List of diagram objects (flagged as repeating) 29 | """ 30 | 31 | ns = [] 32 | 33 | for diag in diags: 34 | for cand in labels: 35 | if diag.overlaps(cand): 36 | with warnings.catch_warnings(): 37 | warnings.simplefilter("ignore") 38 | repeating_units = [token for sentence in read_label(fig, cand)[0].text for token in sentence.tokens if 'n' is token.text] 39 | if repeating_units: 40 | ns.append(cand) 41 | diag.repeating = True 42 | 43 | labels = [label for label in labels if label not in ns] 44 | return labels, diags 45 | 46 | 47 | def remove_diagram_numbers(diags, fig): 48 | """ Removes vertex numbers from diagrams for cleaner OSRA resolution""" 49 | 50 | num_bbox = [] 51 | for diag in diags: 52 | 53 | diag_text = read_diag_text(fig, diag) 54 | 55 | # Simplify into list comprehension when working... 56 | for token in diag_text: 57 | if token.text in '123456789': 58 | print("Numeral successfully extracted %s" % token.text) 59 | num_bbox.append((diag.left + token.left, diag.left + token.right, 60 | diag.top + token.top, diag.top + token.bottom)) 61 | 62 | # Make a cleaned copy of image to be used when resolving diagrams 63 | diag_fig = copy.deepcopy(fig) 64 | 65 | for bbox in num_bbox: 66 | diag_fig.img[bbox[2]:bbox[3], bbox[0]:bbox[1]] = np.ones(3) 67 | 68 | return diag_fig 69 | 70 | 71 | def clean_output(text): 72 | """ Remove whitespace and newline characters from input text.""" 73 | 74 | # text = text.replace(' ', '') 75 | return text.replace('\n', '') 76 | -------------------------------------------------------------------------------- /chemschematicresolver/decorators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Decorators 4 | ========== 5 | 6 | Python decorators used throughout ChemSchematicResolver. 7 | 8 | From FigureDataExtractor () :- 9 | author: Matthew Swain 10 | email: m.swain@me.com 11 | 12 | """ 13 | 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | from __future__ import unicode_literals 18 | import logging 19 | import functools 20 | 21 | import six 22 | 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def memoized_property(fget): 28 | """Decorator to create memoized properties.""" 29 | attr_name = '_{}'.format(fget.__name__) 30 | 31 | @functools.wraps(fget) 32 | def fget_memoized(self): 33 | if not hasattr(self, attr_name): 34 | setattr(self, attr_name, fget(self)) 35 | return getattr(self, attr_name) 36 | return property(fget_memoized) 37 | 38 | 39 | def python_2_unicode_compatible(klass): 40 | """Fix ``__str__``, ``__unicode__`` and ``__repr__`` methods under Python 2. 41 | 42 | Add this decorator to a class, then define ``__str__`` and ``__repr__`` methods that both return unicode strings. 43 | Under python 2, this will return encoded strings for ``__str__`` (utf-8) and ``__repr__`` (ascii), and add 44 | ``__unicode__`` and ``_unicode_repr`` to return the original unicode strings. Under python 3, this does nothing. 45 | """ 46 | if six.PY2: 47 | if '__str__' not in klass.__dict__: 48 | raise ValueError("Define __str__() on %s to use @python_2_unicode_compatible" % klass.__name__) 49 | if '__repr__' not in klass.__dict__: 50 | raise ValueError("Define __repr__() on %s to use @python_2_unicode_compatible" % klass.__name__) 51 | klass.__unicode__ = klass.__str__ 52 | klass._unicode_repr = klass.__repr__ 53 | klass.__str__ = lambda self: self.__unicode__().encode('utf-8') 54 | klass.__repr__ = lambda self: self._unicode_repr().encode('ascii', errors='backslashreplace') 55 | return klass 56 | -------------------------------------------------------------------------------- /chemschematicresolver/dict/spelling.txt: -------------------------------------------------------------------------------- 1 | # Spelling corrections for atom labels and abbreviations 2 | # that might not be correctly parsed by OCR engine 3 | # You can run osra with -d option to check the spelling correction 4 | # process - the output looks like: 5 | # OCR string --> Corrected String --> Final Output 6 | # Note that by default OSRA might try as many as 3 different resolutions 7 | # so you may have quite a bit of output to look through. Try a specific 8 | # resolution (with -r switch) or choose the best match. 9 | # Empty lines are ignored and lines starting with # are comments 10 | 11 | Ci Cl 12 | Cf Cl 13 | Cll Cl 14 | 15 | HN N 16 | NH N 17 | M N 18 | Hm N 19 | MN N 20 | N2 N 21 | NM N 22 | NH2 N 23 | H2N N 24 | NHZ N 25 | HZN N 26 | NH3 N 27 | nu N 28 | Hu N 29 | lU N 30 | HlU N 31 | lUH N 32 | H2Y N 33 | RN N 34 | MH N 35 | M2M N 36 | M N 37 | HnN N 38 | nIH N 39 | NHX N 40 | mH N 41 | NN N 42 | hN N 43 | Nh N 44 | 45 | OH O 46 | oH O 47 | Ho O 48 | HO O 49 | ol O 50 | On O 51 | on O 52 | no O 53 | nO O 54 | ON O 55 | oN O 56 | O4 O 57 | OM O 58 | Un O 59 | 4O O 60 | Mo O 61 | nU O 62 | UnH O 63 | nUH O 64 | 65 | Meo MeO 66 | oMe MeO 67 | oMg MeO 68 | omg MeO 69 | Mgo MeO 70 | leo MeO 71 | ohle MeO 72 | lleo MeO 73 | olllle MeO 74 | OMe MeO 75 | OM8 MeO 76 | OMo MeO 77 | OMB MeO 78 | OCH3 MeO 79 | OCHS MeO 80 | H3CO MeO 81 | OM4 MeO 82 | ocH MeO 83 | OM6 MeO 84 | M6O MeO 85 | OMR MeO 86 | OMoo MeO 87 | OmB MeO 88 | MoO MeO 89 | M*O MeO 90 | McoO MeO 91 | Ome MeO 92 | MgO MeO 93 | M8O MeO 94 | MBO MeO 95 | 96 | NC CN 97 | YC CN 98 | Nc CN 99 | cN CN 100 | 101 | nBU nBu 102 | neU nBu 103 | ngU nBu 104 | n8U nBu 105 | BU nBu 106 | 107 | Eto EtO 108 | oEt EtO 109 | Elo EtO 110 | oEl EtO 111 | ElO EtO 112 | OEl EtO 113 | OC2H EtO 114 | OCH2CH3 EtO 115 | CH3CH2O EtO 116 | 117 | olgU OiBu 118 | oleU OiBu 119 | OlBU OiBu 120 | 121 | npr iPr 122 | llpll iPr 123 | lpl iPr 124 | npl iPr 125 | lPl iPr 126 | nPl iPr 127 | 128 | tBU tBu 129 | llBU tBu 130 | lBU tBu 131 | 132 | CooH COOH 133 | HooC COOH 134 | Co2H COOH 135 | CO2H COOH 136 | HOOC COOH 137 | CO2n COOH 138 | co2H COOH 139 | CO2 COOH 140 | 141 | 142 | AC Ac 143 | pC Ac 144 | pc Ac 145 | 146 | ACo AcO 147 | opC AcO 148 | pcO AcO 149 | ACO AcO 150 | oCO AcO 151 | OoC AcO 152 | OpC AcO 153 | pCO AcO 154 | RCO AcO 155 | ORC AcO 156 | OnC AcO 157 | OAc AcO 158 | nCO AcO 159 | Rco AcO 160 | oRc AcO 161 | OAC AcO 162 | 163 | Bl Br 164 | el Br 165 | BC Br 166 | BF Br 167 | 168 | # Need to distinguish between a recognized label for methyl and 169 | # the default empty string for carbon 170 | CH3 Me 171 | H3C Me 172 | CH Me 173 | CH2 Me 174 | HC Me 175 | hle Me 176 | M8 Me 177 | MB Me 178 | MR Me 179 | Me Me 180 | H2C Me 181 | 3C Me 182 | 183 | pl Ar 184 | nl Ar 185 | 186 | oX Ox 187 | 188 | NoZ NO2 189 | o2N NO2 190 | No2 NO2 191 | No NO2 192 | O2N NO2 193 | NOZ NO2 194 | MO2 NO2 195 | 196 | F3C CF3 197 | CF CF3 198 | FC CF3 199 | Co CF3 200 | F8l CF3 201 | CFS CF3 202 | FSC CF3 203 | 204 | F3Co F3CN 205 | 206 | S3 S 207 | Se S 208 | lS S 209 | 8 S 210 | SH S 211 | HS S 212 | 5 S 213 | 214 | O2S SO2 215 | 216 | lH H 217 | 1H H 218 | 219 | AcNH NHAc 220 | AcHN NHAc 221 | ACNH NHAc 222 | NHnC NHAc 223 | pCNH NHAc 224 | NHpC NHAc 225 | lCnuH NHAc 226 | NHAC NHAc 227 | 228 | OlHP THPO 229 | lHPO THPO 230 | lNpo THPO 231 | olHp THPO 232 | 233 | 234 | NlOHCH3 NOHCH3 235 | 236 | HO3S SO3H 237 | so3H SO3H 238 | Ho3s SO3H 239 | SO3 SO3H 240 | SUn3 SO3H 241 | 242 | NMe MeN 243 | NHMe MeN 244 | NHMF MeN 245 | NHlME MeN 246 | 247 | RO OR 248 | oR OR 249 | Ro OR 250 | 251 | lHPO THPO 252 | OlHP THPO 253 | 254 | NCOlRlH3 N(OH)CH3 255 | 256 | pZO BzO 257 | p2O BzO 258 | OBX BzO 259 | BZO BzO 260 | B2O BzO 261 | OB2 BzO 262 | OBz BzO 263 | OBZ BzO 264 | Blo BzO 265 | BZC BzO 266 | EBZO BzO 267 | CBZ BzO 268 | B2C BzO 269 | 270 | Sl Si 271 | 272 | CO2El CO2Et 273 | COOEl CO2Et 274 | COOEt CO2Et 275 | COOC2H CO2Et 276 | CO2CH2CH3 CO2Et 277 | COOCH2CH3 CO2Et 278 | CO2C2H5 CO2Et 279 | COEl CO2Et 280 | COHEl CO2Et 281 | COOMe CO2Me 282 | COOCH3 CO2Me 283 | CO2CH3 CO2Me 284 | HUn2C COOH 285 | CO2E CO2Et 286 | EO EtO 287 | HO2C COOH 288 | CUn2H COOH 289 | CnU2H COOH 290 | MeHN MeN 291 | n2N N 292 | El Et 293 | CH2CH3 Et 294 | CH3CH2 Et 295 | C2H5 Et 296 | H5C2 Et 297 | OBn BnO 298 | HNZ ZNH 299 | ZHN ZNH 300 | HNAm AmNH 301 | OAm AmO 302 | AmOOC AmO2C 303 | CO2Am AmO2C 304 | COOAm AmO2C 305 | SAm AmS 306 | HNBn BnNH 307 | BnN BnNH 308 | CO2Bn BnO2C 309 | BnOOC BnO2C 310 | COOBn BnO2C 311 | SnBu3 Bu3Sn 312 | HNBu BuNH 313 | OBu BuO 314 | CO2Bu BuO2C 315 | COOBu BuO2C 316 | BuOOC BuO2C 317 | SBu BuS 318 | Br3C CBr3 319 | HNCbz CbzNH 320 | Cl3C CCl3 321 | OCH CHO 322 | OHC CHO 323 | O2SCl ClSO2 324 | SO2Cl ClSO2 325 | MeO2C CO2Me 326 | OSO2Me MeO2SO 327 | BrOC COBr 328 | BuOC COBu 329 | F3COC COCF3 330 | ClOC COCl 331 | OCOC COCO 332 | EtOC COEt 333 | FOC COF 334 | MeOC COMe 335 | H3COC COMe 336 | Et2NOC CONEt2 337 | NH2OC CONH2 338 | EtHNOC CONHEt 339 | MeHNOC CONHMe 340 | Me2NOC CONMe2 341 | HSOC COSH 342 | NEt2 Et2N 343 | El2N Et2N 344 | NE2 Et2N 345 | E2N Et2N 346 | NEt3 Et3N 347 | HNEt EtNH 348 | SO2NH2 H2NSO2 349 | H2NO2S H2NSO2 350 | SO2N H2NSO2 351 | HNOH HONH 352 | NMe2 Me2N 353 | MeNH MeN 354 | HNMe MeN 355 | MeOOC CO2Me 356 | OMs MsO 357 | OMS MsO 358 | OCN NCO 359 | SCN NCS 360 | AmHN NHAm 361 | BnHN NHBn 362 | BuHN NHBu 363 | EtHN NHEt 364 | HOHN NHOH 365 | PrHN NHPr 366 | ON NO 367 | Et2OP POEt2 368 | Et3OP POEt3 369 | Et2OOP POOEt2 370 | HNPr PrNH 371 | EtS SEt 372 | SMe MeS 373 | SCH3 MeS 374 | Pll Ph 375 | Pl Ph 376 | ElO2C CO2Et 377 | EtOOC CO2Et 378 | ElOOC CO2Et 379 | 380 | OlOS OTos 381 | CH2CH CH2CH3 382 | CHCH3 CH2CH3 383 | H3CHC CH2CH3 384 | H3CH2C CH2CH3 385 | CCH3CH2l2N N(CH2CH3)2 386 | NCCH2CH3l2 N(CH2CH3)2 387 | CCH3CHl2N N(CH2CH3)2 388 | NCCH2CH2CH3l2 N(CH2CH2CH3)2 389 | CCCH3l3 C(CH3)3 390 | CHCCH3l2 CH(CH3)2 391 | OCH2CO2El OCH2CO2Et 392 | CBOCl2N BOC2N 393 | NHBOC BOCHN 394 | NBOC BOCHN 395 | NHCRZ NHCbz 396 | NClZ NHCbz 397 | F3CO OCF3 398 | Cl3CO OCCl3 399 | NSO2BU NHSO2BU 400 | NHSO2CH3 NHSO2Me 401 | ElOCHN EtO2CHN 402 | ElO2CHl EtO2CHN 403 | NHCOOEl NHCOOEt 404 | OEl2 OEt2 405 | CH2OH HOCH2 406 | NHCH NHCH3 407 | NCH3 NHCH3 408 | NO3S H4NO3S 409 | NOOC H4NOOC 410 | C3H C3H7 411 | C2H C2H5 412 | NNH2 NHNH2 413 | H3CS MeS 414 | NHNHCOCH NHNHCOCH3 415 | NNCOCH3 NHNHCOCH3 416 | NHNHCOCF NHNHCOCF3 417 | NNCOCF3 NHNHCOCF3 418 | CO2CYSP CO2CysPr 419 | COCYSPl CO2CysPr 420 | CO2CYSPl CO2CysPr 421 | CF3CH CF3CH2 422 | PPll2 PPh2 423 | Pll2P PPh2 424 | Ph2P PPh2 425 | CO2M8 CO2Me 426 | OCH2Pll OCH2Ph 427 | PMoN PMBN 428 | lCO AcO 429 | XCO AcO 430 | OXC AcO 431 | CH3O MeO 432 | O3S SO3H 433 | NXeOH2C CH2OMe 434 | CH2ONXe CH2OMe 435 | CHOMe CH2OMe 436 | MeOHC CH2OMe 437 | CH3OCH2 CH2OMe 438 | CH2OCH3 CH2OMe 439 | N3Cl NH3Cl 440 | MeCHlN MeN 441 | NCHlMe MeN 442 | ElCHlN NHEt 443 | NCHlEl NHEt 444 | OCH2Pll OCH2Ph 445 | OCH2P OCH2Ph 446 | COOCH2P COOCH2Ph 447 | NeO MeO 448 | ONe MeO 449 | OCPh3 Ph3CO 450 | SO2CH3 SO2Me 451 | H3CO2S SO2Me 452 | CH3SO2 SO2Me 453 | POIOEII2 POOEt2 454 | SO3NI SO3Na 455 | OSO2Me MsO 456 | CH2I5Bl (CH2)5Br 457 | ICH2I5Bl (CH2)5Br 458 | CH2I5 (CH2)5 459 | TOS Tos 460 | PhO OPh 461 | PhS SPh 462 | PhHN NHPh 463 | 464 | Rl R1 465 | Rlo R10 466 | Rg R9 467 | Rp R4 468 | 2 Z 469 | RlO R10 470 | Y2 Y2 471 | 472 | PMRN PMBN 473 | 474 | * Xx 475 | ** Xx 476 | *** Xx 477 | 478 | -------------------------------------------------------------------------------- /chemschematicresolver/dict/superatom.txt: -------------------------------------------------------------------------------- 1 | Me C 2 | MeO OC 3 | MeS SC 4 | MeN NC 5 | CF CF 6 | CF3 C(F)(F)F 7 | CN C#N 8 | F3CN NC(F)(F)F 9 | Ph c1ccccc1 10 | NO N=O 11 | NO2 N(=O)=O 12 | N(OH)CH3 N(O)C 13 | SO3H S(=O)(=O)O 14 | COOH C(=O)O 15 | nBu CCCC 16 | EtO OCC 17 | OiBu OCC(C)C 18 | iPr CCC 19 | tBu C(C)(C)C 20 | Ac C(=O)C 21 | AcO OC(=O)C 22 | NHAc NC(=O)C 23 | OR O* 24 | BzO OC(=O)C1=CC=CC=C1 25 | THPO O[C@@H]1OCCCC1 26 | CHO C=O 27 | NOH NO 28 | CO2Et C(=O)OCC 29 | CO2Me C(=O)OC 30 | MeO2S S(=O)(=O)C 31 | NMe2 N(C)C 32 | CO2R C(=O)O* 33 | ZNH NC(=O)OCC1=CC=CC=C1 34 | HOCH2 CO 35 | H2NCH2 CN 36 | Et CC 37 | BnO OCC1=CC=CC=C1 38 | AmNH NCCCCC 39 | AmO OCCCCC 40 | AmO2C C(=O)OCCCCC 41 | AmS SCCCCC 42 | BnNH NCC1=CC=CC=C1 43 | BnO2C C(=O)OCC1=CC=CC=C1 44 | Bu3Sn [Sn](CCCC)(CCCC)CCCC 45 | BuNH NCCCC 46 | BuO OCCCC 47 | BuO2C C(=O)OCCCC 48 | BuS SCCCC 49 | CBr3 C(Br)(Br)Br 50 | CbzNH NC(=O)OCC1=CC=CC=C1 51 | CCl3 C(Cl)(Cl)Cl 52 | ClSO2 S(=O)(=O)Cl 53 | COBr C(=O)Br 54 | COBu C(=O)CCCC 55 | COCF3 C(=O)C(F)(F)F 56 | COCl C(=O)Cl 57 | COCO C(=O)C=O 58 | COEt C(=O)CC 59 | COF C(=O)F 60 | COMe C(=O)C 61 | OCOMe OC(=O)C 62 | CONH2 C(=O)N 63 | CONHEt C(=O)NCC 64 | CONHMe C(=O)NC 65 | COSH C(=O)S 66 | Et2N N(CC)CC 67 | Et3N N(CC)(CC)CC 68 | EtNH NCC 69 | H2NSO2 S(=O)(N)=O 70 | HONH ON 71 | Me2N N(C)C 72 | NCO N=C=O 73 | NCS N=C=S 74 | NHAm NCCCCC 75 | NHBn NCC1=CC=CC=C1 76 | NHBu NCCCC 77 | NHEt NCC 78 | NHOH NO 79 | NHPr NCCC 80 | NO N=O 81 | POEt2 P(OCC)OCC 82 | POEt3 P(OCC)(OCC)OCC 83 | POOEt2 P(=O)(OCC)OCC 84 | PrNH CCCN 85 | SEt SCC 86 | BOC C(=O)OC(C)(C)C 87 | MsO OS(=O)(=O)C 88 | OTos OS(=O)(=O)c1ccc(C)cc1 89 | Tos S(=O)(=O)c1ccc(C)cc1 90 | C8H CCCCCCCC 91 | C6H CCCCCC 92 | CH2CH3 CC 93 | N(CH2CH3)2 N(CC)CC 94 | N(CH2CH2CH3)2 N(CCC)CCC 95 | C(CH3)3 C(C)(C)C 96 | COCH3 C(=O)C 97 | CH(CH3)2 C(C)C 98 | OCF3 OC(F)(F)F 99 | OCCl3 OC(Cl)(Cl)Cl 100 | OCF2H OC(F)F 101 | SO2Me S(=O)(=O)C 102 | OCH2CO2H OCC(=O)O 103 | OCH2CO2Et OCC(=O)OCC 104 | BOC2N N(C(=O)OC(C)(C)C)C(=O)OC(C)(C)C 105 | BOCHN NC(=O)OC(C)(C)C 106 | NHCbz NC(=O)OCc1ccccc1 107 | OCH2CF3 OCC(F)(F)F 108 | NHSO2BU NS(=O)(=O)CCCC 109 | NHSO2Me NS(=O)(=O)C 110 | MeO2SO OS(=O)(=O)C 111 | NHCOOEt NC(=O)OCC 112 | NHCH3 NC 113 | H4NOOC C(=O)ON 114 | C3H7 CCC 115 | C2H5 CC 116 | NHNH2 NN 117 | OCH2CH2OH OCCO 118 | OCH2CHOHCH2OH OCC(O)CO 119 | OCH2CHOHCH2NH OCC(O)CN 120 | NHNHCOCH3 NNC(=O)C 121 | NHNHCOCF3 NNC(=O)C(F)(F)F 122 | NHCOCF3 NC(=O)C(F)(F)F 123 | CO2CysPr C(=O)ON[C@H](CS)C(=O)CCC 124 | HOH2C CO 125 | H3CHN NC 126 | H3CO2C C(=O)OC 127 | CF3CH2 CC(F)(F)F 128 | OBOC OC(=O)OC(C)(C)C 129 | Bn2N N(Cc1ccccc1)Cc1ccccc1 130 | F5S S(F)(F)(F)(F)F 131 | PPh2 P(c1ccccc1)c1ccccc1 132 | PPh3 P(c1ccccc1)(c1ccccc1)c1ccccc1 133 | OCH2Ph OCc1ccccc1 134 | CH2OMe COC 135 | PMBN NCc1ccc(OC)cc1 136 | SO2 S(=O)=O 137 | NH3Cl NCl 138 | CF2CF3 C(F)(F)C(F)(F)F 139 | CF2CF2H C(F)(F)C(F)(F) 140 | Bn Cc1ccccc1 141 | OCH2Ph OCc1ccccc1 142 | COOCH2Ph C(=O)OCc1ccccc1 143 | Ph3CO OC(c1ccccc1)(c1ccccc1)c1ccccc1 144 | Ph3C C(c1ccccc1)(c1ccccc1)c1ccccc1 145 | Me2NO2S S(C)(C)N(=O)=O 146 | SO3Na S(=O)(=O)(=O)[Na] 147 | OSO2Ph OS(=O)(=O)c1ccccc1 148 | (CH2)5Br CCCCCBr 149 | OPh Oc1ccccc1 150 | SPh Sc1ccccc1 151 | NHPh Nc1ccccc1 152 | CONEt2 C(=O)N(CC)CC 153 | CONMe2 C(=O)N(C)C 154 | EtO2CHN NC(=O)OCC 155 | H4NO3S S(=O)(=O)ON 156 | TMS [Si](C)(C)(C) 157 | COCOOCH2CH3 C(=O)C(=O)OCC 158 | OCH2CN OCC#N 159 | Xx [*] 160 | X [*] 161 | Y [*] 162 | Z [*] 163 | R [*] 164 | R1 [*] 165 | R2 [*] 166 | R3 [*] 167 | R4 [*] 168 | R5 [*] 169 | R6 [*] 170 | R7 [*] 171 | R8 [*] 172 | R9 [*] 173 | R10 [*] 174 | Y2 [*] 175 | D [*] 176 | C(c1ccccc1)c2ccccc2 C(c1ccccc1)c2ccccc2 177 | OC1c2ccccc2CCc3ccccc13 OC1c2ccccc2CCc3ccccc13 178 | [H] [H] 179 | N N 180 | CC CC 181 | C1NC=CO1 C1NC=CO1 182 | C C 183 | [Ac] [Ac] 184 | CI Cl 185 | F F 186 | Br Br 187 | -------------------------------------------------------------------------------- /chemschematicresolver/extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Extract 4 | ======= 5 | 6 | Functions to extract diagram-label pairs from schematic chemical diagrams. 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | """ 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | from __future__ import unicode_literals 16 | import logging 17 | 18 | from .io import imread 19 | from .actions import segment, classify_kmeans, preprocessing, label_diags, read_diagram_pyosra 20 | from .clean import clean_output 21 | from .ocr import read_label 22 | from .r_group import detect_r_group, get_rgroup_smiles 23 | from .validate import is_false_positive, remove_repeating 24 | 25 | from matplotlib import pyplot as plt 26 | import matplotlib.patches as mpatches 27 | import os 28 | import urllib 29 | import math 30 | import tarfile, zipfile 31 | 32 | from chemdataextractor import Document 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | 37 | def extract_document(filename, extract_all=True, allow_wildcards=False, output=os.path.join(os.path.dirname(os.getcwd()), 'csd')): 38 | """ Extracts chemical records from a document and identifies chemical schematic diagrams. 39 | Then substitutes in if the label was found in a record 40 | 41 | :param filename: Location of document to be extracted 42 | :param extract_all : Boolean to determine whether output is combined with chemical records 43 | :param allow_wildcards: Bool to indicate whether results containing wildcards are permitted 44 | :param output: Directory to store extracted images 45 | 46 | :return : Dictionary of chemical records with diagram SMILES strings, or List of label candidates and smiles 47 | """ 48 | 49 | log.info('Extracting from %s ...' % filename) 50 | 51 | # Extract the raw records from CDE 52 | doc = Document.from_file(filename) 53 | figs = doc.figures 54 | 55 | # Identify image candidates 56 | csds = find_image_candidates(figs, filename) 57 | 58 | # Download figures locally 59 | fig_paths = download_figs(csds, output) 60 | log.info("All relevant figures from %s downloaded successfully" % filename) 61 | 62 | # When diagrams are not found, return results without CSR extraction 63 | if extract_all and not fig_paths: 64 | log.info('No chemical diagrams detected. Returning chemical records.') 65 | return doc.records.serialize() 66 | elif not extract_all and not fig_paths: 67 | log.info('No chemical diagrams detected. Returning empty list.') 68 | return [] 69 | 70 | log.info('Chemical diagram(s) detected. Running ChemSchematicResolver...') 71 | # Run CSR 72 | results = [] 73 | for path in fig_paths: 74 | try: 75 | results.append(extract_image(path, allow_wildcards=allow_wildcards)) 76 | except: 77 | log.error('Could not extract image at %s' % path) 78 | pass 79 | 80 | if not extract_all: 81 | return results 82 | 83 | records = doc.records.serialize() 84 | 85 | # Substitute smiles for labels 86 | combined_results = substitute_labels(records, results) 87 | log.info('All diagram results extracted and combined with chemical records.') 88 | 89 | return combined_results 90 | 91 | 92 | def extract_documents(dirname, extract_all=True, allow_wildcards=False, output=os.path.join(os.path.dirname(os.getcwd()), 'csd')): 93 | """ Automatically identifies and extracts chemical schematic diagrams from all files in a directory of documents. 94 | 95 | :param dirname: Location of directory, with corpus to be extracted 96 | :param extract_all : Boolean indicating whether to extract all results (even those without chemical diagrams) 97 | :param allow_wildcards: Bool to indicate whether results containing wildcards are permitted 98 | :param output: Directory to store extracted images 99 | 100 | :return results: List of chemical record objects, enriched with chemical diagram information 101 | """ 102 | 103 | log.info('Extracting all documents at %s ...' % dirname) 104 | 105 | results = [] 106 | 107 | if os.path.isdir(dirname): 108 | # Extract from all files in directory 109 | for file in os.listdir(dirname): 110 | results.append(extract_document(os.path.join(dirname, file), extract_all, allow_wildcards, output)) 111 | 112 | elif os.path.isfile(dirname): 113 | 114 | # Unzipping compressed inputs 115 | if dirname.endswith('zip'): 116 | # Logic to unzip the file locally 117 | log.info('Opening zip file...') 118 | zip_ref = zipfile.ZipFile(dirname) 119 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 120 | if not os.path.exists(extracted_path): 121 | os.makedirs(extracted_path) 122 | zip_ref.extractall(extracted_path) 123 | zip_ref.close() 124 | 125 | elif dirname.endswith('tar.gz'): 126 | # Logic to unzip tarball locally 127 | log.info('Opening tarball file...') 128 | tar_ref = tarfile.open(dirname, 'r:gz') 129 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 130 | if not os.path.exists(extracted_path): 131 | os.makedirs(extracted_path) 132 | tar_ref.extractall(extracted_path) 133 | tar_ref.close() 134 | 135 | elif dirname.endswith('tar'): 136 | # Logic to unzip tarball locally 137 | log.info('Opening tarball file...') 138 | tar_ref = tarfile.open(dirname, 'r:') 139 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 140 | if not os.path.exists(extracted_path): 141 | os.makedirs(extracted_path) 142 | tar_ref.extractall(extracted_path) 143 | tar_ref.close() 144 | else: 145 | # Logic for wrong file type 146 | log.error('Input not a directory') 147 | raise NotADirectoryError 148 | 149 | docs = [os.path.join(extracted_path, doc) for doc in os.listdir(extracted_path)] 150 | for file in docs: 151 | results.append(extract_document(file, extract_all, allow_wildcards, output)) 152 | 153 | return results 154 | 155 | 156 | def substitute_labels(records, results): 157 | """ Looks for label candidates in the document records and substitutes where appropriate 158 | 159 | :param records: Serialized chemical records from chemdataextractor 160 | :param results: Results extracted from the diagram 161 | 162 | :returns: List of chemical records enriched with chemical diagram SMILES string 163 | """ 164 | 165 | docs_labelled_records = [] 166 | 167 | record_labels = [record for record in records if 'labels' in record.keys()] 168 | 169 | # Get all records that contain common labels 170 | for diag_result in results: 171 | for label_cands, smile in diag_result: 172 | for record_label in record_labels: 173 | overlap = [(record_label, label_cand, smile) for label_cand in label_cands if label_cand in record_label['labels']] 174 | docs_labelled_records += overlap 175 | 176 | log.debug(docs_labelled_records) 177 | 178 | # Adding data to the serialized ChemDataExtractor output 179 | for doc_record, diag_label, diag_smile in docs_labelled_records: 180 | for record in records: 181 | if record == doc_record: 182 | record['diagram'] = {'smiles': diag_smile, 'label': diag_label} 183 | 184 | return records 185 | 186 | 187 | def download_figs(figs, output): 188 | """ Downloads figures from url 189 | 190 | :param figs: List of tuples in form figure metadata (Filename, figure id, url to figure, caption) 191 | :param output: Location of output images 192 | """ 193 | 194 | if not os.path.exists(output): 195 | os.makedirs(output) 196 | 197 | fig_paths = [] 198 | 199 | for file, id, url, caption in figs: 200 | 201 | img_format = url.split('.')[-1] 202 | log.info('Downloading %s image from %s' % (img_format, url)) 203 | filename = file.split('/')[-1].rsplit('.', 1)[0] + '_' + id + '.' + img_format 204 | path = os.path.join(output, filename) 205 | 206 | log.debug("Downloading %s..." % filename) 207 | if not os.path.exists(path): 208 | urllib.request.urlretrieve(url, path) # Saves downloaded image to file 209 | else: 210 | log.debug("File exists! Going to next image") 211 | 212 | fig_paths.append(path) 213 | 214 | return fig_paths 215 | 216 | 217 | def find_image_candidates(figs, filename): 218 | """ Returns a list of csd figures 219 | 220 | :param figs: ChemDataExtractor figure objects 221 | :param filename: String of the file's name 222 | :return: List of figure metadata (Filename, figure id, url to figure, caption) 223 | :rtype: list[tuple[string, string, string, string]] 224 | """ 225 | csd_imgs = [] 226 | 227 | for fig in figs: 228 | detected = False # Used to avoid processing images twice 229 | records = fig.records 230 | caption = fig.caption 231 | for record in records: 232 | if detected: 233 | break 234 | 235 | rec = record.serialize() 236 | if 'figure' in rec.keys(): 237 | detected = True 238 | log.info('Chemical schematic diagram instance found!') 239 | csd_imgs.append((filename, fig.id, fig.url, caption.text.replace('\n', ' '))) 240 | 241 | return csd_imgs 242 | 243 | 244 | def extract_image(filename, debug=False, allow_wildcards=False): 245 | """ Converts a Figure containing chemical schematic diagrams to SMILES strings and extracted label candidates 246 | 247 | :param filename: Input file name for extraction 248 | :param debug: Bool to indicate debugging 249 | :param allow_wildcards: Bool to indicate whether results containing wildcards are permitted 250 | 251 | :return : List of label candidates and smiles 252 | :rtype : list[tuple[list[string],string]] 253 | """ 254 | 255 | # Output lists 256 | r_smiles = [] 257 | smiles = [] 258 | 259 | extension = filename.split('.')[-1] 260 | 261 | # Confidence threshold for OCR results 262 | confidence_threshold = 73.7620468139648 263 | 264 | # Read in float and raw pixel images 265 | fig = imread(filename) 266 | fig_bbox = fig.get_bounding_box() 267 | 268 | # Segment image into pixel islands 269 | panels = segment(fig) 270 | 271 | # Initial classify of images, to account for merging in segmentation 272 | labels, diags = classify_kmeans(panels, fig) 273 | 274 | # Preprocess image (eg merge labels that are small into larger labels) 275 | labels, diags = preprocessing(labels, diags, fig) 276 | 277 | # Re-cluster by height if there are more Diagram objects than Labels 278 | if len(labels) < len(diags): 279 | labels_h, diags_h = classify_kmeans(panels, fig, skel=False) 280 | labels_h, diags_h = preprocessing(labels_h, diags_h, fig) 281 | 282 | # Choose the fitting with the closest number of diagrams and labels 283 | if abs(len(labels_h) - len(diags_h)) < abs(len(labels) - len(diags)): 284 | labels = labels_h 285 | diags = diags_h 286 | 287 | if debug is True: 288 | # Create output image 289 | out_fig, ax = plt.subplots(figsize=(10, 6)) 290 | ax.imshow(fig.img) 291 | colours = iter( 292 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y']) 293 | 294 | # Add label information to the appropriate diagram by expanding bounding box 295 | labelled_diags = label_diags(labels, diags, fig_bbox) 296 | labelled_diags = remove_repeating(labelled_diags) 297 | 298 | for diag in labelled_diags: 299 | 300 | label = diag.label 301 | 302 | if debug is True: 303 | 304 | colour = next(colours) 305 | 306 | # Add diag bbox to debug image 307 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 308 | fill=False, edgecolor=colour, linewidth=2) 309 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 310 | ax.add_patch(diag_rect) 311 | 312 | # Add label bbox to debug image 313 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 314 | fill=False, edgecolor=colour, linewidth=2) 315 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 316 | ax.add_patch(label_rect) 317 | 318 | # Read the label 319 | diag.label, conf = read_label(fig, label) 320 | 321 | if not diag.label.text: 322 | log.warning('Text could not be resolved from label %s' % label.tag) 323 | 324 | # Only extract images where the confidence is sufficiently high 325 | if not math.isnan(conf) and conf > confidence_threshold: 326 | 327 | # Add r-group variables if detected 328 | diag = detect_r_group(diag) 329 | 330 | # Get SMILES for output 331 | smiles, r_smiles = get_smiles(diag, smiles, r_smiles, extension) 332 | 333 | else: 334 | log.warning('Confidence of label %s deemed too low for extraction' % diag.label.tag) 335 | 336 | log.info('The results are :') 337 | log.info('R-smiles %s' % r_smiles) 338 | log.info('Smiles %s' % smiles) 339 | if debug is True: 340 | ax.set_axis_off() 341 | plt.show() 342 | 343 | total_smiles = smiles + r_smiles 344 | 345 | # Removing false positives from lack of labels or wildcard smiles 346 | output = [smile for smile in total_smiles if is_false_positive(smile, allow_wildcards=allow_wildcards) is False] 347 | if len(total_smiles) != len(output): 348 | log.warning('Some SMILES strings were determined to be false positives and were removed from the output.') 349 | 350 | log.info('Final Results : ') 351 | for result in output: 352 | log.info(result) 353 | 354 | return output 355 | 356 | 357 | def extract_images(dirname, debug=False, allow_wildcards=False): 358 | """ Extracts the chemical schematic diagrams from a directory of input images 359 | 360 | :param dirname: Location of directory, with figures to be extracted 361 | :param debug: Boolean specifying verbose debug mode. 362 | :param allow_wildcards: Bool to indicate whether results containing wildcards are permitted 363 | 364 | :return results: List of chemical record objects, enriched with chemical diagram information 365 | """ 366 | 367 | log.info('Extracting all images at %s ...' % dirname) 368 | 369 | results = [] 370 | 371 | if os.path.isdir(dirname): 372 | # Extract from all files in directory 373 | for file in os.listdir(dirname): 374 | results.append(extract_image(os.path.join(dirname, file), debug, allow_wildcards)) 375 | 376 | elif os.path.isfile(dirname): 377 | 378 | # Unzipping compressed inputs 379 | if dirname.endswith('zip'): 380 | # Logic to unzip the file locally 381 | log.info('Opening zip file...') 382 | zip_ref = zipfile.ZipFile(dirname) 383 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 384 | if not os.path.exists(extracted_path): 385 | os.makedirs(extracted_path) 386 | zip_ref.extractall(extracted_path) 387 | zip_ref.close() 388 | 389 | elif dirname.endswith('tar.gz'): 390 | # Logic to unzip tarball locally 391 | log.info('Opening tarball file...') 392 | tar_ref = tarfile.open(dirname, 'r:gz') 393 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 394 | if not os.path.exists(extracted_path): 395 | os.makedirs(extracted_path) 396 | tar_ref.extractall(extracted_path) 397 | tar_ref.close() 398 | 399 | elif dirname.endswith('tar'): 400 | # Logic to unzip tarball locally 401 | log.info('Opening tarball file...') 402 | tar_ref = tarfile.open(dirname, 'r:') 403 | extracted_path = os.path.join(os.path.dirname(dirname), 'extracted') 404 | if not os.path.exists(extracted_path): 405 | os.makedirs(extracted_path) 406 | tar_ref.extractall(extracted_path) 407 | tar_ref.close() 408 | else: 409 | # Logic for wrong file type 410 | log.error('Input not a directory') 411 | raise NotADirectoryError 412 | 413 | imgs = [os.path.join(extracted_path, doc) for doc in os.listdir(extracted_path)] 414 | for file in imgs: 415 | results.append(extract_image(file, debug, allow_wildcards)) 416 | 417 | log.info('Results extracted sucessfully:') 418 | log.info(results) 419 | 420 | return results 421 | 422 | 423 | def get_smiles(diag, smiles, r_smiles, extension='jpg'): 424 | """ Extracts diagram information. 425 | 426 | :param diag: Input Diagram 427 | :param smiles: List of smiles from all diagrams up to 'diag' 428 | :param r_smiles: List of smiles extracted from R-Groups from all diagrams up to 'diag' 429 | :param extension: Format of image file 430 | 431 | :return smiles: List of smiles from all diagrams up to and including 'diag' 432 | :return r_smiles: List of smiles extracted from R-Groups from all diagrams up to and including 'diag' 433 | """ 434 | 435 | # Resolve R-groups if detected 436 | if len(diag.label.r_group) > 0: 437 | r_smiles_group = get_rgroup_smiles(diag, extension) 438 | for smile in r_smiles_group: 439 | label_cand_str = list(set([cand.text for cand in smile[0]])) 440 | r_smiles.append((label_cand_str, smile[1])) 441 | 442 | # Resolve diagram normally if no R-groups - should just be one smile 443 | else: 444 | smile = read_diagram_pyosra(diag, extension) 445 | label_raw = diag.label.text 446 | label_cand_str = list(set([clean_output(cand.text) for cand in label_raw])) 447 | 448 | smiles.append((label_cand_str, smile)) 449 | 450 | return smiles, r_smiles 451 | -------------------------------------------------------------------------------- /chemschematicresolver/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Image IO 4 | ======== 5 | Reading and writing images. 6 | 7 | Module adapted by :- 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | from FigureDataExtractor () :- 12 | author: Matthew Swain 13 | email: m.swain@me.com 14 | 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | import logging 22 | 23 | import numpy as np 24 | from PIL import Image 25 | from skimage import img_as_float, img_as_ubyte, img_as_uint 26 | from skimage import io as skio 27 | from skimage.color import gray2rgb 28 | import os 29 | import csv 30 | 31 | import warnings 32 | 33 | from .model import Figure 34 | 35 | log = logging.getLogger(__name__) 36 | 37 | 38 | def imread(f, raw=False): 39 | """Read an image from a file, create Figure object 40 | :param string|file f: Filename or file-like object. 41 | :return: Figure object. 42 | :rtype: Figure 43 | """ 44 | 45 | with warnings.catch_warnings(record=True) as ws: 46 | img = skio.imread(f, plugin='pil') 47 | 48 | # Transform all images pixel values to be floating point values between 0 and 1 (i.e. not ints 0-255) 49 | # Recommended in skimage-tutorials "Images are numpy arrays" because this what scikit-image uses internally 50 | # Transform greyscale images to RGB 51 | if len(img.shape) == 2: 52 | log.debug('Converting greyscale image to RGB...') 53 | img = gray2rgb(img) 54 | 55 | # Transform all images pixel values to be floating point values between 0 and 1 (i.e. not ints 0-255) 56 | # Recommended in skimage-tutorials "Images are numpy arrays" because this what scikit-image uses internally 57 | if not raw: 58 | img = img_as_float(img) 59 | fig = Figure(img) 60 | 61 | return fig 62 | 63 | 64 | def imsave(f, img): 65 | """Save an image to file. 66 | :param string|file f: Filename or file-like object. 67 | :param numpy.ndarray img: Image to save. Of shape (M,N) or (M,N,3) or (M,N,4). 68 | """ 69 | with warnings.catch_warnings(record=True) as ws: 70 | # Ensure we use PIL so we can guarantee that imsave will accept file-like object as well as filename 71 | skio.imsave(f, img, plugin='pil', quality=100) 72 | 73 | 74 | def imdel(f): 75 | """ Delete an image file 76 | """ 77 | 78 | os.remove(f) 79 | 80 | 81 | def read_superatom(superatom_path): 82 | """ 83 | Reads the superatom file as a list of tuples 84 | :param superatom_path: The path to the file containng superatom info 85 | :return: list of abbreviation-smile tuples for superatoms 86 | """ 87 | 88 | with open(superatom_path, 'r') as inf: 89 | cleaned_lines = [' '.join(line.split()) for line in inf if not line.startswith('#')] 90 | cleaned_lines = [line for line in cleaned_lines if len(line) != 0] 91 | lines = [(line.split(' ')[0], line.split(' ')[1]) for line in cleaned_lines] 92 | 93 | return lines 94 | 95 | 96 | def write_to_superatom(sub_smile, superatom_path): 97 | """ 98 | Adds a smile string to the superatom.txt file, for resolution in pyosra 99 | :param sub_smile: The smile string to be added to the file 100 | :param: superatom_path: The path to the file containng superatom info 101 | """ 102 | 103 | lines = read_superatom(superatom_path) 104 | 105 | if (sub_smile, sub_smile) not in lines: 106 | lines.append((sub_smile, sub_smile)) 107 | with open(superatom_path, 'w') as outf: 108 | csvwriter = csv.writer(outf, delimiter=' ') 109 | csvwriter.writerows(lines) 110 | 111 | 112 | def img_as_pil(arr, format_str=None): 113 | """Convert an scikit-image image (ndarray) to a PIL object. 114 | 115 | Derived from code in scikit-image PIL IO plugin. 116 | 117 | :param numpy.ndarray img: The image to convert. 118 | :return: PIL image. 119 | :rtype: Image 120 | """ 121 | if arr.ndim == 3: 122 | arr = img_as_ubyte(arr) 123 | mode = {3: 'RGB', 4: 'RGBA'}[arr.shape[2]] 124 | 125 | elif format_str in ['png', 'PNG']: 126 | mode = 'I;16' 127 | mode_base = 'I' 128 | 129 | if arr.dtype.kind == 'f': 130 | arr = img_as_uint(arr) 131 | 132 | elif arr.max() < 256 and arr.min() >= 0: 133 | arr = arr.astype(np.uint8) 134 | mode = mode_base = 'L' 135 | 136 | else: 137 | arr = img_as_uint(arr) 138 | 139 | else: 140 | arr = img_as_ubyte(arr) 141 | mode = 'L' 142 | mode_base = 'L' 143 | 144 | try: 145 | array_buffer = arr.tobytes() 146 | except AttributeError: 147 | array_buffer = arr.tostring() # Numpy < 1.9 148 | 149 | if arr.ndim == 2: 150 | im = Image.new(mode_base, arr.T.shape) 151 | try: 152 | im.frombytes(array_buffer, 'raw', mode) 153 | except AttributeError: 154 | im.fromstring(array_buffer, 'raw', mode) # PIL 1.1.7 155 | else: 156 | image_shape = (arr.shape[1], arr.shape[0]) 157 | try: 158 | im = Image.frombytes(mode, image_shape, array_buffer) 159 | except AttributeError: 160 | im = Image.fromstring(mode, image_shape, array_buffer) # PIL 1.1.7 161 | return im 162 | -------------------------------------------------------------------------------- /chemschematicresolver/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Model 4 | ===== 5 | 6 | Models created to identify different regions of a chemical schematic diagram. 7 | 8 | Module adapted by :- 9 | author: Ed Beard 10 | email: ejb207@cam.ac.uk 11 | 12 | from FigureDataExtractor () :- 13 | author: Matthew Swain 14 | email: m.swain@me.com 15 | 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | import logging 23 | 24 | from . import decorators 25 | import numpy as np 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | @decorators.python_2_unicode_compatible 31 | class Rect(object): 32 | """A rectangular region.""" 33 | 34 | def __init__(self, left, right, top, bottom): 35 | """ 36 | 37 | :param int left: Left edge of rectangle. 38 | :param int right: Right edge of rectangle. 39 | :param int top: Top edge of rectangle. 40 | :param int bottom: Bottom edge of rectangle. 41 | """ 42 | self.left = left 43 | self.right = right 44 | self.top = top 45 | self.bottom = bottom 46 | 47 | @property 48 | def width(self): 49 | """Return width of rectangle in pixels. May be floating point value. 50 | 51 | :rtype: int 52 | """ 53 | return self.right - self.left 54 | 55 | @property 56 | def height(self): 57 | """Return height of rectangle in pixels. May be floating point value. 58 | 59 | :rtype: int 60 | """ 61 | return self.bottom - self.top 62 | 63 | @property 64 | def perimeter(self): 65 | """Return length of the perimeter around rectangle. 66 | 67 | :rtype: int 68 | """ 69 | return (2 * self.height) + (2 * self.width) 70 | 71 | @property 72 | def area(self): 73 | """Return area of rectangle in pixels. May be floating point values. 74 | 75 | :rtype: int 76 | """ 77 | return self.width * self.height 78 | 79 | @property 80 | def center(self): 81 | """Center point of rectangle. May be floating point values. 82 | 83 | :rtype: tuple(int|float, int|float) 84 | """ 85 | xcenter = (self.left + self.right) / 2 86 | ycenter = (self.bottom + self.top) / 2 87 | return xcenter, ycenter 88 | 89 | @property 90 | def center_px(self): 91 | """(x, y) coordinates of pixel nearest to center point. 92 | 93 | :rtype: tuple(int, int) 94 | """ 95 | xcenter, ycenter = self.center 96 | return np.around(xcenter), np.around(ycenter) 97 | 98 | def contains(self, other_rect): 99 | """Return true if ``other_rect`` is within this rect. 100 | 101 | :param Rect other_rect: Another rectangle. 102 | :return: Whether ``other_rect`` is within this rect. 103 | :rtype: bool 104 | """ 105 | return (other_rect.left >= self.left and other_rect.right <= self.right and 106 | other_rect.top >= self.top and other_rect.bottom <= self.bottom) 107 | 108 | def overlaps(self, other_rect): 109 | """Return true if ``other_rect`` overlaps this rect. 110 | 111 | :param Rect other_rect: Another rectangle. 112 | :return: Whether ``other_rect`` overlaps this rect. 113 | :rtype: bool 114 | """ 115 | return (min(self.right, other_rect.right) > max(self.left, other_rect.left) and 116 | min(self.bottom, other_rect.bottom) > max(self.top, other_rect.top)) 117 | 118 | def separation(self, other_rect): 119 | """ Returns the distance between the center of each graph 120 | 121 | :param Rect other_rect: Another rectangle 122 | :return: Distance between centoids of rectangle 123 | :rtype: float 124 | """ 125 | length = abs(self.center[0] - other_rect.center[0]) 126 | height = abs(self.center[1] - other_rect.center[1]) 127 | return np.hypot(length, height) 128 | 129 | 130 | def __repr__(self): 131 | return '%s(left=%s, right=%s, top=%s, bottom=%s)' % ( 132 | self.__class__.__name__, self.left, self.right, self.top, self.bottom 133 | ) 134 | 135 | def __str__(self): 136 | return '<%s (%s, %s, %s, %s)>' % (self.__class__.__name__, self.left, self.right, self.top, self.bottom) 137 | 138 | def __eq__(self, other): 139 | if self.left == other.left and self.right == other.right \ 140 | and self.top == other.top and self.bottom == other.bottom: 141 | return True 142 | else: 143 | return False 144 | 145 | def __hash__(self): 146 | return hash((self.left, self.right, self.top, self.bottom)) 147 | 148 | 149 | class Panel(Rect): 150 | """ Tagged section inside Figure""" 151 | 152 | def __init__(self, left, right, top, bottom, tag=0): 153 | super(Panel, self).__init__(left, right, top, bottom) 154 | self.tag = tag 155 | self._repeating = False 156 | self._pixel_ratio = None 157 | 158 | @property 159 | def repeating(self): 160 | return self._repeating 161 | 162 | @repeating.setter 163 | def repeating(self, repeating): 164 | self._repeating = repeating 165 | 166 | @property 167 | def pixel_ratio(self): 168 | return self._pixel_ratio 169 | 170 | @pixel_ratio.setter 171 | def pixel_ratio(self, pixel_ratio): 172 | self._pixel_ratio = pixel_ratio 173 | 174 | 175 | class Diagram(Panel): 176 | """ Chemical Schematic Diagram that is identified""" 177 | 178 | def __init__(self, *args, label=None, smile=None, fig=None): 179 | self._label = label 180 | self._smile = smile 181 | self._fig = fig 182 | super(Diagram, self).__init__(*args) 183 | 184 | @property 185 | def label(self): 186 | return self._label 187 | 188 | @label.setter 189 | def label(self, label): 190 | self._label = label 191 | 192 | @property 193 | def smile(self): 194 | return self._smile 195 | 196 | @smile.setter 197 | def smile(self, smile): 198 | self._smile = smile 199 | 200 | @property 201 | def fig(self): 202 | """ Cropped Figure object of the specific diagram""" 203 | return self._fig 204 | 205 | @fig.setter 206 | def fig(self, fig): 207 | self._fig = fig 208 | 209 | def compass_position(self, other): 210 | """ Determines the compass position (NSEW) of other relative to self""" 211 | 212 | length = other.center[0] - self.center[0] 213 | height = other.center[1] - self.center[1] 214 | 215 | if abs(length) > abs(height): 216 | if length > 0: 217 | return 'E' 218 | else: 219 | return 'W' 220 | elif abs(length) < abs(height): 221 | if height > 0: 222 | return 'S' 223 | else: 224 | return 'N' 225 | 226 | else: 227 | return None 228 | 229 | def __repr__(self): 230 | if self.label is not None: 231 | return '%s(label=%s, smile=%s)' % ( 232 | self.__class__.__name__, self.label.tag, self.smile 233 | ) 234 | else: 235 | return '%s(label=None, smile=%s)' % ( 236 | self.__class__.__name__, self.smile 237 | ) 238 | 239 | def __str__(self): 240 | if self.label is not None: 241 | return '<%s (%s, %s)>' % (self.__class__.__name__, self.label.tag, self.smile) 242 | else: 243 | return '<%s (%s, %s)' % (self.__class__.__name__, self.tag, self.smile) 244 | 245 | 246 | class Label(Panel): 247 | """ Label used as an identifier for the closest Chemical Schematic Diagram""" 248 | 249 | def __init__(self, *args): 250 | super(Label, self).__init__(*args) 251 | self.r_group = [] 252 | self.values = [] 253 | 254 | @property 255 | def text(self): 256 | return self._text 257 | 258 | @text.setter 259 | def text(self, text): 260 | self._text = text 261 | 262 | def r_group(self): 263 | """ List of lists of tuples containing variable-value-label triplets. 264 | Each list represents a particular combination of chemicals yielding a unique compound. 265 | 266 | :param : List(str,str,List(str)) : A list of variable-value pairs and their list of candidate labels 267 | """ 268 | return self.r_group 269 | 270 | def add_r_group_variables(self, var_value_label_tuples): 271 | """ Updates the R-groups for this label.""" 272 | 273 | self.r_group.append(var_value_label_tuples) 274 | 275 | 276 | class RGroup(object): 277 | """ Object containing all extracted information for an R-group result""" 278 | 279 | def __init__(self, var, value, label_candidates): 280 | self.var = var 281 | self.value = value 282 | self.label_candidates = label_candidates 283 | 284 | def __repr__(self): 285 | return '%s(variable=%s, value=%s, label_candidates=%s)' % ( 286 | self.__class__.__name__, self.var, self.value, self.label_candidates 287 | ) 288 | 289 | def __str__(self): 290 | return '%s(variable=%s, value=%s, label_candidates=%s)' % ( 291 | self.__class__.__name__, self.var, self.value, self.label_candidates 292 | ) 293 | 294 | def convert_to_tuple(self): 295 | """ Converts the r-group object to a usable a list of variable-value pairs and their list of candidate labels """ 296 | tuple_r_group = (self.var, self.value, self.label_candidates) 297 | return tuple_r_group 298 | 299 | 300 | @decorators.python_2_unicode_compatible 301 | class Figure(object): 302 | """A figure image.""" 303 | 304 | def __init__(self, img, panels=None, plots=None, photos=None): 305 | """ 306 | 307 | :param numpy.ndarray img: Figure image. 308 | :param list[Panel] panels: List of panels. 309 | :param list[Plot] plots: List of plots. 310 | :param list[Photo] photos: List of photos. 311 | """ 312 | self.img = img 313 | self.width, self.height = img.shape[0], img.shape[1] 314 | self.center = (int(self.width * 0.5), int(self.height) * 0.5) 315 | self.panels = panels 316 | self.plots = plots 317 | self.photos = photos 318 | 319 | # TODO: Image metadata? 320 | 321 | def __repr__(self): 322 | return '<%s>' % self.__class__.__name__ 323 | 324 | def __str__(self): 325 | return '<%s>' % self.__class__.__name__ 326 | 327 | def get_bounding_box(self): 328 | """ Returns the Panel object for the extreme bounding box of the image 329 | 330 | :rtype: Panel()""" 331 | 332 | rows = np.any(self.img, axis=1) 333 | cols = np.any(self.img, axis=0) 334 | left, right = np.where(rows)[0][[0, -1]] 335 | top, bottom = np.where(cols)[0][[0, -1]] 336 | return Panel(left, right, top, bottom) 337 | -------------------------------------------------------------------------------- /chemschematicresolver/ocr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Optical Character Recognition 4 | ============================= 5 | 6 | Extract text from images using Tesseract. 7 | 8 | Module adapted by :- 9 | author: Ed Beard 10 | email: ejb207@cam.ac.uk 11 | 12 | from FigureDataExtractor () :- 13 | author: Matthew Swain 14 | email: m.swain@me.com 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | import collections 22 | import enum 23 | import logging 24 | import warnings 25 | 26 | import numpy as np 27 | import tesserocr 28 | from chemdataextractor.doc.text import Sentence 29 | 30 | from . import decorators, io, model 31 | from .utils import convert_greyscale, crop, pad 32 | from .parse import ChemSchematicResolverTokeniser, LabelParser 33 | 34 | 35 | log = logging.getLogger(__name__) 36 | 37 | # Whitelist for labels 38 | ALPHABET_UPPER = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 39 | ALPHABET_LOWER = ALPHABET_UPPER.lower() 40 | DIGITS = '0123456789' 41 | ASSIGNMENT = ':=-' 42 | CONCENTRATION = '%()<>' 43 | SEPERATORS = ',' 44 | OTHER = '\'`/' 45 | LABEL_WHITELIST = ASSIGNMENT + DIGITS + ALPHABET_UPPER + ALPHABET_LOWER + CONCENTRATION + SEPERATORS + OTHER 46 | 47 | 48 | def read_diag_text(fig, diag, whitelist=LABEL_WHITELIST): 49 | """ Reads a diagram using OCR and returns the textual OCR objects""" 50 | img = convert_greyscale(fig.img) 51 | cropped_img = crop(img, diag.left, diag.right, diag.top, diag.bottom) 52 | text = get_text(cropped_img, x_offset=diag.left, y_offset=diag.top, psm=PSM.SINGLE_BLOCK, whitelist=whitelist) 53 | tokens = get_words(text) 54 | return tokens 55 | 56 | 57 | def read_label(fig, label, whitelist=LABEL_WHITELIST): 58 | """ Reads a label paragraph objects using ocr 59 | 60 | :param numpy.ndarray img: Input unprocessedimage 61 | :param Label label: Label object containing appropriate bounding box 62 | 63 | :rtype List[List[str]] 64 | """ 65 | 66 | size = 5 67 | img = convert_greyscale(fig.img) 68 | cropped_img = crop(img, label.left, label.right, label.top, label.bottom) 69 | padded_img = pad(cropped_img, size, mode='constant', constant_values=(1, 1)) 70 | text = get_text(padded_img, x_offset=label.left, y_offset=label.top, psm=PSM.SINGLE_BLOCK, whitelist=whitelist) 71 | if not text: 72 | label.text = [] 73 | return label, 0 74 | raw_sentences = get_sentences(text) 75 | 76 | if len(raw_sentences) is not 0: 77 | # Tag each sentence 78 | tagged_sentences = [Sentence(sentence, word_tokenizer=ChemSchematicResolverTokeniser(), 79 | parsers=[LabelParser()]) for sentence in raw_sentences] 80 | else: 81 | tagged_sentences = [] 82 | label.text = tagged_sentences 83 | 84 | # Calculating average confidence for the block 85 | confidences = [t.confidence for t in text] 86 | avg_conf = np.mean(confidences) 87 | log.info('Confidence in OCR: %s' % avg_conf) 88 | 89 | return label, avg_conf 90 | 91 | 92 | # These enums just wrap tesserocr functionality, so we can return proper enum members instead of ints. 93 | 94 | class Orientation(enum.IntEnum): 95 | """Text element orientations enum.""" 96 | #: Up orientation. 97 | PAGE_UP = tesserocr.Orientation.PAGE_UP 98 | #: Right orientation. 99 | PAGE_RIGHT = tesserocr.Orientation.PAGE_RIGHT 100 | #: Down orientation. 101 | PAGE_DOWN = tesserocr.Orientation.PAGE_DOWN 102 | #: Left orientation. 103 | PAGE_LEFT = tesserocr.Orientation.PAGE_LEFT 104 | 105 | 106 | class WritingDirection(enum.IntEnum): 107 | """Text element writing directions enum.""" 108 | #: Left to right. 109 | LEFT_TO_RIGHT = tesserocr.WritingDirection.LEFT_TO_RIGHT 110 | #: Right to left. 111 | RIGHT_TO_LEFT = tesserocr.WritingDirection.RIGHT_TO_LEFT 112 | #: Top to bottom. 113 | TOP_TO_BOTTOM = tesserocr.WritingDirection.TOP_TO_BOTTOM 114 | 115 | 116 | class TextlineOrder(enum.IntEnum): 117 | """Text line order enum.""" 118 | #: Left to right. 119 | LEFT_TO_RIGHT = tesserocr.TextlineOrder.LEFT_TO_RIGHT 120 | #: Right to left. 121 | RIGHT_TO_LEFT = tesserocr.TextlineOrder.RIGHT_TO_LEFT 122 | #: Top to bottom. 123 | TOP_TO_BOTTOM = tesserocr.TextlineOrder.TOP_TO_BOTTOM 124 | 125 | 126 | class Justification(enum.IntEnum): 127 | """Justification enum.""" 128 | #: Unknown justification. 129 | UNKNOWN = tesserocr.Justification.UNKNOWN 130 | #: Left justified 131 | LEFT = tesserocr.Justification.LEFT 132 | #: Center justified 133 | CENTER = tesserocr.Justification.CENTER 134 | #: Right justified 135 | RIGHT = tesserocr.Justification.RIGHT 136 | 137 | 138 | class PSM(enum.IntEnum): 139 | """Page Segmentation Mode enum.""" 140 | OSD_ONLY = tesserocr.PSM.OSD_ONLY 141 | AUTO_OSD = tesserocr.PSM.AUTO_OSD 142 | AUTO_ONLY = tesserocr.PSM.AUTO_ONLY 143 | AUTO = tesserocr.PSM.AUTO 144 | SINGLE_COLUMN = tesserocr.PSM.SINGLE_COLUMN 145 | SINGLE_BLOCK_VERT_TEXT = tesserocr.PSM.SINGLE_BLOCK_VERT_TEXT 146 | SINGLE_BLOCK = tesserocr.PSM.SINGLE_BLOCK 147 | SINGLE_LINE = tesserocr.PSM.SINGLE_LINE 148 | SINGLE_WORD = tesserocr.PSM.SINGLE_WORD 149 | CIRCLE_WORD = tesserocr.PSM.CIRCLE_WORD 150 | SINGLE_CHAR = tesserocr.PSM.SINGLE_CHAR 151 | SPARSE_TEXT = tesserocr.PSM.SPARSE_TEXT 152 | SPARSE_TEXT_OSD = tesserocr.PSM.SPARSE_TEXT_OSD 153 | RAW_LINE = tesserocr.PSM.RAW_LINE 154 | COUNT = tesserocr.PSM.COUNT 155 | 156 | 157 | class RIL(enum.IntEnum): 158 | """Page Iterator Level enum.""" 159 | BLOCK = tesserocr.RIL.BLOCK 160 | PARA = tesserocr.RIL.PARA 161 | SYMBOL = tesserocr.RIL.SYMBOL 162 | TEXTLINE = tesserocr.RIL.TEXTLINE 163 | WORD = tesserocr.RIL.WORD 164 | 165 | 166 | def get_words(blocks): 167 | """Convert list of text blocks into a flat list of the contained words. 168 | 169 | :param list[TextBlock] blocks: List of text blocks. 170 | :return: Flat list of text words. 171 | :rtype: list[TextWord] 172 | """ 173 | words = [] 174 | for block in blocks: 175 | for para in block: 176 | for line in para: 177 | for word in line: 178 | words.append(word) 179 | return words 180 | 181 | 182 | def get_lines(blocks): 183 | """Convert list of text blocks into a nested list of lines, each of which contains a list of words. 184 | 185 | :param list[TextBlock] blocks: List of text blocks. 186 | :return: List of sentences 187 | :rtype: list[list[TextWord]] 188 | """ 189 | lines = [] 190 | for block in blocks: 191 | for para in block: 192 | for line in para: 193 | words = [] 194 | for word in line: 195 | words.append(word) 196 | lines.append(words) 197 | return lines 198 | 199 | 200 | def get_sentences(blocks): 201 | """Convert list of text blocks into a nested list of lines, each of which contains a list of words. 202 | 203 | :param list[TextBlock] blocks: List of text blocks. 204 | :return: List of sentences 205 | :rtype: list[list[TextWord]] 206 | """ 207 | sentences = [] 208 | for block in blocks: 209 | for para in block: 210 | for line in para: 211 | # sentences.append(line.text.replace(',', ' ')) # NB - commas switched for spaces to improve tokenization 212 | sentences.append(line.text) 213 | return sentences 214 | 215 | 216 | def get_text(img, x_offset=0, y_offset=0, psm=PSM.AUTO, padding=0, whitelist=None, img_orientation=None): 217 | """Get text elements in image. 218 | 219 | When passing a cropped image to this function, use ``x_offset`` and ``y_offset`` to ensure the coordinate positions 220 | of the extracted text elements are relative to the original uncropped image. 221 | 222 | :param numpy.ndarray img: Input image. 223 | :param int x_offset: Offset to add to the horizontal coordinates of the returned text elements. 224 | :param int y_offset: Offset to add to the vertical coordinates of the returned text elements. 225 | :param PSM psm: Page segmentation mode. 226 | :param int padding: Padding to add to text element bounding boxes. 227 | :param string whitelist: String containing allowed characters. e.g. Use '0123456789' for digits. 228 | :param Orientation img_orientation: Main orientation of text in image, if known. 229 | :return: List of text blocks. 230 | :rtype: list[TextBlock] 231 | """ 232 | log.debug( 233 | 'get_text: %s x_offset=%s, y_offset=%s, padding=%s, whitelist=%s', 234 | img.shape, x_offset, y_offset, padding, whitelist 235 | ) 236 | 237 | # Add a buffer around the entire input image to ensure no text is too close to edges 238 | img_padding = 3 239 | if img.ndim == 3: 240 | npad = ((img_padding, img_padding), (img_padding, img_padding), (0, 0)) 241 | elif img.ndim == 2: 242 | npad = ((img_padding, img_padding), (img_padding, img_padding)) 243 | else: 244 | raise ValueError('Unexpected image dimensions') 245 | img = np.pad(img, pad_width=npad, mode='constant', constant_values=1) 246 | shape = img.shape 247 | 248 | # Rotate img before sending to tesseract if an img_orientation has been given 249 | if img_orientation == Orientation.PAGE_LEFT: 250 | img = np.rot90(img, k=3, axes=(0, 1)) 251 | elif img_orientation == Orientation.PAGE_RIGHT: 252 | img = np.rot90(img, k=1, axes=(0, 1)) 253 | elif img_orientation is not None: 254 | raise NotImplementedError('Unsupported img_orientation') 255 | 256 | def _get_common_props(it, ril): 257 | """Get the properties that apply to all text elements.""" 258 | # Important: Call GetUTF8Text() before Orientation(). Former raises RuntimeError if no text, latter Segfaults. 259 | text = it.GetUTF8Text(ril) 260 | orientation, writing_direction, textline_order, deskew_angle = it.Orientation() 261 | bb = it.BoundingBox(ril, padding=padding) 262 | 263 | # Translate bounding box and orientation if img was previously rotated 264 | if img_orientation == Orientation.PAGE_LEFT: 265 | orientation = { 266 | Orientation.PAGE_UP: Orientation.PAGE_LEFT, 267 | Orientation.PAGE_LEFT: Orientation.PAGE_DOWN, 268 | Orientation.PAGE_DOWN: Orientation.PAGE_RIGHT, 269 | Orientation.PAGE_RIGHT: Orientation.PAGE_UP 270 | }[orientation] 271 | left, right, top, bottom = bb[1], bb[3], shape[0] - bb[2], shape[0] - bb[0] 272 | elif img_orientation == Orientation.PAGE_RIGHT: 273 | orientation = { 274 | Orientation.PAGE_UP: Orientation.PAGE_RIGHT, 275 | Orientation.PAGE_LEFT: Orientation.PAGE_UP, 276 | Orientation.PAGE_DOWN: Orientation.PAGE_LEFT, 277 | Orientation.PAGE_RIGHT: Orientation.PAGE_DOWN 278 | }[orientation] 279 | left, right, top, bottom = shape[1] - bb[3], shape[1] - bb[1], bb[0], bb[2] 280 | else: 281 | left, right, top, bottom = bb[0], bb[2], bb[1], bb[3] 282 | 283 | common_props = { 284 | 'text': text, 285 | 'left': left + x_offset - img_padding, 286 | 'right': right + x_offset - img_padding, 287 | 'top': top + y_offset - img_padding, 288 | 'bottom': bottom + y_offset - img_padding, 289 | 'confidence': it.Confidence(ril), 290 | 'orientation': Orientation(orientation), # TODO 291 | 'writing_direction': WritingDirection(writing_direction), 292 | 'textline_order': TextlineOrder(textline_order), 293 | 'deskew_angle': deskew_angle 294 | } 295 | return common_props 296 | 297 | blocks = [] 298 | with tesserocr.PyTessBaseAPI(psm=psm) as api: 299 | # Convert image to PIL to load into tesseract (suppress precision loss warning) 300 | with warnings.catch_warnings(record=True) as ws: 301 | pil_img = io.img_as_pil(img) 302 | api.SetImage(pil_img) 303 | if whitelist is not None: 304 | api.SetVariable('tessedit_char_whitelist', whitelist) 305 | # TODO: api.SetSourceResolution if we want correct pointsize on output? 306 | api.Recognize() 307 | it = api.GetIterator() 308 | block = None 309 | para = None 310 | line = None 311 | word = None 312 | it.Begin() 313 | 314 | while True: 315 | try: 316 | if it.IsAtBeginningOf(RIL.BLOCK): 317 | common_props = _get_common_props(it, RIL.BLOCK) 318 | block = TextBlock(**common_props) 319 | blocks.append(block) 320 | 321 | if it.IsAtBeginningOf(RIL.PARA): 322 | common_props = _get_common_props(it, RIL.PARA) 323 | justification, is_list_item, is_crown, first_line_indent = it.ParagraphInfo() 324 | para = TextParagraph( 325 | is_ltr=it.ParagraphIsLtr(), 326 | justification=Justification(justification), 327 | is_list_item=is_list_item, 328 | is_crown=is_crown, 329 | first_line_indent=first_line_indent, 330 | **common_props 331 | ) 332 | if block is not None: 333 | block.paragraphs.append(para) 334 | 335 | if it.IsAtBeginningOf(RIL.TEXTLINE): 336 | common_props = _get_common_props(it, RIL.TEXTLINE) 337 | line = TextLine(**common_props) 338 | if para is not None: 339 | para.lines.append(line) 340 | 341 | if it.IsAtBeginningOf(RIL.WORD): 342 | common_props = _get_common_props(it, RIL.WORD) 343 | wfa = it.WordFontAttributes() 344 | if wfa: 345 | common_props.update(wfa) 346 | word = TextWord( 347 | language=it.WordRecognitionLanguage(), 348 | from_dictionary=it.WordIsFromDictionary(), 349 | numeric=it.WordIsNumeric(), 350 | **common_props 351 | ) 352 | if line is not None: 353 | line.words.append(word) 354 | 355 | # Beware: Character level coordinates do not seem to be accurate in Tesseact 4!! 356 | common_props = _get_common_props(it, RIL.SYMBOL) 357 | symbol = TextSymbol( 358 | is_dropcap=it.SymbolIsDropcap(), 359 | is_subscript=it.SymbolIsSubscript(), 360 | is_superscript=it.SymbolIsSuperscript(), 361 | **common_props 362 | ) 363 | word.symbols.append(symbol) 364 | except RuntimeError as e: 365 | # Happens if no text was detected 366 | log.info(e) 367 | 368 | if not it.Next(RIL.SYMBOL): 369 | break 370 | return blocks 371 | 372 | 373 | @decorators.python_2_unicode_compatible 374 | class TextElement(model.Rect): 375 | """Abstract base class for all text elements.""" 376 | 377 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 378 | confidence): 379 | """ 380 | 381 | :param string text: Recognized text content. 382 | :param int left: Left edge of bounding box. 383 | :param int right: Right edge of bounding box. 384 | :param int top: Top edge of bounding box. 385 | :param int bottom: Bottom edge of bounding box. 386 | :param Orientation orientation: Orientation of this element. 387 | :param WritingDirection writing_direction: Writing direction of this element. 388 | :param TextlineOrder textline_order: Text line order of this element. 389 | :param float deskew_angle: Angle required to make text upright in radians. 390 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 391 | """ 392 | super(TextElement, self).__init__(left, right, top, bottom) 393 | self.text = text 394 | self.orientation = orientation 395 | self.writing_direction = writing_direction 396 | self.textline_order = textline_order 397 | self.deskew_angle = deskew_angle 398 | self.confidence = confidence 399 | 400 | def __repr__(self): 401 | return '<%s: %r>' % (self.__class__.__name__, self.text) 402 | 403 | def __str__(self): 404 | return '<%s: %r>' % (self.__class__.__name__, self.text) 405 | 406 | 407 | class TextBlock(TextElement, collections.MutableSequence): 408 | """Text block.""" 409 | 410 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 411 | confidence): 412 | """ 413 | 414 | :param string text: Recognized text content. 415 | :param int left: Left edge of bounding box. 416 | :param int right: Right edge of bounding box. 417 | :param int top: Top edge of bounding box. 418 | :param int bottom: Bottom edge of bounding box. 419 | :param Orientation orientation: Orientation of this element. 420 | :param WritingDirection writing_direction: Writing direction of this element. 421 | :param TextlineOrder textline_order: Text line order of this element. 422 | :param float deskew_angle: Angle required to make text upright in radians. 423 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 424 | """ 425 | super(TextBlock, self).__init__(text, left, right, top, bottom, orientation, writing_direction, textline_order, 426 | deskew_angle, confidence) 427 | self.paragraphs = [] 428 | 429 | def __getitem__(self, index): 430 | return self.paragraphs[index] 431 | 432 | def __setitem__(self, index, value): 433 | self.paragraphs[index] = value 434 | 435 | def __delitem__(self, index): 436 | del self.paragraphs[index] 437 | 438 | def __len__(self): 439 | return len(self.paragraphs) 440 | 441 | def insert(self, index, value): 442 | self.paragraphs.insert(index, value) 443 | 444 | 445 | class TextParagraph(TextElement, collections.MutableSequence): 446 | """Text paragraph. 447 | 448 | :param string text: Recognized text content. 449 | :param int left: Left edge of bounding box. 450 | :param int right: Right edge of bounding box. 451 | :param int top: Top edge of bounding box. 452 | :param int bottom: Bottom edge of bounding box. 453 | :param Orientation orientation: Orientation of this element. 454 | :param WritingDirection writing_direction: Writing direction of this element. 455 | :param TextlineOrder textline_order: Text line order of this element. 456 | :param float deskew_angle: Angle required to make text upright in radians. 457 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 458 | :param bool is_ltr: Whether this paragraph text is left to right. 459 | :param Justification justification: Paragraph justification. 460 | :param bool is_list_item: Whether this paragraph is part of a list. 461 | :param bool is_crown: Whether the first line is aligned with the subsequent lines yet other paragraphs are indented. 462 | :param int first_line_indent: Indent of first line in pixels. 463 | """ 464 | 465 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 466 | confidence, is_ltr, justification, is_list_item, is_crown, first_line_indent): 467 | super(TextParagraph, self).__init__(text, left, right, top, bottom, orientation, writing_direction, 468 | textline_order, deskew_angle, confidence) 469 | self.lines = [] 470 | self.is_ltr = is_ltr 471 | self.justification = justification 472 | self.is_list_item = is_list_item 473 | self.is_crown = is_crown 474 | self.first_line_indent = first_line_indent 475 | 476 | def __getitem__(self, index): 477 | return self.lines[index] 478 | 479 | def __setitem__(self, index, value): 480 | self.lines[index] = value 481 | 482 | def __delitem__(self, index): 483 | del self.lines[index] 484 | 485 | def __len__(self): 486 | return len(self.lines) 487 | 488 | def insert(self, index, value): 489 | self.lines.insert(index, value) 490 | 491 | 492 | class TextLine(TextElement, collections.MutableSequence): 493 | """Text line.""" 494 | 495 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 496 | confidence): 497 | """ 498 | 499 | :param string text: Recognized text content. 500 | :param int left: Left edge of bounding box. 501 | :param int right: Right edge of bounding box. 502 | :param int top: Top edge of bounding box. 503 | :param int bottom: Bottom edge of bounding box. 504 | :param Orientation orientation: Orientation of this element. 505 | :param WritingDirection writing_direction: Writing direction of this element. 506 | :param TextlineOrder textline_order: Text line order of this element. 507 | :param float deskew_angle: Angle required to make text upright in radians. 508 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 509 | """ 510 | super(TextLine, self).__init__(text, left, right, top, bottom, orientation, writing_direction, textline_order, 511 | deskew_angle, confidence) 512 | self.words = [] 513 | 514 | def __getitem__(self, index): 515 | return self.words[index] 516 | 517 | def __setitem__(self, index, value): 518 | self.words[index] = value 519 | 520 | def __delitem__(self, index): 521 | del self.words[index] 522 | 523 | def __len__(self): 524 | return len(self.words) 525 | 526 | def insert(self, index, value): 527 | self.words.insert(index, value) 528 | 529 | 530 | class TextWord(TextElement, collections.MutableSequence): 531 | """Text word.""" 532 | 533 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 534 | confidence, language, from_dictionary, numeric, font_name=None, bold=None, italic=None, 535 | underlined=None, monospace=None, serif=None, smallcaps=None, pointsize=None, font_id=None): 536 | """ 537 | 538 | :param string text: Recognized text content. 539 | :param int left: Left edge of bounding box. 540 | :param int right: Right edge of bounding box. 541 | :param int top: Top edge of bounding box. 542 | :param int bottom: Bottom edge of bounding box. 543 | :param Orientation orientation: Orientation of this element. 544 | :param WritingDirection writing_direction: Writing direction of this element. 545 | :param TextlineOrder textline_order: Text line order of this element. 546 | :param float deskew_angle: Angle required to make text upright in radians. 547 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 548 | :param language: Language used to recognize this word. 549 | :param from_dictionary: Whether this word was found in a dictionary. 550 | :param numeric: Whether this word is numeric. 551 | :param string font_name: Font name. 552 | :param bool bold: Whether this word is bold. 553 | :param bool italic: Whether this word is italic. 554 | :param underlined: Whether this word is underlined. 555 | :param monospace: Whether this word is in a monospace font. 556 | :param serif: Whethet this word is in a serif font. 557 | :param smallcaps: Whether this word is in small caps. 558 | :param pointsize: Font size in points (1/72 inch). 559 | :param font_id: Font ID. 560 | """ 561 | super(TextWord, self).__init__(text, left, right, top, bottom, orientation, writing_direction, textline_order, 562 | deskew_angle, confidence) 563 | self.symbols = [] 564 | self.font_name = font_name 565 | self.bold = bold 566 | self.italic = italic 567 | self.underlined = underlined 568 | self.monospace = monospace 569 | self.serif = serif 570 | self.smallcaps = smallcaps 571 | self.pointsize = pointsize 572 | self.font_id = font_id 573 | self.language = language 574 | self.from_dictionary = from_dictionary 575 | self.numeric = numeric 576 | 577 | def __getitem__(self, index): 578 | return self.symbols[index] 579 | 580 | def __setitem__(self, index, value): 581 | self.symbols[index] = value 582 | 583 | def __delitem__(self, index): 584 | del self.symbols[index] 585 | 586 | def __len__(self): 587 | return len(self.symbols) 588 | 589 | def insert(self, index, value): 590 | self.symbols.insert(index, value) 591 | 592 | 593 | class TextSymbol(TextElement): 594 | """Text symbol.""" 595 | 596 | def __init__(self, text, left, right, top, bottom, orientation, writing_direction, textline_order, deskew_angle, 597 | confidence, is_dropcap, is_subscript, is_superscript): 598 | """ 599 | 600 | :param string text: Recognized text content. 601 | :param int left: Left edge of bounding box. 602 | :param int right: Right edge of bounding box. 603 | :param int top: Top edge of bounding box. 604 | :param int bottom: Bottom edge of bounding box. 605 | :param Orientation orientation: Orientation of this element. 606 | :param WritingDirection writing_direction: Writing direction of this element. 607 | :param TextlineOrder textline_order: Text line order of this element. 608 | :param float deskew_angle: Angle required to make text upright in radians. 609 | :param float confidence: Mean confidence for the text in this element. Probability 0-100%. 610 | :param bool is_dropcap: Whether this symbol is a dropcap. 611 | :param bool is_subscript: Whether this symbol is subscript. 612 | :param bool is_superscript: Whether this symbol is superscript. 613 | """ 614 | super(TextSymbol, self).__init__(text, left, right, top, bottom, orientation, writing_direction, textline_order, 615 | deskew_angle, confidence) 616 | self.is_dropcap = is_dropcap 617 | self.is_subscript = is_subscript 618 | self.is_superscript = is_superscript 619 | 620 | 621 | -------------------------------------------------------------------------------- /chemschematicresolver/parse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Classes for parsing relevant info 4 | 5 | ======== 6 | author: Ed Beard 7 | email: ejb207@cam.ac.uk 8 | 9 | """ 10 | 11 | from chemdataextractor.parse.cem import BaseParser, lenient_chemical_label 12 | from chemdataextractor.nlp.tokenize import WordTokenizer 13 | from chemdataextractor.model import Compound 14 | 15 | 16 | class LabelParser(BaseParser): 17 | 18 | root = lenient_chemical_label 19 | 20 | def interpret(self, result, start, end): 21 | for label in result.xpath('./text()'): 22 | yield Compound(labels=[label]) 23 | 24 | 25 | class ChemSchematicResolverTokeniser(WordTokenizer): 26 | """ Bespoke version of ChemDiagramExtractor's word tokenizer that doesn't split on prime characters""" 27 | 28 | #: Split before and after these sequences, wherever they occur, unless entire token is one of these sequences 29 | SPLIT = [ 30 | ' ', # Specific whitespace characters 31 | '----', 32 | '––––', # \u2013 en dash 33 | '————', # \u2014 em dash 34 | '<--->', 35 | '---', 36 | '–––', # \u2013 en dash 37 | '———', # \u2014 em dash 38 | '<-->', 39 | '-->', 40 | '...', 41 | '--', 42 | '––', # \u2013 en dash 43 | '——', # \u2014 em dash 44 | '``', 45 | # "''", 46 | '->', 47 | '<', 48 | '>', 49 | '–', # \u2013 en dash 50 | '—', # \u2014 em dash 51 | '―', # \u2015 horizontal bar 52 | '~', # \u007e Tilde 53 | '⁓', # \u2053 Swung dash 54 | '∼', # \u223c Tilde operator 55 | '°', # \u00b0 Degrees 56 | ';', 57 | '@', 58 | '#', 59 | '$', 60 | '£', # \u00a3 61 | '€', # \u20ac 62 | '%', 63 | '&', 64 | '?', 65 | '!', 66 | '™', # \u2122 67 | '®', # \u00ae 68 | '…', # \u2026 69 | '⋯', # \u22ef Mid-line ellipsis 70 | '†', # \u2020 Dagger 71 | '‡', # \u2021 Double dagger 72 | '§', # \u00a7 Section sign 73 | '¶' # \u00b6 Pilcrow sign 74 | '≠', # \u2260 75 | '≡', # \u2261 76 | '≢', # \u2262 77 | '≣', # \u2263 78 | '≤', # \u2264 79 | '≥', # \u2265 80 | '≦', # \u2266 81 | '≧', # \u2267 82 | '≨', # \u2268 83 | '≩', # \u2269 84 | '≪', # \u226a 85 | '≫', # \u226b 86 | '≈', # \u2248 87 | '=', 88 | '÷', # \u00f7 89 | '×', # \u00d7 90 | '→', # \u2192 91 | '⇄', # \u21c4 92 | # '"', # \u0022 Quote mark 93 | # '“', # \u201c 94 | # '”', # \u201d 95 | '„', # \u201e 96 | #'‟', # \u201f 97 | # '‘', # \u2018 Left single quote 98 | # '’', # \u2019 Right single quote - Regularly used as an apostrophe, so don't always split 99 | '‚', # \u201a Single low quote 100 | # '‛', # \u201b Single reversed quote 101 | # '`', # \u0060 102 | # '´', # \u00b4 103 | # Brackets 104 | '(', 105 | '[', 106 | '{', 107 | '}', 108 | ']', 109 | ')', 110 | '+', # \u002b Plus 111 | '±', # \u00b1 Plus/Minus 112 | ] 113 | 114 | SPLIT_START_WORD = [] 115 | SPLIT_END_WORD = [] 116 | 117 | # def __init__(self): 118 | # super().__init__(self) 119 | 120 | def _subspan(self, s, span, nextspan): 121 | """Recursively subdivide spans based on a series of rules.""" 122 | text = s[span[0]:span[1]] 123 | lowertext = text.lower() 124 | 125 | # Skip if only a single character or a split sequence 126 | if span[1] - span[ 127 | 0] < 2 or text in self.SPLIT or text in self.SPLIT_END_WORD or text in self.SPLIT_START_WORD or lowertext in self.NO_SPLIT: 128 | return [span] 129 | 130 | # Skip if it looks like URL 131 | if text.startswith('http://') or text.startswith('ftp://') or text.startswith('www.'): 132 | return [span] 133 | 134 | # Split full stop at end of final token (allow certain characters to follow) unless ellipsis 135 | if self.split_last_stop and nextspan is None and text not in self.NO_SPLIT_STOP and not text[-3:] == '...': 136 | if text[-1] == '.': 137 | return self._split_span(span, -1) 138 | 139 | # Split off certain sequences at the end of a word 140 | for spl in self.SPLIT_END_WORD: 141 | if text.endswith(spl) and len(text) > len(spl) and text[-len(spl) - 1].isalpha(): 142 | return self._split_span(span, -len(spl), 0) 143 | 144 | # Split off certain sequences at the start of a word 145 | for spl in self.SPLIT_START_WORD: 146 | if text.startswith(spl) and len(text) > len(spl) and text[-len(spl) - 1].isalpha(): 147 | return self._split_span(span, len(spl), 0) 148 | 149 | # Split around certain sequences 150 | for spl in self.SPLIT: 151 | ind = text.find(spl) 152 | if ind > -1: 153 | return self._split_span(span, ind, len(spl)) 154 | 155 | # Split around certain sequences unless followed by a digit 156 | for spl in self.SPLIT_NO_DIGIT: 157 | ind = text.rfind(spl) 158 | if ind > -1 and (len(text) <= ind + len(spl) or not text[ind + len(spl)].isdigit()): 159 | return self._split_span(span, ind, len(spl)) 160 | 161 | # Characters to split around, but with exceptions 162 | for i, char in enumerate(text): 163 | if char == '-': 164 | before = lowertext[:i] 165 | after = lowertext[i + 1:] 166 | # By default we split on hyphens 167 | split = True 168 | if before in self.NO_SPLIT_PREFIX or after in self.NO_SPLIT_SUFFIX: 169 | split = False # Don't split if prefix or suffix in list 170 | elif not before.strip(self.NO_SPLIT_CHARS) or not after.strip(self.NO_SPLIT_CHARS): 171 | split = False # Don't split if prefix or suffix entirely consist of certain characters 172 | if split: 173 | return self._split_span(span, i, 1) 174 | 175 | # Split contraction words 176 | for contraction in self.CONTRACTIONS: 177 | if lowertext == contraction[0]: 178 | return self._split_span(span, contraction[1]) 179 | return [span] 180 | 181 | -------------------------------------------------------------------------------- /chemschematicresolver/r_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | R-Group 4 | ======= 5 | 6 | Scripts for identifying R-Group structure diagrams 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | from __future__ import unicode_literals 17 | import logging 18 | 19 | import osra_rgroup 20 | import cirpy 21 | import itertools 22 | import os 23 | 24 | from . import io 25 | from . import actions 26 | from .model import RGroup 27 | from .ocr import ASSIGNMENT, SEPERATORS, CONCENTRATION 28 | 29 | import re 30 | from skimage.util import pad 31 | from urllib.error import URLError 32 | 33 | from chemdataextractor.doc.text import Token 34 | 35 | log = logging.getLogger(__name__) 36 | 37 | BLACKLIST_CHARS = ASSIGNMENT + SEPERATORS + CONCENTRATION 38 | 39 | # Regular Expressions 40 | NUMERIC_REGEX = re.compile('^\d{1,3}$') 41 | ALPHANUMERIC_REGEX = re.compile('^((d-)?(\d{1,2}[A-Za-z]{1,2}[′″‴‶‷⁗]?)(-d))|(\d{1,3})?$') 42 | 43 | # Commonly occuring tokens for R-Groups: 44 | r_group_indicators = ['R', 'X', 'Y', 'Z', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'Y2', 'D', "R'", "R''", "R'''", "R''''"] 45 | r_group_indicators = r_group_indicators + [val.lower() for val in r_group_indicators] 46 | 47 | # Standard path to superatom dictionary file 48 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 49 | superatom_file = os.path.join(parent_dir, 'dict', 'superatom.txt') 50 | spelling_file = os.path.join(parent_dir, 'dict', 'spelling.txt') 51 | 52 | 53 | def detect_r_group(diag): 54 | """ Determines whether a label represents an R-Group structure, and if so gives the variable and value. 55 | 56 | :param diag: Diagram object to search for R-Group indicators 57 | :return diag: Diagram object with R-Group variable and value candidates assigned. 58 | """ 59 | 60 | sentences = diag.label.text 61 | first_sentence_tokens = [token.text.replace(' ', '').replace('\n', '') for token in sentences[0].tokens] 62 | 63 | if sentences == []: 64 | pass 65 | # # Identifies grid labels from the presence of only variable tokens in the first line 66 | elif all([True if token in r_group_indicators else False for token in first_sentence_tokens]): 67 | 68 | r_groups = resolve_r_group_grid(sentences) 69 | r_groups_list = separate_duplicate_r_groups(r_groups) 70 | for r_groups in r_groups_list: 71 | diag.label.add_r_group_variables(convert_r_groups_to_tuples(r_groups)) 72 | 73 | # Otherwise looks for indicative R-Group characters (=, :) 74 | else: 75 | 76 | for sentence in sentences: 77 | 78 | all_sentence_text = [token.text for token in sentence.tokens] 79 | 80 | if '=' in all_sentence_text: 81 | var_value_pairs = detect_r_group_from_sentence(sentence, indicator='=') 82 | elif ':' in all_sentence_text: 83 | var_value_pairs = detect_r_group_from_sentence(sentence, indicator=':') 84 | else: 85 | var_value_pairs = [] 86 | 87 | # Process R-group values from '=' 88 | r_groups = get_label_candidates(sentence, var_value_pairs) 89 | r_groups = standardize_values(r_groups) 90 | 91 | # Resolving positional labels where possible for 'or' cases 92 | r_groups = filter_repeated_labels(r_groups) 93 | 94 | # Separate duplicate variables into separate lists 95 | r_groups_list = separate_duplicate_r_groups(r_groups) 96 | 97 | for r_groups in r_groups_list: 98 | diag.label.add_r_group_variables(convert_r_groups_to_tuples(r_groups)) 99 | 100 | return diag 101 | 102 | 103 | def detect_r_group_from_sentence(sentence, indicator='='): 104 | """ Detects an R-Group from the presence of an input character 105 | 106 | :param sentence: A chemdataextractor.doc.text.Sentence object containing tokens to be probed for R-Groups 107 | :param indicator: String used to identify R-Groups 108 | 109 | :return var_value_pairs: A list of RGroup objects, containing the variable, value and label candidates 110 | :rtype: List[chemschematicresolver.model.RGroup] 111 | """ 112 | 113 | var_value_pairs = [] # Used to find variable - value pairs for extraction 114 | 115 | for i, token in enumerate(sentence.tokens): 116 | if token.text is indicator: 117 | log.info('Found R-Group descriptor %s' % token.text) 118 | if i > 0: 119 | log.info('Variable candidate is %s' % sentence.tokens[i - 1]) 120 | if i < len(sentence.tokens) - 1: 121 | log.info('Value candidate is %s' % sentence.tokens[i + 1]) 122 | 123 | if 0 < i < len(sentence.tokens) - 1: 124 | variable = sentence.tokens[i - 1] 125 | value = sentence.tokens[i + 1] 126 | var_value_pairs.append(RGroup(variable, value, [])) 127 | 128 | elif token.text == 'or' and var_value_pairs: 129 | log.info('"or" keyword detected. Assigning value to previous R-group variable...') 130 | 131 | # Identify the most recent var_value pair 132 | variable = var_value_pairs[-1].var 133 | value = sentence.tokens[i + 1] 134 | var_value_pairs.append(RGroup(variable, value, [])) 135 | 136 | return var_value_pairs 137 | 138 | 139 | def resolve_r_group_grid(sentences): 140 | """Resolves the special grid case, where data is organised into label-value columns for a specific variable. 141 | 142 | Please note that this only extracts simple tables, where the column indicators are contained in the list of 143 | r_group_indicators 144 | 145 | :param sentences: A chemdataextractor.doc.text.Sentence objects containing tokens to be probed for R-Groups 146 | :return var_value_pairs: A list of RGroup objects, containing the variable, value and label candidates 147 | :rtype: List[chemschematicresolver.model.RGroup] 148 | """ 149 | 150 | var_value_pairs = [] # Used to find variable - value pairs for extraction 151 | table_identifier, table_rows = sentences[0], sentences[1:] 152 | 153 | variables = table_identifier.tokens 154 | log.info('R-Group table format detected. Variable candidates are %s' % variables) 155 | 156 | # Check that the length of all table rows is the same as the table_identifier + 1 157 | correct_row_lengths = [True for row in table_rows if len(row.tokens) == len(variables) + 1] 158 | if not all(correct_row_lengths): 159 | return [] 160 | 161 | for row in table_rows: 162 | tokens = row.tokens 163 | label_candidates = [tokens[0]] 164 | values = tokens[1:] 165 | for i, value in enumerate(values): 166 | var_value_pairs.append(RGroup(variables[i], value, label_candidates)) 167 | 168 | return var_value_pairs 169 | 170 | 171 | def get_label_candidates(sentence, r_groups, blacklist_chars=BLACKLIST_CHARS, blacklist_words=['or']): 172 | """Assign label candidates from a sentence that contains known R-Group variables 173 | 174 | :param sentence: Sentence to probe for label candidates 175 | :param: r_groups: A list of R-Group objects with variable-value pairs assigned 176 | :param blacklist_chars: String of disallowed characters 177 | :param blacklist_words: List of disallowed words 178 | 179 | :return r_groups: List of R-Group objects with assigned label candidates 180 | """ 181 | 182 | # Remove irrelevant characters and blacklisted words 183 | candidates = [token for token in sentence.tokens if token.text not in blacklist_chars] 184 | candidates = [token for token in candidates if token.text not in blacklist_words] 185 | 186 | r_group_vars_and_values = [] 187 | for r_group in r_groups: 188 | r_group_vars_and_values.append(r_group.var) 189 | r_group_vars_and_values.append(r_group.value) 190 | 191 | candidates = [token for token in candidates if token not in r_group_vars_and_values] 192 | 193 | r_groups = assign_label_candidates(r_groups, candidates) 194 | 195 | return r_groups 196 | 197 | 198 | def assign_label_candidates(r_groups, candidates): 199 | """ Gets label candidates for cases where the same variable appears twice in one sentence. 200 | This is typically indicative of cases where 2 R-Groups are defined on the same line 201 | """ 202 | 203 | # Check - are there repeated variables? 204 | var_text = [r_group.var.text for r_group in r_groups] 205 | duplicate_r_groups = [r_group for r_group in r_groups if var_text.count(r_group.var.text) > 1] 206 | 207 | # Check that ALL r_group values have this duplicity (ie has every r_group got a duplicate variable?) 208 | if len(duplicate_r_groups) == len(r_groups) and len(r_groups) != 0: 209 | 210 | # Now go through r_groups getting positions of tokens 211 | for i, r_group in enumerate(r_groups): 212 | if i == 0: 213 | end_index = r_group.var.end 214 | r_group.label_candidates = [cand for cand in candidates if cand.start < end_index] 215 | elif i == len(r_groups) - 1: 216 | start_index = r_groups[i - 1].value.end 217 | end_index = r_group.var.end 218 | r_group.label_candidates = [cand for cand in candidates if (start_index< cand.start < end_index) or cand.start > r_group.value.end] 219 | else: 220 | start_index = r_groups[i - 1].value.end 221 | end_index = r_group.var.end 222 | r_group.label_candidates = [cand for cand in candidates if start_index< cand.start < end_index] 223 | 224 | return r_groups 225 | 226 | else: 227 | for r_group in r_groups: 228 | var = r_group.var 229 | value = r_group.value 230 | label_cands = [candidate for candidate in candidates if candidate not in [var, value]] 231 | r_group.label_candidates = label_cands 232 | 233 | return r_groups 234 | 235 | 236 | def filter_repeated_labels(r_groups): 237 | """ 238 | Detects repeated variable values. 239 | When found, this is determined to be an 'or' case so relative label assignment ensues. 240 | 241 | :param r_groups: Input list of R-Group objects 242 | :return output_r_groups: R-Group objects corrected for 'or' statements 243 | 244 | """ 245 | 246 | or_vars = [] 247 | vars = [r_group.var for r_group in r_groups] 248 | unique_vars = set(vars) 249 | for test_var in unique_vars: 250 | if vars.count(test_var) > 1: 251 | log.debug('Identified "or" variable') 252 | or_vars.append(test_var) 253 | 254 | # Get label candidates for r_groups containing this: 255 | filtered_r_groups = [r_group for r_group in r_groups if r_group.var in or_vars] 256 | 257 | # If no duplicate r_group variables, exit function 258 | if len(filtered_r_groups) == 0: 259 | return r_groups 260 | 261 | remaining_r_groups = [r_group for r_group in r_groups if r_group.var not in or_vars] 262 | label_cands = filtered_r_groups[0].label_candidates # Get the label candidates for these vars (should be the same) 263 | 264 | # Prioritizing alphanumerics for relative label assignment 265 | alphanumeric_labels = [label for label in label_cands if ALPHANUMERIC_REGEX.match(label.text)] 266 | 267 | output_r_groups = [] 268 | 269 | # First check if the normal number of labels is the same 270 | if len(filtered_r_groups) == len(label_cands): 271 | for i in range(len(filtered_r_groups)): 272 | altered_r_group = filtered_r_groups[i] 273 | altered_r_group.label_candidates = [label_cands[i]] 274 | output_r_groups.append(altered_r_group) 275 | output_r_groups = output_r_groups + remaining_r_groups 276 | 277 | # Otherwise, check if alphanumerics match 278 | elif len(filtered_r_groups) == len(alphanumeric_labels): 279 | for i in range(len(filtered_r_groups)): 280 | altered_r_group = filtered_r_groups[i] 281 | altered_r_group.label_candidates = [alphanumeric_labels[i]] 282 | output_r_groups.append(altered_r_group) 283 | output_r_groups = output_r_groups + remaining_r_groups 284 | 285 | # Otherwise return with all labels 286 | else: 287 | output_r_groups = r_groups 288 | 289 | return output_r_groups 290 | 291 | 292 | def get_rgroup_smiles(diag, extension='jpg', debug=False, superatom_path=superatom_file, spelling_path=spelling_file): 293 | """ Extract SMILES from a chemical diagram (powered by pyosra) 294 | 295 | :param diag: Input Diagram 296 | :param extension: String indicating format of input file 297 | :param debug: Bool to indicate debugging 298 | 299 | :return labels_and_smiles: List of Tuple(List of label candidates, SMILES) objects 300 | """ 301 | 302 | # Add some padding to image to help resolve characters on the edge 303 | padded_img = pad(diag.fig.img, ((5, 5), (5, 5), (0, 0)), mode='constant', constant_values=1) 304 | 305 | # Save a temp image 306 | img_name = 'r_group_temp.' + extension 307 | io.imsave(img_name, padded_img) 308 | 309 | osra_input = [] 310 | label_cands = [] 311 | 312 | # Format the extracted rgroup 313 | for tokens in diag.label.r_group: 314 | token_dict = {} 315 | for token in tokens: 316 | token_dict[token[0].text] = token[1].text 317 | 318 | # Assigning var-var cases to true value if found (eg. R1=R2=H) 319 | for a, b in itertools.combinations(token_dict.keys(), 2): 320 | if token_dict[a] == b: 321 | token_dict[a] = token_dict[b] 322 | 323 | osra_input.append(token_dict) 324 | label_cands.append(tokens[0][2]) 325 | 326 | # Run osra on temp image 327 | smiles = osra_rgroup.read_rgroup(osra_input, input_file=img_name, verbose=False, debug=debug, superatom_file=superatom_path, spelling_file=spelling_path) 328 | 329 | if not smiles: 330 | log.warning('No SMILES strings were extracted for diagram %s' % diag.tag) 331 | 332 | if not debug: 333 | io.imdel(img_name) 334 | 335 | smiles = [actions.clean_output(smile) for smile in smiles] 336 | 337 | labels_and_smiles = [] 338 | for i, smile in enumerate(smiles): 339 | labels_and_smiles.append((label_cands[i], smile)) 340 | 341 | return labels_and_smiles 342 | 343 | 344 | def clean_chars(value, cleanchars): 345 | """ Remove chars for cleaning 346 | :param value: String to be cleaned 347 | :param cleanchars: Characters to remove from value 348 | 349 | :return value: Cleaned string 350 | """ 351 | 352 | for char in cleanchars: 353 | value = value.replace(char, '') 354 | 355 | return value 356 | 357 | 358 | def resolve_structure(compound): 359 | """ Resolves a compound structure using CIRPY """ 360 | 361 | try: 362 | smiles = cirpy.resolve(compound, 'smiles') 363 | return smiles 364 | except URLError: 365 | log.warning('Cannot connect to Chemical Identify Resolver - chemical names may not be resolved.') 366 | return compound 367 | 368 | 369 | def convert_r_groups_to_tuples(r_groups): 370 | """ Converts a list of R-Group model objects to R-Group tuples""" 371 | 372 | return [r_group.convert_to_tuple() for r_group in r_groups] 373 | 374 | 375 | def standardize_values(r_groups, superatom_path=superatom_file): 376 | """ Converts values to a format compatible with diagram extraction""" 377 | 378 | # List of tuples pairing multiple definitions to the appropriate SMILES string 379 | alkyls = [('CH', ['methyl']), 380 | ('C2H', ['ethyl']), 381 | ('C3H', ['propyl']), 382 | ('C4H', ['butyl']), 383 | ('C5H', ['pentyl']), 384 | ('C6H', ['hexyl']), 385 | ('C7H', ['heptyl']), 386 | ('C8H', ['octyl']), 387 | ('C9H', ['nonyl']), 388 | ('C1OH', ['decyl'])] 389 | 390 | for r_group in r_groups: 391 | # Convert 0's in value field to O 392 | r_group.value = Token(r_group.value.text.replace('0', 'O'), r_group.value.start, r_group.value.end, r_group.value.lexicon) 393 | 394 | # Check if r_group value is in the superatom file 395 | exisiting_abbreviations = [line[0] for line in io.read_superatom(superatom_path)] 396 | if r_group.value.text not in exisiting_abbreviations: 397 | sub_smile = resolve_structure(r_group.value.text) 398 | 399 | if sub_smile is not None: 400 | # Add the smile to the superatom.txt dictionary for resolution in pyosra 401 | io.write_to_superatom(sub_smile, superatom_path) 402 | r_group.value = Token(sub_smile, r_group.value.start, r_group.value.end, r_group.value.lexicon) 403 | 404 | # Resolve commone alkyls 405 | # value = r_group.value.text 406 | # for alkyl in alkyls: 407 | # if value.lower() in alkyl[1]: 408 | # r_group.value = Token(alkyl[0], r_group.value.start, r_group.value.end, r_group.value.lexicon) 409 | 410 | return r_groups 411 | 412 | 413 | def separate_duplicate_r_groups(r_groups): 414 | """ 415 | Separate duplicate R-group variables into separate lists 416 | 417 | :param r_groups: List of input R-Group objects to be tested for duplicates 418 | :return output: List of R-Groups with duplicates separated 419 | """ 420 | 421 | if len(r_groups) is 0: 422 | return r_groups 423 | 424 | # Getting only the variables with unique text value 425 | vars = [r_group.var for r_group in r_groups] 426 | vars_text = [var.text for var in vars] 427 | unique_vars, unique_vars_text = [], [] 428 | for i, var in enumerate(vars): 429 | if vars_text[i] not in unique_vars_text: 430 | unique_vars.append(var) 431 | unique_vars_text.append(vars_text[i]) 432 | 433 | var_quantity_tuples = [] 434 | vars_dict = {} 435 | output = [] 436 | 437 | for var in unique_vars: 438 | var_quantity_tuples.append((var, vars_text.count(var.text))) 439 | vars_dict[var.text] = [] 440 | 441 | equal_length = all(elem[1] == var_quantity_tuples[0][1] for elem in var_quantity_tuples) 442 | 443 | # If irregular, default behaviour is to just use one of the values 444 | if not equal_length: 445 | return [r_groups] 446 | 447 | # Populate dictionary for each unique variable 448 | for var in unique_vars: 449 | for r_group in r_groups: 450 | if var.text == r_group.var.text: 451 | vars_dict[var.text].append(r_group) 452 | 453 | for i in range(len(vars_dict[var.text])): 454 | temp = [] 455 | for var in unique_vars: 456 | try: 457 | temp.append(vars_dict[var.text][i]) 458 | except Exception as e: 459 | log.error("An error occurred while attempting to separate duplicate r-groups") 460 | log.error(e) 461 | output.append(temp) 462 | 463 | # Ensure that each complete set contains all label candidates 464 | for r_groups_output in output: 465 | total_cands = [] 466 | for r_group in r_groups_output: 467 | for cand in r_group.label_candidates: 468 | total_cands.append(cand) 469 | 470 | for r_group in r_groups_output: 471 | r_group.label_candidates = total_cands 472 | 473 | return output 474 | -------------------------------------------------------------------------------- /chemschematicresolver/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Image processing utilities 4 | ========================== 5 | 6 | A toolkit of image processing operations. 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | from __future__ import unicode_literals 17 | import logging 18 | 19 | import copy 20 | 21 | from skimage.color import rgb2gray 22 | from skimage.morphology import binary_closing, disk 23 | from skimage.util import pad 24 | from skimage.util import crop as crop_skimage 25 | from skimage.morphology import skeletonize as skeletonize_skimage 26 | 27 | from scipy import ndimage as ndi 28 | 29 | from .model import Rect 30 | 31 | log = logging.getLogger(__name__) 32 | 33 | 34 | def crop(img, left=None, right=None, top=None, bottom=None): 35 | """Crop image. 36 | 37 | Automatically limits the crop if bounds are outside the image. 38 | 39 | :param numpy.ndarray img: Input image. 40 | :param int left: Left crop. 41 | :param int right: Right crop. 42 | :param int top: Top crop. 43 | :param int bottom: Bottom crop. 44 | :return: Cropped image. 45 | :rtype: numpy.ndarray 46 | """ 47 | height, width = img.shape[:2] 48 | 49 | left = max(0, 0 if left is None else left ) 50 | right = min(width, width if right is None else right) 51 | top = max(0, 0 if top is None else top) 52 | bottom = min(height, height if bottom is None else bottom) 53 | out_img = img[top: bottom, left: right] 54 | return out_img 55 | 56 | 57 | def binarize(fig, threshold=0.85): 58 | """ Converts image to binary 59 | 60 | RGB images are converted to greyscale using :class:`skimage.color.rgb2gray` before binarizing. 61 | 62 | :param numpy.ndarray img: Input image 63 | :param float|numpy.ndarray threshold: Threshold to use. 64 | :return: Binary image. 65 | :rtype: numpy.ndarray 66 | """ 67 | bin_fig = copy.deepcopy(fig) 68 | img = bin_fig.img 69 | 70 | # Skip if already binary 71 | if img.ndim <= 2 and img.dtype == bool: 72 | return img 73 | 74 | img = convert_greyscale(img) 75 | 76 | # Binarize with threshold (default of 0.85 empirically determined) 77 | binary = img < threshold 78 | bin_fig.img = binary 79 | return bin_fig 80 | 81 | 82 | def binary_close(fig, size=20): 83 | """ Joins unconnected pixel by dilation and erosion""" 84 | selem = disk(size) 85 | 86 | fig.img = pad(fig.img, size, mode='constant') 87 | fig.img = binary_closing(fig.img, selem) 88 | fig.img = crop_skimage(fig.img, size) 89 | return fig 90 | 91 | 92 | def binary_floodfill(fig): 93 | """ Converts all pixels inside closed contour to 1""" 94 | log.debug('Binary floodfill initiated...') 95 | fig.img = ndi.binary_fill_holes(fig.img) 96 | return fig 97 | 98 | 99 | def convert_greyscale(img): 100 | """ Converts to greyscale if RGB""" 101 | 102 | # Convert to greyscale if needed 103 | if img.ndim == 3 and img.shape[-1] in [3, 4]: 104 | grey_img = rgb2gray(img) 105 | else: 106 | grey_img = img 107 | return grey_img 108 | 109 | 110 | def skeletonize(fig): 111 | """ 112 | Erode pixels down to skeleton of a figure's img object 113 | :param fig : 114 | :return: Figure : binarized figure 115 | """ 116 | 117 | skel_fig = binarize(fig) 118 | skel_fig.img = skeletonize_skimage(skel_fig.img) 119 | 120 | return skel_fig 121 | 122 | 123 | def merge_rect(rect1, rect2): 124 | """ Merges rectangle with another, such that the bounding box enclose both 125 | 126 | :param Rect rect1: A rectangle 127 | :param Rect rect2: Another rectangle 128 | :return: Merged rectangle 129 | """ 130 | 131 | left = min(rect1.left, rect2.left) 132 | right = max(rect1.right, rect2.right) 133 | top = min(rect1.top, rect2.top) 134 | bottom = max(rect1.bottom, rect2.bottom) 135 | return Rect(left, right, top, bottom) 136 | 137 | 138 | def merge_overlap(a, b): 139 | """ Checks whether panels a and b overlap. If they do, returns new merged panel""" 140 | 141 | if a.overlaps(b) or b.overlaps(a): 142 | return merge_rect(a, b) 143 | -------------------------------------------------------------------------------- /chemschematicresolver/validate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Image Processing Validation Metrics 4 | =================================== 5 | 6 | A toolkit of validation metrics for determining reliability of output. 7 | 8 | author: Ed Beard 9 | email: ejb207@cam.ac.uk 10 | 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | from __future__ import unicode_literals 17 | 18 | 19 | def is_false_positive(label_smile_tuple, allow_wildcards=False): 20 | """ Identifies failures from absence of labels and incomplete / invalid smiles 21 | 22 | :rtype bool 23 | :returns : True if result is a false positive 24 | """ 25 | 26 | label_candidates, smile = label_smile_tuple[0], label_smile_tuple[1] 27 | # Remove results without a label 28 | if len(label_candidates) == 0: 29 | return True 30 | 31 | # Remove results containing the wildcard character in the SMILE 32 | if '*' in smile and not allow_wildcards: 33 | return True 34 | 35 | # Remove results where no SMILE was returned 36 | if smile == '': 37 | return True 38 | 39 | return False 40 | 41 | 42 | def remove_repeating(diags): 43 | """ Removes any diagrams containing a repeating element 44 | 45 | :param diags: List of labelled diagrams 46 | :returns: diags: List of labelled diagtrams, where those with repeating structures are removed... 47 | """ 48 | 49 | diags = [diag for diag in diags if not diag.repeating] 50 | return diags 51 | 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from setuptools import setup, find_packages 6 | 7 | 8 | if os.path.exists('README.md'): 9 | long_description = open('README.md').read() 10 | else: 11 | long_description = '''A toolkit for resolving chemical SMILES from structural diagrams.''' 12 | 13 | setup( 14 | name='ChemSchematicResolver', 15 | version='0.0.1', 16 | author='Edward Beard', 17 | author_email='ejb207@cam.ac.uk', 18 | license='MIT', 19 | url='https://github.com/edbeard/ChemSchematicResolver', 20 | packages=find_packages(), 21 | description='A toolkit for resolving chemical SMILES from structural diagrams.', 22 | long_description=long_description, 23 | keywords='image-mining mining chemistry cheminformatics OSR structure diagram html xml science scientific', 24 | zip_safe=False, 25 | include_package_data=True, 26 | package_data={ 27 | '': ['dict/*txt'], 28 | }, 29 | tests_require=['pytest'], 30 | install_requires=[ 31 | 'pillow', 'tesserocr', 'matplotlib==2.2.4', 'scikit-learn', 'scikit-image<0.15', 'numpy', 'scipy', 32 | ], 33 | classifiers=[ 34 | 'Intended Audience :: Developers', 35 | 'Intended Audience :: Science/Research', 36 | 'License :: OSI Approved :: MIT License', 37 | 'Operating System :: OS Independent', 38 | 'Programming Language :: Python :: 3', 39 | 'Programming Language :: Python :: 3.6', 40 | 'Topic :: Internet :: WWW/HTTP :: Indexing/Search', 41 | 'Topic :: Scientific/Engineering', 42 | 'Topic :: Scientific/Engineering :: Bio-Informatics', 43 | 'Topic :: Scientific/Engineering :: Chemistry', 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /tests/data/S014372081630122X_gr1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edbeard/ChemSchematicResolver/3ed519c83173cb561cffe9546918c014cc5bb792/tests/data/S014372081630122X_gr1.jpg -------------------------------------------------------------------------------- /tests/eval_markush.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | eval_markush 4 | ======== 5 | 6 | Used to test accuracy of training samples in semi-automatic way 7 | 8 | """ 9 | 10 | import unittest 11 | import os 12 | import copy 13 | import chemschematicresolver as csr 14 | from matplotlib import pyplot as plt 15 | import matplotlib.patches as mpatches 16 | 17 | from chemschematicresolver.ocr import LABEL_WHITELIST 18 | 19 | 20 | tests_dir = os.path.dirname(os.path.abspath(__file__)) 21 | train_dir = os.path.join(os.path.dirname(tests_dir), 'train') 22 | raw_train_data = os.path.join(train_dir, 'train_markush_small') 23 | 24 | 25 | class TestMarkush(unittest.TestCase): 26 | ''' 27 | Test the stages in the Markush (R-Group) detection and resolution pipeline 28 | ''' 29 | 30 | def find_labels_from_img(self, filename): 31 | 32 | train_img = os.path.join(raw_train_data, filename) 33 | # Read in float and raw pixel images 34 | fig = csr.io.imread(train_img) 35 | raw_fig = csr.io.imread(train_img, raw=True) 36 | 37 | # Create unreferenced binary copy 38 | bin_fig = copy.deepcopy(fig) 39 | 40 | panels = csr.actions.segment(bin_fig) 41 | panels = csr.actions.preprocessing(panels, fig) 42 | 43 | # Create output image 44 | out_fig, ax = plt.subplots(figsize=(10, 6)) 45 | ax.imshow(fig.img) 46 | 47 | diags, labels = csr.actions.classify_kruskal(panels) 48 | labelled_diags = csr.actions.label_kruskal(diags, labels) 49 | 50 | colours = iter( 51 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 52 | 'y']) 53 | 54 | for diag in labelled_diags: 55 | colour = next(colours) 56 | 57 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 58 | fill=False, edgecolor=colour, linewidth=2) 59 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 60 | ax.add_patch(diag_rect) 61 | 62 | label = diag.label 63 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 64 | fill=False, edgecolor=colour, linewidth=2) 65 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 66 | ax.add_patch(label_rect) 67 | 68 | ax.set_axis_off() 69 | plt.show() 70 | 71 | 72 | return fig, labelled_diags 73 | 74 | def test_markush_candidate_detection(self): 75 | 76 | fig, labelled_diags = self.find_labels_from_img('S0143720816300286_gr1.jpg') 77 | test_diag = labelled_diags[1] 78 | print(csr.actions.read_label(fig, test_diag.label)) 79 | 80 | print(labelled_diags) 81 | 82 | def test_general_ocr(self): 83 | 84 | train_img = os.path.join(raw_train_data, 'S0143720816301681_gr1.jpg') 85 | # Read in float and raw pixel images 86 | fig = csr.io.imread(train_img) 87 | txt = csr.ocr.get_text(fig.img, whitelist=LABEL_WHITELIST) 88 | print(txt) 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /tests/eval_osra.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | eval_osra 4 | ======== 5 | 6 | Used to test accuracy of training samples in semi-automatic way 7 | 8 | """ 9 | 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | from __future__ import unicode_literals 15 | import logging 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | import chemschematicresolver as csr 20 | import os 21 | import unittest 22 | import copy 23 | from matplotlib import pyplot as plt 24 | import csv 25 | 26 | # Paths used in training: 27 | 28 | tests_dir = os.path.dirname(os.path.abspath(__file__)) 29 | train_dir = os.path.join(os.path.dirname(tests_dir), 'train') 30 | raw_train_data = os.path.join(os.path.dirname(train_dir), 'train_imgs') 31 | seg_train_dir = os.path.join(train_dir, 'train_osra_small') 32 | seg_train_csv = os.path.join(train_dir, 'train_osra.csv') 33 | 34 | 35 | def split_raw_train_data(): 36 | """ Splits the raw training data into separate images""" 37 | 38 | for train_fig in os.listdir(raw_train_data): 39 | 40 | train_path = os.path.join(raw_train_data, train_fig) 41 | # Read in float and raw pixel images 42 | fig = csr.io.imread(train_path) 43 | raw_fig = csr.io.imread(train_path, raw=True) 44 | 45 | # Create unreferenced binary copy 46 | bin_fig = copy.deepcopy(fig) 47 | 48 | # Segment images 49 | panels = csr.actions.segment(bin_fig) 50 | panels = csr.actions.preprocessing(panels, fig) 51 | 52 | # Classify diagrams and their labels using kruskal 53 | diags, labels = csr.actions.classify_kruskal(panels) 54 | labelled_diags = csr.actions.label_kruskal(diags, labels) 55 | 56 | # Save the segmented diagrams 57 | for diag in labelled_diags: 58 | # Save the segmented diagrams 59 | l, r, t, b = diag.left, diag.right, diag.top, diag.bottom 60 | cropped_img = csr.actions.crop(raw_fig.img, l, r, t, b) 61 | out_path = os.path.join(seg_train_dir, train_fig[:-4] + '_' + str(diag.tag)) 62 | csr.io.imsave(out_path + '.png', cropped_img) 63 | 64 | # TODO : Create image with regions and labels superimposed 65 | 66 | 67 | 68 | 69 | # class TestOsra(unittest.TestCase): 70 | # """ Class tests whether output of OSRA is correct, with human input. 71 | # Tests pass if > 80% of results after filtering are correct. 72 | # """ 73 | # 74 | # def test_train_data(self): 75 | # """ Looks in the training data to get smiles """ 76 | # 77 | # tps, fps = [], [] # Define true and false positive counters 78 | # 79 | # # Create file if it doesn't exist 80 | # if not os.path.isfile(seg_train_csv): 81 | # open(seg_train_csv, 'w') 82 | # 83 | # with open(seg_train_csv, "r") as f: 84 | # csv_reader = csv.reader(f) 85 | # auto_results = list(csv_reader) 86 | # tp_prev = [res for res in auto_results if res[2] is 'y'] 87 | # fp_prev = [res for res in auto_results if res[2] is 'n'] 88 | # 89 | # for train_fig in os.listdir(seg_train_dir): 90 | # 91 | # 92 | # # if train_fig in tp_filenames: 93 | # # tps.append([train_fig, auto_results[tp_filenames.index(train_fig)], 'y']) 94 | # # else: 95 | # 96 | # img_path = os.path.join(seg_train_dir, train_fig) 97 | # 98 | # # Load in cropped diagram 99 | # fig = csr.io.imread(img_path) 100 | # l, r, t, b = csr.actions.get_img_boundaries(fig.img) 101 | # diag = csr.model.Diagram(l, r, t, b, 0) # Using throwaway tag 0 102 | # 103 | # # Get the SMILES and the confidence 104 | # smile, confidence = csr.actions.read_diagram(fig, diag) 105 | # 106 | # if '*' in smile: 107 | # pass # Remove wildcard results 108 | # elif [train_fig, smile, 'y'] in tp_prev: 109 | # tps.append([train_fig, smile, 'y']) 110 | # elif [train_fig, smile, 'n'] in fp_prev: 111 | # fps.append([train_fig, smile, 'n']) 112 | # else: 113 | # while True: 114 | # inp = str(input('Filename : %s , smile: %s ; Correct? [y/n]\n' % (train_fig, smile))) 115 | # if inp.lower() in ['y', '']: 116 | # tps.append([train_fig, smile, 'y']) 117 | # break 118 | # elif inp.lower() in ['n']: 119 | # fps.append([train_fig, smile, 'n']) 120 | # break 121 | # else : 122 | # print("Invalid response. Please try again.") 123 | # 124 | # print("Precision : %s" % str(float(len(tps))/ (float(len(tps)) + float(len(fps))))) 125 | # 126 | # with open(seg_train_csv, "w") as f: 127 | # csv_writer = csv.writer(f) 128 | # csv_writer.writerows(tps) 129 | # csv_writer.writerows(fps) 130 | 131 | def eval_train_data(): 132 | """ Looks in the training data to get smiles """ 133 | 134 | tps, fps = [], [] # Define true and false positive counters 135 | 136 | # Create file if it doesn't exist 137 | if not os.path.isfile(seg_train_csv): 138 | open(seg_train_csv, 'w') 139 | 140 | with open(seg_train_csv, "r") as f: 141 | csv_reader = csv.reader(f) 142 | auto_results = list(csv_reader) 143 | tp_prev = [res for res in auto_results if res[2] is 'y'] 144 | fp_prev = [res for res in auto_results if res[2] is 'n'] 145 | 146 | for train_fig in os.listdir(seg_train_dir): 147 | 148 | 149 | # if train_fig in tp_filenames: 150 | # tps.append([train_fig, auto_results[tp_filenames.index(train_fig)], 'y']) 151 | # else: 152 | 153 | img_path = os.path.join(seg_train_dir, train_fig) 154 | 155 | # Load in cropped diagram 156 | fig = csr.io.imread(img_path) 157 | l, r, t, b = csr.actions.get_img_boundaries(fig.img) 158 | diag = csr.model.Diagram(l, r, t, b, 0) # Using throwaway tag 0 159 | 160 | # Get the SMILES and the confidence 161 | smile, confidence = csr.actions.read_diagram(fig, diag) 162 | 163 | if '*' in smile: 164 | pass # Remove wildcard results 165 | elif [train_fig, smile, 'y'] in tp_prev: 166 | tps.append([train_fig, smile, 'y']) 167 | elif [train_fig, smile, 'n'] in fp_prev: 168 | fps.append([train_fig, smile, 'n']) 169 | else: 170 | while True: 171 | inp = str(input('Filename : %s , smile: %s ; Correct? [y/n]\n' % (train_fig, smile))) 172 | if inp.lower() in ['y', '']: 173 | tps.append([train_fig, smile, 'y']) 174 | break 175 | elif inp.lower() in ['n']: 176 | fps.append([train_fig, smile, 'n']) 177 | break 178 | else : 179 | print("Invalid response. Please try again.") 180 | 181 | print("Precision : %s" % str(float(len(tps))/ (float(len(tps)) + float(len(fps))))) 182 | 183 | with open(seg_train_csv, "w") as f: 184 | csv_writer = csv.writer(f) 185 | csv_writer.writerows(tps) 186 | csv_writer.writerows(fps) 187 | 188 | 189 | 190 | 191 | if __name__ == '__main__': 192 | eval_train_data() 193 | 194 | -------------------------------------------------------------------------------- /tests/test_actions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_actions 4 | ======== 5 | 6 | Test image processing actions. 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | import chemschematicresolver as csr 19 | import os 20 | from pathlib import Path 21 | import unittest 22 | import copy 23 | import numpy 24 | 25 | from skimage import img_as_float 26 | from matplotlib import pyplot as plt 27 | import matplotlib.patches as mpatches 28 | 29 | 30 | tests_dir = os.path.dirname(os.path.abspath(__file__)) 31 | train_dir = os.path.join(os.path.dirname(tests_dir), 'train') 32 | markush_dir = os.path.join(train_dir, 'train_markush_small') 33 | sample_diag = os.path.join(markush_dir, 'S014372081630119X_gr1.jpg') 34 | 35 | class TestActions(unittest.TestCase): 36 | 37 | def test_binarization(self): 38 | ''' Tests binarization of image''' 39 | 40 | fig = csr.io.imread(sample_diag) 41 | bin = csr.actions.binarize(fig) 42 | self.assertTrue(True in bin.img and False in bin.img) 43 | 44 | def test_segement(self): 45 | ''' Tests segmentation of image''' 46 | 47 | fig = csr.io.imread(sample_diag) 48 | raw_fig = csr.io.imread(sample_diag, raw=True) # Reads in version of pure pixels 49 | 50 | bin_fig = copy.deepcopy(fig) # Image copy to be binarized 51 | 52 | float_fig = copy.deepcopy(fig) # Image copy to be converted to float 53 | # float_fig.img = img_as_float(float_fig.img) 54 | 55 | #bin_fig = csr.actions.binarize(bin_fig) # Might not need binary version? 56 | #bin_fig.img = img_as_float(bin_fig.img) 57 | panels = csr.actions.segment(bin_fig) 58 | 59 | # Create debugging image 60 | out_fig, ax = plt.subplots(figsize=(10, 6)) 61 | ax.imshow(fig.img) 62 | #train_dir = os.path.join(os.path.dirname(tests_dir), 'train') 63 | 64 | diags, labels = csr.actions.classify(panels) 65 | # 66 | for panel in diags: 67 | rect = mpatches.Rectangle((panel.left, panel.top), panel.width, panel.height, 68 | fill=False, edgecolor='red', linewidth=2) 69 | ax.add_patch(rect) 70 | ax.text(panel.left, panel.top + panel.height / 4, '[%s]' % panel.tag, size=panel.height / 20, color='r') 71 | 72 | for panel in labels: 73 | rect = mpatches.Rectangle((panel.left, panel.top), panel.width, panel.height, 74 | fill=False, edgecolor='yellow', linewidth=2) 75 | ax.text(panel.left, panel.top + panel.height / 4, '[%s]' % panel.tag, size=panel.height / 5, color='r') 76 | ax.add_patch(rect) 77 | 78 | ax.set_axis_off() 79 | plt.show() 80 | # for diag in diags: 81 | # csr.actions.assign_label_to_diag(diag, labels) 82 | labelled_diags = csr.actions.label_diags(diags, labels) 83 | tagged_diags = csr.actions.read_all_labels(fig, labelled_diags) 84 | tagged_resolved_diags = csr.actions.read_all_diags(raw_fig, tagged_diags) 85 | 86 | def test_kruskal(self): 87 | 88 | p1 = csr.model.Panel(-1, 1, -1, 1, 0) 89 | p2 = csr.model.Panel(2, 4, 3, 5, 1) 90 | p3 = csr.model.Panel(6, 8, 23, 25, 2) 91 | 92 | panels = [p1, p2, p3] 93 | 94 | sorted_edges = csr.actions.kruskal(panels) 95 | print(sorted_edges) 96 | self.assertEqual(sorted_edges[0][2], 5.) 97 | self.assertEqual(round(sorted_edges[1][2]), 20) 98 | 99 | def test_merge_rect(self): 100 | 101 | r1 = csr.model.Rect(0, 10, 0, 20) 102 | r2 = csr.model.Rect(5, 15, 5, 15) 103 | 104 | merged_r = csr.actions.merge_rect(r1, r2) 105 | self.assertEqual(merged_r.left, 0) 106 | self.assertEqual(merged_r.right, 15) 107 | self.assertEqual(merged_r.top, 0) 108 | self.assertEqual(merged_r.bottom, 20) 109 | 110 | def test_horizontal_merging(self): 111 | ''' Tests the horizontal merging is behaving''' 112 | 113 | test_markush = os.path.join(markush_dir, 'S0143720816301681_gr1.jpg') 114 | fig = csr.io.imread(test_markush) 115 | raw_fig = copy.deepcopy(fig) # Create unreferenced binary copy 116 | 117 | panels = csr.actions.segment(raw_fig) 118 | print('Segmented panel number : %s ' % len(panels)) 119 | 120 | # Crete output image (post-merging) 121 | merged_panels = csr.actions.merge_label_horizontally_repeats(panels) 122 | 123 | out_fig2, ax2 = plt.subplots(figsize=(10, 6)) 124 | ax2.imshow(fig.img) 125 | 126 | for panel in merged_panels: 127 | diag_rect = mpatches.Rectangle((panel.left, panel.top), panel.width, panel.height, 128 | fill=False, edgecolor='r', linewidth=2) 129 | ax2.add_patch(diag_rect) 130 | 131 | plt.show() 132 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_io 4 | ======== 5 | 6 | Test io of images. 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | import chemschematicresolver as csr 17 | import os 18 | from pathlib import Path 19 | import unittest 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | tests_dir = os.path.abspath(__file__) 24 | data_dir = os.path.join(os.path.dirname(tests_dir), 'data') 25 | sample_diag = os.path.join(data_dir, 'S014372081630122X_gr1.jpg') 26 | 27 | 28 | class TestImportAndSave(unittest.TestCase): 29 | """ Tests importing and saving of relevant image types.""" 30 | 31 | def test_import_jpg(self): 32 | """ Tests import of jpg file""" 33 | 34 | fig = csr.io.imread(sample_diag) 35 | 36 | output_path = os.path.join(data_dir, 'test_import_and_save.jpg') 37 | csr.io.imsave(output_path,fig.img) 38 | f = Path(output_path) 39 | is_file = f.is_file() 40 | os.remove(output_path) 41 | 42 | self.assertTrue(is_file) 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_model 4 | ======== 5 | 6 | Test model functionality 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | import chemschematicresolver.model as mod 17 | import unittest 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | 22 | class TestModel(unittest.TestCase): 23 | 24 | def test_separation(self): 25 | r1 = mod.Rect(-1, 1, -1, 1) 26 | r2 = mod.Rect(2, 4, 3, 5) 27 | self.assertEqual(r1.separation(r2), 5.) 28 | 29 | def test_panel_equality(self): 30 | p1 = mod.Panel(1, 2, 3, 4, 0) 31 | p2 = mod.Panel(1, 2, 3, 4, 0) 32 | self.assertEqual(p1, p2) 33 | 34 | def test_pairs_of_panels(self): 35 | tuple1 = (mod.Panel(1, 2, 3, 4, 0), mod.Panel(5, 6, 7, 8, 0)) 36 | tuple2 = (mod.Panel(5, 6, 7, 8, 0), mod.Panel(1, 2, 3, 4, 0)) 37 | 38 | list1 = [tuple1, tuple2] 39 | 40 | self.assertTrue(tuple1 in list1) 41 | -------------------------------------------------------------------------------- /tests/test_ocr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_ocr 4 | ======== 5 | 6 | Test optical character recognition 7 | TODO : Choose a working OCR example... OR add neural net to improve 8 | """ 9 | 10 | import unittest 11 | import os 12 | import chemschematicresolver as csr 13 | import copy 14 | 15 | from matplotlib import pyplot as plt 16 | import matplotlib.patches as mpatches 17 | 18 | tests_dir = os.path.abspath(__file__) 19 | test_ocr_dir = os.path.join(os.path.dirname(tests_dir), 'data', 'ocr') 20 | 21 | 22 | class TestOcr(unittest.TestCase): 23 | 24 | def test_ocr_all_imgs(self): 25 | """ 26 | Uses the OCR module on the whole image to identify text blocks 27 | """ 28 | test_imgs = [os.path.join(test_ocr_dir, file) for file in os.listdir(test_ocr_dir)] 29 | 30 | for img_path in test_imgs: 31 | fig = csr.io.imread(img_path) # Read in float and raw pixel images 32 | text_blocks = csr.ocr.get_text(fig.img) 33 | 34 | # Create output image 35 | out_fig, ax = plt.subplots(figsize=(10, 6)) 36 | ax.imshow(fig.img) 37 | 38 | self.assert_equal(text_blocks[0].text, '1: R1=R2=H:TQEN\n2:R1=H,R2=OMe:T(MQ)EN\n3: R1=R2=OMe:T(TMQ)EN\n\n') 39 | 40 | def test_ocr_r_group(self): 41 | """ 42 | Used to test different functions on OCR recognition""" 43 | 44 | path = os.path.join(test_ocr_dir, 'S0143720816301115_gr1_text.jpg') 45 | 46 | fig = csr.io.imread(path) 47 | copy_fig = copy.deepcopy(fig) 48 | 49 | bin_fig = copy_fig 50 | 51 | text_blocks = csr.ocr.get_text(bin_fig.img) 52 | 53 | out_fig, ax = plt.subplots(figsize=(10, 6)) 54 | ax.imshow(bin_fig.img) 55 | plt.show() 56 | 57 | print(text_blocks) 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /tests/test_parse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_parse 4 | ======== 5 | 6 | Test parsing operations. 7 | 8 | """ 9 | 10 | from chemdataextractor.doc.text import Sentence 11 | from chemschematicresolver.parse import LabelParser 12 | 13 | import unittest 14 | 15 | 16 | class TestParse(unittest.TestCase): 17 | """ Checks that the chemical parser extraction logic is working""" 18 | 19 | def test_label_parsing(self): 20 | 21 | test_sentence = Sentence('3', parsers=[LabelParser()]) 22 | self.assertEqual(test_sentence.records.serialize(), [{'labels': ['3']}]) 23 | -------------------------------------------------------------------------------- /tests/test_r_group.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_parse 4 | ======== 5 | 6 | Test R-group resolution operations. 7 | 8 | """ 9 | 10 | from chemschematicresolver import r_group 11 | from chemschematicresolver.model import RGroup 12 | 13 | from chemdataextractor.doc.text import Sentence, Token, ChemSentenceTokenizer, ChemWordTokenizer, ChemLexicon, ChemAbbreviationDetector, ChemCrfPosTagger, CemTagger 14 | 15 | import unittest 16 | 17 | 18 | def do_resolve(comp): 19 | raw_smile = r_group.resolve_structure(comp) 20 | return raw_smile 21 | 22 | 23 | class TestRgroup(unittest.TestCase): 24 | """ Test functios from the r_group.py module""" 25 | 26 | def test_resolve_structure_1(self): 27 | 28 | comp = '4-nitrophenyl' 29 | gold = '[O-][N+](=O)c1ccccc1' 30 | result = do_resolve(comp) 31 | self.assertEqual(gold, result) 32 | 33 | def test_resolve_structure_2(self): 34 | 35 | comp = '2-chloro-4-nitrophenol' 36 | gold = 'Oc1ccc(cc1Cl)[N+]([O-])=O' 37 | result = do_resolve(comp) 38 | self.assertEqual(gold, result) 39 | 40 | def test_resolve_structure_4(self): 41 | 42 | comp = 'Hexyl' 43 | gold = '[O-][N+](=O)c1cc(c(Nc2c(cc(cc2[N+]([O-])=O)[N+]([O-])=O)[N+]([O-])=O)c(c1)[N+]([O-])=O)[N+]([O-])=O' 44 | result = do_resolve(comp) 45 | self.assertEqual(gold, result) 46 | 47 | def test_duplicate_r_group_vars_in_one_sentence(self): 48 | 49 | sent = Sentence('A R1=H R2=NH B R1=H R2=C') 50 | 51 | # sent = Sentence(text=[Token('A', 0, 1), Token('R1', 2, 3), Token('=', 4, 5), Token('H', 6, 7), 52 | # Token('R2', 8, 9), Token('=', 10, 11), Token('NH', 12, 13), 53 | # Token('B', 14, 15), Token('R1', 16, 17), Token('=', 18, 19), Token('H', 20, 21), 54 | # Token('R2', 21, 22), Token('=', 23, 24), Token('H', 25, 26)], 55 | # start=0, 56 | # end=26, 57 | # sentence_tokenizer=ChemSentenceTokenizer(), 58 | # word_tokenizer=ChemWordTokenizer(), 59 | # lexicon=ChemLexicon(), 60 | # abbreviation_detector=ChemAbbreviationDetector(), 61 | # pos_tagger=ChemCrfPosTagger(), # ChemPerceptronTagger() 62 | # ner_tagger=CemTagger() 63 | # ) 64 | 65 | var_value_pairs = r_group.detect_r_group_from_sentence(sent) 66 | r_groups = r_group.get_label_candidates(sent, var_value_pairs) 67 | r_groups = r_group.standardize_values(r_groups) 68 | 69 | # Resolving positional labels where possible for 'or' cases 70 | r_groups = r_group.filter_repeated_labels(r_groups) 71 | 72 | # Separate duplicate variables into separate lists 73 | r_groups_list = r_group.separate_duplicate_r_groups(r_groups) 74 | 75 | output = [] 76 | for r_groups in r_groups_list: 77 | output.append(r_group.convert_r_groups_to_tuples(r_groups)) 78 | 79 | def test_r_group_simple_table(self): 80 | 81 | # Define a simple table structure 82 | table = [Sentence('R'), Sentence('1a CH3'), Sentence('1b Me')] 83 | 84 | output = r_group.resolve_r_group_grid(table) 85 | var, value, labels = output[0].convert_to_tuple() 86 | var2, value2, labels2 = output[1].convert_to_tuple() 87 | # tuple_output = [ (var.text, value.text, labels.text) for var, value, labels in output[0].convert_to_tuple()] 88 | self.assertEqual(var.text, 'R') 89 | self.assertEqual(value.text, 'CH3') 90 | self.assertEqual(labels[0].text, '1a') 91 | self.assertEqual(var2.text, 'R') 92 | self.assertEqual(value2.text, 'Me') 93 | self.assertEqual(labels2[0].text, '1b') 94 | 95 | def test_r_group_table(self): 96 | 97 | # Define a simple table structure 98 | table = [Sentence('R1 R2'), Sentence('1a CH3 C'), Sentence('1b Me Br')] 99 | 100 | output = r_group.resolve_r_group_grid(table) 101 | r_groups_list = r_group.separate_duplicate_r_groups(output) 102 | 103 | # Test the first r_group pair 104 | var1, value1, labels1 = r_groups_list[0][0].convert_to_tuple() 105 | var2, value2, labels2 = r_groups_list[0][1].convert_to_tuple() 106 | 107 | self.assertEqual(var1.text, 'R1') 108 | self.assertEqual(value1.text, 'CH3') 109 | self.assertEqual(labels1[0].text, '1a') 110 | self.assertEqual(var2.text, 'R2') 111 | self.assertEqual(value2.text, 'C') 112 | self.assertEqual(labels2[0].text, '1a') 113 | 114 | # Test the second r_group pair 115 | var1, value1, labels1 = r_groups_list[1][0].convert_to_tuple() 116 | var2, value2, labels2 = r_groups_list[1][1].convert_to_tuple() 117 | 118 | self.assertEqual(var1.text, 'R1') 119 | self.assertEqual(value1.text, 'Me') 120 | self.assertEqual(labels1[0].text, '1b') 121 | self.assertEqual(var2.text, 'R2') 122 | self.assertEqual(value2.text, 'Br') 123 | self.assertEqual(labels2[0].text, '1b') 124 | 125 | def test_r_group_assignment(self): 126 | """ 127 | Test assignment of multiple lines 128 | """ 129 | 130 | sentences = [Sentence('R1 = R2 = H'), Sentence('R1 = R2 = Ac')] 131 | out = [] 132 | for sentence in sentences: 133 | r_groups = r_group.detect_r_group_from_sentence(sentence, indicator='=') 134 | r_groups = r_group.standardize_values(r_groups) 135 | 136 | # Resolving positional labels where possible for 'or' cases 137 | r_groups = r_group.filter_repeated_labels(r_groups) 138 | 139 | # Separate duplicate variables into separate lists 140 | r_groups_list = r_group.separate_duplicate_r_groups(r_groups) 141 | 142 | out.append(r_groups_list) 143 | 144 | self.assertEqual(out[0][0][0].var.text, 'R1') 145 | self.assertEqual(out[0][0][0].value.text, 'R2') 146 | 147 | self.assertEqual(out[0][0][1].var.text, 'R2') 148 | self.assertEqual(out[0][0][1].value.text, '[H]') 149 | 150 | self.assertEqual(out[1][0][0].var.text, 'R1') 151 | self.assertEqual(out[1][0][0].value.text, 'R2') 152 | 153 | self.assertEqual(out[1][0][1].var.text, 'R2') 154 | self.assertEqual(out[1][0][1].value.text, 'Ac') 155 | -------------------------------------------------------------------------------- /tests/test_system.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_system 4 | ======== 5 | 6 | Test image processing on images from examples 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | import chemschematicresolver as csr 19 | import os 20 | import unittest 21 | import copy 22 | 23 | from skimage import img_as_float 24 | from matplotlib import pyplot as plt 25 | import matplotlib.patches as mpatches 26 | 27 | 28 | tests_dir = os.path.dirname(os.path.abspath(__file__)) 29 | train_dir = os.path.join(os.path.dirname(tests_dir), 'train') 30 | examples_dir = os.path.join(train_dir, 'train_imgs') 31 | markush_dir = os.path.join(train_dir, 'train_markush_small') 32 | r_group_diags_dir = os.path.join(train_dir, 'r_group_diags') 33 | labelled_output_dir = os.path.join(train_dir, 'output') 34 | 35 | class TestSystem(unittest.TestCase): 36 | 37 | # Testing that the images are being cleaned of floating pixels 38 | def do_diag_clean(self, filename, filedir=examples_dir): 39 | """ 40 | Tests that rouge pixel islands are removed for all diagrams in filename. 41 | Displays the individual diagram areas for human inspection 42 | :param filename 43 | :return: 44 | """ 45 | 46 | test_diag = os.path.join(filedir, filename) 47 | 48 | fig = csr.io.imread(test_diag) # Read in float and raw pixel images 49 | raw_fig = copy.deepcopy(fig) # Create unreferenced binary copy 50 | 51 | panels = csr.actions.segment(raw_fig) 52 | print('Segmented panel number : %s ' % len(panels)) 53 | 54 | labels, diags = csr.actions.classify_kmeans(panels) 55 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 56 | all_panels = labels + diags 57 | print('After processing : %s' % len(all_panels)) 58 | 59 | # Show diagrams in blue 60 | for panel in diags: 61 | 62 | # Create output image 63 | out_fig, ax = plt.subplots(figsize=(10, 6)) 64 | ax.imshow(panel.fig.img) 65 | 66 | ax.set_axis_off() 67 | plt.show() 68 | 69 | def test_diag_clean_all(self): 70 | """ 71 | Test all diagrams in train_imgs 72 | :return: 73 | """ 74 | 75 | test_path = examples_dir 76 | test_imgs = os.listdir(test_path) 77 | for img_path in test_imgs: 78 | self.do_diag_clean(img_path, filedir=test_path) 79 | 80 | def test_diag_clean_1(self): 81 | self.do_diag_clean('S0143720816301681_gr1.jpg') 82 | 83 | def test_diag_clean_2(self): 84 | self.do_diag_clean('S014372081630122X_gr1.jpg') 85 | 86 | def test_diag_clean_3(self): 87 | self.do_diag_clean('S0143720816301565_gr1.jpg', filedir=r_group_diags_dir) 88 | 89 | 90 | # Testing sementation is sucessful 91 | def do_segmentation(self, filename, filedir=examples_dir): 92 | ''' 93 | Tests bounding box assignment for filename, and kmeans classification into diagrams (blue) and labels (red) 94 | 95 | :param filename: 96 | :return: 97 | ''' 98 | 99 | test_diag = os.path.join(filedir, filename) 100 | 101 | fig = csr.io.imread(test_diag) # Read in float and raw pixel images 102 | raw_fig = copy.deepcopy(fig) # Create unreferenced binary copy 103 | 104 | panels = csr.actions.segment(raw_fig, size=3) 105 | print('Segmented panel number : %s ' % len(panels)) 106 | 107 | labels, diags = csr.actions.classify_kmeans(panels) 108 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 109 | all_panels = labels + diags 110 | print('After processing : %s' % len(all_panels)) 111 | 112 | # Create output image 113 | out_fig, ax = plt.subplots(figsize=(10, 6)) 114 | ax.imshow(fig.img) 115 | 116 | # Show diagrams in blue 117 | for panel in diags: 118 | 119 | diag_rect = mpatches.Rectangle((panel.left, panel.top), panel.width, panel.height, 120 | fill=False, edgecolor='b', linewidth=2) 121 | ax.text(panel.left, panel.top + panel.height / 4, '[%s]' % panel.tag, size=panel.height / 20, color='r') 122 | ax.add_patch(diag_rect) 123 | 124 | # Show labels in red 125 | for panel in labels: 126 | 127 | diag_rect = mpatches.Rectangle((panel.left, panel.top), panel.width, panel.height, 128 | fill=False, edgecolor='r', linewidth=2) 129 | ax.text(panel.left, panel.top + panel.height / 4, '[%s]' % panel.tag, size=panel.height / 20, color='r') 130 | ax.add_patch(diag_rect) 131 | 132 | ax.set_axis_off() 133 | plt.show() 134 | 135 | 136 | def test_segmentation_all(self): 137 | 138 | test_path = examples_dir 139 | test_imgs = os.listdir(test_path) 140 | for img_path in test_imgs: 141 | self.do_segmentation(img_path, filedir=test_path) 142 | 143 | def test_variable_cases(self): 144 | 145 | self.do_segmentation('S0143720816300286_gr1.jpg') 146 | self.do_segmentation('S0143720816301115_gr1.jpg') 147 | self.do_segmentation('S0143720816301115_gr4.jpg') 148 | self.do_segmentation('S014372081630167X_sc1.jpg') 149 | 150 | def test_segmentation1(self): 151 | 152 | self.do_segmentation('S014372081630119X_gr1.jpg') 153 | 154 | def test_segmentation2(self): 155 | self.do_segmentation('S014372081630122X_gr1.jpg') 156 | 157 | def test_segmentation3(self): 158 | # TODO : noise remover? Get rid of connected components a few pixels in size? 159 | self.do_segmentation('S014372081630167X_sc1.jpg') # This one isn't identifying the repeating unit, or label 160 | 161 | def test_segmentation4(self): 162 | self.do_segmentation('S014372081730116X_gr8.jpg') 163 | 164 | def test_segmentation5(self): 165 | self.do_segmentation('S0143720816300201_sc2.jpg') 166 | 167 | def test_segmentation6(self): 168 | self.do_segmentation('S0143720816300274_gr1.jpg') 169 | 170 | def test_segmentation7(self): 171 | self.do_segmentation('S0143720816300419_sc1.jpg') 172 | 173 | def test_segmentation8(self): 174 | self.do_segmentation('S0143720816300559_sc2.jpg') 175 | 176 | def test_segmentation9(self): 177 | self.do_segmentation('S0143720816300821_gr2.jpg') 178 | 179 | def test_segmentation10(self): 180 | self.do_segmentation('S0143720816300900_gr2.jpg') 181 | 182 | def test_segmentation11(self): 183 | self.do_segmentation('S0143720816301115_gr1.jpg') 184 | 185 | def test_segmentation12(self): 186 | self.do_segmentation('S0143720816301115_gr4.jpg') 187 | 188 | def test_segmentation_markush_img(self): 189 | self.do_segmentation('S0143720816301115_r75.jpg') 190 | 191 | def test_segmentation_markush_img2(self): 192 | self.do_segmentation('S0143720816300286_gr1.jpg') 193 | 194 | def test_segmentation_markush_img3(self): 195 | self.do_segmentation('S0143720816301681_gr1.jpg') 196 | 197 | def test_segmentation_r_group_diags_img1(self): 198 | self.do_segmentation('S0143720816301565_gr1.jpg', r_group_diags_dir) 199 | 200 | def test_segmentation_r_group_diags_img2(self): 201 | self.do_segmentation('S0143720816302054_sc1.jpg', filedir=r_group_diags_dir) 202 | 203 | def test_segmentation_r_group_diags_img3(self): 204 | 205 | self.do_segmentation('S0143720816301401_gr5.jpg', r_group_diags_dir) 206 | 207 | def test_segmentation_13(self): 208 | self.do_segmentation('10.1039_C4TC01753F_fig1.gif', filedir='/home/edward/github/csr-development/csd') 209 | 210 | # Testing grouping of diagram - label pairs is correct 211 | def do_grouping(self, filename, filedir=examples_dir): 212 | ''' 213 | Tests grouping of label-diagram pairs, where label and diagram have the same coloured bbox 214 | To be checked by a human 215 | 216 | :param filename: 217 | :return: 218 | ''' 219 | 220 | test_diag = os.path.join(filedir, filename) 221 | 222 | # Read in float and raw pixel images 223 | fig = csr.io.imread(test_diag) 224 | fig_bbox = fig.get_bounding_box() 225 | 226 | # Create unreferenced binary copy 227 | bin_fig = copy.deepcopy(fig) 228 | 229 | # Segment and classify diagrams 230 | panels = csr.actions.segment(bin_fig) 231 | labels, diags = csr.actions.classify_kmeans(panels) 232 | 233 | # Preprocessing cleaning and merging 234 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 235 | 236 | # Create output image 237 | out_fig, ax = plt.subplots(figsize=(10, 6)) 238 | ax.imshow(fig.img) 239 | 240 | # Assign labels to diagrams 241 | labelled_diags = csr.actions.label_diags(labels, diags, fig_bbox) 242 | 243 | colours = iter(['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y']) 244 | 245 | for diag in labelled_diags: 246 | colour = next(colours) 247 | 248 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 249 | fill=False, edgecolor=colour, linewidth=2) 250 | ax.add_patch(diag_rect) 251 | 252 | label = diag.label 253 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 254 | fill=False, edgecolor=colour, linewidth=2) 255 | ax.add_patch(label_rect) 256 | 257 | ax.set_axis_off() 258 | plt.show() 259 | 260 | def test_grouping_all(self): 261 | test_path = examples_dir 262 | test_imgs = os.listdir(test_path) 263 | for img_path in test_imgs: 264 | print(img_path) 265 | self.do_grouping(img_path, filedir=test_path) 266 | 267 | def test_grouping1(self): 268 | self.do_grouping('S014372081630119X_gr1.jpg') 269 | 270 | def test_grouping2(self): 271 | self.do_grouping('S014372081630122X_gr1.jpg') 272 | 273 | def test_grouping3(self): 274 | self.do_grouping('S014372081630167X_sc1.jpg') 275 | 276 | def test_grouping4(self): 277 | self.do_grouping('S014372081730116X_gr8.jpg') 278 | 279 | def test_grouping5(self): 280 | self.do_grouping('S0143720816300201_sc2.jpg') 281 | 282 | def test_grouping6(self): 283 | self.do_grouping('S0143720816300274_gr1.jpg') 284 | 285 | def test_grouping7(self): 286 | self.do_grouping('S0143720816300419_sc1.jpg') 287 | 288 | def test_grouping8(self): 289 | self.do_grouping('S0143720816300559_sc2.jpg') 290 | 291 | def test_grouping9(self): 292 | self.do_grouping('S0143720816300821_gr2.jpg') 293 | 294 | def test_grouping10(self): 295 | self.do_grouping('S0143720816300900_gr2.jpg') 296 | 297 | def test_grouping_r_group_diags(self): 298 | self.do_grouping('S0143720816302054_sc1.jpg', filedir=r_group_diags_dir) 299 | 300 | def test_grouping_markush(self): 301 | self.do_grouping('S0143720816300286_gr1.jpg') 302 | 303 | def do_ocr(self, filename, filedir=examples_dir): 304 | """ Tests the OCR recognition of labels.""" 305 | 306 | test_diag = os.path.join(filedir, filename) 307 | 308 | # Read in float and raw pixel images 309 | fig = csr.io.imread(test_diag) 310 | fig_bbox = fig.get_bounding_box() 311 | 312 | # Create unreferenced binary copy 313 | bin_fig = copy.deepcopy(fig) 314 | 315 | # Segment and classify diagrams and labels 316 | panels = csr.actions.segment(bin_fig) 317 | labels, diags = csr.actions.classify_kmeans(panels) 318 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 319 | 320 | # Create output image 321 | out_fig, ax = plt.subplots(figsize=(10, 6)) 322 | ax.imshow(fig.img) 323 | 324 | # Assign labels to diagrams 325 | labelled_diags = csr.actions.label_diags(labels, diags, fig_bbox) 326 | 327 | colours = iter( 328 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y']) 329 | 330 | labels_text = [] 331 | 332 | for diag in labelled_diags: 333 | colour = next(colours) 334 | 335 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 336 | fill=False, edgecolor=colour, linewidth=2) 337 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 338 | ax.add_patch(diag_rect) 339 | 340 | label = diag.label 341 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 342 | fill=False, edgecolor=colour, linewidth=2) 343 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 344 | ax.add_patch(label_rect) 345 | 346 | label = csr.actions.read_label(fig, label) 347 | label_strings = [csr.actions.clean_output(sentence.text) for sentence in label.text] 348 | labels_text.append(label_strings) 349 | print("Label %s : %s " % (label.tag, labels_text)) 350 | 351 | ax.set_axis_off() 352 | plt.show() 353 | 354 | return labels_text 355 | 356 | def test_ocr_all(self): 357 | 358 | test_path = examples_dir 359 | test_imgs = os.listdir(test_path) 360 | for img_path in test_imgs: 361 | self.do_ocr(img_path, filedir=test_path) 362 | 363 | def test_ocr1(self): 364 | labels_text = self.do_ocr('S014372081630119X_gr1.jpg') 365 | gold = [['MeNAPH:R=CH3', 'MeONAPH:R=OCH3'], ['EtNAPH']] 366 | self.assertListEqual(gold, labels_text) 367 | 368 | # TODO : Update all OCR tests to reflect new format 369 | 370 | def test_ocr2(self): 371 | labels_text = self.do_ocr('S014372081630122X_gr1.jpg') 372 | gold = [['Q4'], ['Q1'], ['Q2'], ['Q3']] 373 | self.assertEqual(gold, labels_text) 374 | 375 | def test_ocr3(self): 376 | # Currently failing - not getting B in the second label 377 | labels_text = self.do_ocr('S014372081630167X_sc1.jpg') 378 | gold = [['TPE-SQ'], ['PC71BM']] 379 | self.assertEqual(gold, labels_text) 380 | 381 | def test_ocr4(self): 382 | labels_text = self.do_ocr('S014372081730116X_gr8.jpg') 383 | gold = [['J51'], ['PDBT-T1'], ['J61'], ['R=2-ethylhexyl', 'PBDB-T'], ['R=2-ethylhexyl', 'PBDTTT-E-T'], ['R=2-ethylhexyl', 'PTB7-Th'], 384 | ['R=2-ethylhexyl', 'PBDTTT-C-T']] 385 | for item in gold: 386 | self.assertIn(item, labels_text) 387 | 388 | def test_ocr5(self): 389 | labels_text = self.do_ocr('S0143720816300201_sc2.jpg') 390 | gold = [['9(>99%)'], ['1(82%)'], ['3(86%)'], ['7(94%)'], 391 | ['4(78%)'], ['5(64%)'], ['2(78%)'], ['6(75%)'], ['8(74%)']] 392 | for x in gold: 393 | self.assertIn(x, labels_text) 394 | 395 | def test_ocr6(self): 396 | labels = self.do_ocr('S0143720816300274_gr1.jpg') 397 | gold = [['8c'], ['8b'], ['8a'], ['7c'], ['7b'], ['7a']] 398 | for x in gold: 399 | self.assertIn(x, labels) 400 | 401 | def test_ocr7(self): 402 | labels = self.do_ocr('S0143720816300419_sc1.jpg') 403 | gold = [['DDOF'], ['DPF'], ['NDOF'], ['PDOF']] 404 | for x in gold: 405 | self.assertIn(x, labels) 406 | 407 | def test_ocr8(self): 408 | labels = self.do_ocr('S0143720816300559_sc2.jpg') 409 | gold = [['1'], ['2'], ['3']] 410 | for x in gold: 411 | self.assertIn(x, labels) 412 | 413 | def test_ocr9(self): 414 | labels = self.do_ocr('S0143720816300821_gr2.jpg') # Need to add greyscale 415 | gold = [['9'], ['10']] 416 | for x in gold: 417 | self.assertIn(x, labels) 418 | 419 | def test_ocr10(self): 420 | # IR dye doesn't work 421 | labels = self.do_ocr('S0143720816300900_gr2.jpg') 422 | gold = [['ICG'], ['Compound10'], ['Compound13'], ['Compound11'], ['ZW800-1'], ['Compound12']] 423 | for x in gold: 424 | self.assertIn(x, labels) 425 | 426 | def test_ocr11(self): 427 | # Currently failing - can't detect some : 428 | labels = self.do_ocr('S0143720816301115_gr1.jpg') 429 | gold = [['1:R1=R2=H:TQEN'], ['2:R1=H,R2=OMe:T(MQ)EN'], ['3:R1=R2=OMe:T(TMQ)EN']] 430 | for x in gold: 431 | self.assertIn(x, labels) 432 | 433 | def do_r_group(self, filename, debug=False, filedir=examples_dir): 434 | """ Tests the R-group detection and recognition """ 435 | 436 | test_diag = os.path.join(filedir, filename) 437 | 438 | # Read in float and raw pixel images 439 | fig = csr.io.imread(test_diag) 440 | fig_bbox = fig.get_bounding_box() 441 | 442 | # Create unreferenced binary copy 443 | bin_fig = copy.deepcopy(fig) 444 | 445 | # Segment and classify diagrams and labels 446 | panels = csr.actions.segment(bin_fig) 447 | labels, diags = csr.actions.classify_kmeans(panels) 448 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 449 | 450 | # Create output image 451 | if debug is True: 452 | out_fig, ax = plt.subplots(figsize=(10, 6)) 453 | ax.imshow(fig.img) 454 | colours = iter( 455 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 456 | 'y']) 457 | 458 | labelled_diags = csr.actions.label_diags(labels, diags, fig_bbox) 459 | 460 | for diag in labelled_diags: 461 | 462 | label = diag.label 463 | diag.label = csr.actions.read_label(fig, label) 464 | diag = csr.r_group.detect_r_group(diag) 465 | 466 | if debug is True: 467 | 468 | colour = next(colours) 469 | 470 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 471 | fill=False, edgecolor=colour, linewidth=2) 472 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 473 | ax.add_patch(diag_rect) 474 | 475 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 476 | fill=False, edgecolor=colour, linewidth=2) 477 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 478 | ax.add_patch(label_rect) 479 | 480 | print(diag.label.r_group) 481 | 482 | if debug is True: 483 | ax.set_axis_off() 484 | plt.show() 485 | 486 | return labelled_diags 487 | 488 | def test_r_group_detection_all(self): 489 | 490 | test_path = examples_dir 491 | test_imgs = os.listdir(test_path) 492 | for img_path in test_imgs: 493 | self.do_ocr(img_path, filedir=test_path) 494 | 495 | def test_r_group1(self): 496 | labelled_diags = self.do_r_group('S014372081630119X_gr1.jpg') 497 | all_detected_r_groups_values = [(token[0].text, token[1].text)for diag in labelled_diags for tokens 498 | in diag.label.r_group for token in tokens] 499 | gold = [('R', 'OCH3'), ('R', 'CH3')] 500 | for x in gold: 501 | self.assertIn(x, all_detected_r_groups_values) 502 | 503 | def test_r_group2(self): 504 | """ Tests included to check the rgroup variable is unpopulated""" 505 | labelled_diags = self.do_r_group('S014372081630122X_gr1.jpg') 506 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 507 | self.assertTrue(len(all_detected_r_groups_values) is 0) 508 | 509 | def test_r_group3(self): 510 | labelled_diags = self.do_r_group('S014372081630167X_sc1.jpg') 511 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 512 | self.assertTrue(len(all_detected_r_groups_values) is 0) 513 | 514 | def test_r_group4(self): 515 | labelled_diags = self.do_r_group('S014372081730116X_gr8.jpg') 516 | all_detected_r_groups_values = [tokens for diag in labelled_diags for tokens in diag.label.r_group] 517 | unique_combos = [] 518 | for tokens in all_detected_r_groups_values: 519 | tuple_set = set() 520 | for token in tokens: 521 | tuple_set.add((token[0].text, token[1].text)) 522 | unique_combos.append(tuple_set) 523 | 524 | gold = {('R', '2-ethylhexyl')} # All R-groups give this for this example 525 | for diag in unique_combos: 526 | self.assertEqual(diag, gold) 527 | 528 | self.assertEqual(len(unique_combos), 4) 529 | 530 | def test_r_group5(self): 531 | labelled_diags = self.do_r_group('S0143720816300201_sc2.jpg') 532 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 533 | self.assertTrue(len(all_detected_r_groups_values) is 0) 534 | 535 | def test_r_group6(self): 536 | labelled_diags = self.do_r_group('S0143720816300274_gr1.jpg') 537 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 538 | self.assertTrue(len(all_detected_r_groups_values) is 0) 539 | 540 | def test_r_group7(self): 541 | labelled_diags = self.do_r_group('S0143720816300419_sc1.jpg') 542 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 543 | self.assertTrue(len(all_detected_r_groups_values) is 0) 544 | 545 | def test_r_group8(self): 546 | labelled_diags = self.do_r_group('S0143720816300559_sc2.jpg') 547 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 548 | self.assertTrue(len(all_detected_r_groups_values) is 0) 549 | 550 | def test_r_group9(self): 551 | labelled_diags = self.do_r_group('S0143720816300821_gr2.jpg') 552 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 553 | self.assertTrue(len(all_detected_r_groups_values) is 0) 554 | 555 | def test_r_group10(self): 556 | labelled_diags = self.do_r_group('S0143720816300900_gr2.jpg') 557 | all_detected_r_groups_values = [token[1].text for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 558 | self.assertTrue(len(all_detected_r_groups_values) is 0) 559 | 560 | def test_r_group11(self): 561 | # Currently failing : OCR performing poorly on semicolons 562 | labelled_diags = self.do_r_group('S0143720816301115_r75.jpg') 563 | all_detected_r_groups_values = [(token[0].text, token[1].text) for diag in labelled_diags for tokens in diag.label.r_group for token in tokens] 564 | gold = [('X', 'S'), ('X', 'O')] 565 | for x in gold: 566 | self.assertIn(x, all_detected_r_groups_values) 567 | 568 | def test_r_group12(self): 569 | labelled_diags = self.do_r_group('S0143720816301681_gr1.jpg') 570 | all_detected_r_groups_values = [tokens for diag in labelled_diags for tokens in diag.label.r_group] 571 | unique_combos = [] 572 | for tokens in all_detected_r_groups_values: 573 | tuple_set = set() 574 | for token in tokens: 575 | tuple_set.add((token[0].text, token[1].text)) 576 | unique_combos.append(tuple_set) 577 | 578 | gold = [{('R', 'CN'), ('X', 'NH')}, {('R', 'CN'), ('X', 'NC(O)OEt')}, {('R', 'CN'), ('X', 'O')}, {('R', 'H'), ('X', 'O')}] 579 | for diag in unique_combos: 580 | self.assertIn(diag, gold) 581 | 582 | def do_label_smile_resolution(self, filename, debug=True, filedir=examples_dir): 583 | """ Tests the R-group detection, recognition and resolution (using OSRA) 584 | NB : This is very similar to extract.extract_diagram, but it does not filter out the wildcard results 585 | This can be helpful to identify where OSRA is failing 586 | """ 587 | 588 | r_smiles = [] 589 | smiles = [] 590 | 591 | test_diag = os.path.join(filedir, filename) 592 | 593 | # Read in float and raw pixel images 594 | fig = csr.io.imread(test_diag) 595 | fig_bbox = fig.get_bounding_box() 596 | 597 | # Create unreferenced binary copy 598 | bin_fig = copy.deepcopy(fig) 599 | 600 | # Segment and classify diagrams and labels 601 | panels = csr.actions.segment(bin_fig) 602 | labels, diags = csr.actions.classify_kmeans(panels) 603 | labels, diags = csr.actions.preprocessing(labels, diags, fig) 604 | 605 | # Create output image 606 | if debug is True: 607 | out_fig, ax = plt.subplots(figsize=(10, 6)) 608 | ax.imshow(fig.img) 609 | colours = iter( 610 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 611 | 'y']) 612 | 613 | labelled_diags = csr.actions.label_diags(labels, diags, fig_bbox) 614 | 615 | for diag in labelled_diags: 616 | 617 | label = diag.label 618 | diag.label = csr.actions.read_label(fig, label) 619 | diag = csr.r_group.detect_r_group(diag) 620 | csr.extract.get_smiles(diag, smiles, r_smiles) 621 | 622 | if debug is True: 623 | colour = next(colours) 624 | 625 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 626 | fill=False, edgecolor=colour, linewidth=2) 627 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 628 | ax.add_patch(diag_rect) 629 | 630 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 631 | fill=False, edgecolor=colour, linewidth=2) 632 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 633 | ax.add_patch(label_rect) 634 | 635 | if debug is True: 636 | ax.set_axis_off() 637 | plt.show() 638 | 639 | total_smiles = r_smiles + smiles 640 | 641 | return total_smiles 642 | 643 | def test_label_smile_resolution1(self): 644 | smiles = self.do_label_smile_resolution('S014372081630119X_gr1.jpg') 645 | print('extracted Smiles are : %s' % smiles) 646 | 647 | gold = [(['MeNAPH'], 648 | 'c1c(ccc(c1)N(c1ccc(/C=C/c2c3c(c(cc2)/C=C/c2ccc(N(c4ccc(C)cc4)c4ccc(cc4)C)cc2)cccc3)cc1)c1ccc(C)cc1)C'), 649 | (['MeONAPH'], 650 | 'c1c(ccc(c1)N(c1ccc(/C=C/c2c3c(c(cc2)/C=C/c2ccc(N(c4ccc(OC)cc4)c4ccc(cc4)OC)cc2)cccc3)cc1)c1ccc(OC)cc1)OC'), 651 | (['EtNAPH'], 'c1c2n(c3c(c2ccc1)cc(cc3)/C=C/c1ccc(/C=C/c2ccc3n(c4ccccc4c3c2)CC)c2c1cccc2)CC')] 652 | 653 | self.assertEqual(gold, smiles) 654 | 655 | def test_label_smile_resolution2(self): 656 | smiles = self.do_label_smile_resolution('S014372081630122X_gr1.jpg', debug=True) 657 | print('extracted Smiles are : %s' % smiles) 658 | 659 | # TODO : Try this with tesseract (label resolution is poor) - currently broken 660 | 661 | gold = [ 662 | ['c1c(ccc(c1)N(c1ccc(/C=C/c2c3c(c(cc2)/C=C/c2ccc(N(c4ccc(C)cc4)c4ccc(cc4)C)cc2)cccc3)cc1)c1ccc(C)cc1)C', 663 | 'c1c(ccc(c1)N(c1ccc(/C=C/c2c3c(c(cc2)/C=C/c2ccc(N(c4ccc(OC)cc4)c4ccc(cc4)OC)cc2)cccc3)cc1)c1ccc(OC)cc1)OC']] 664 | 665 | self.assertEqual(gold, smiles) 666 | 667 | def test_label_smile_resolution5(self): 668 | smiles = self.do_label_smile_resolution('S0143720816300201_sc2.jpg') 669 | print('extracted Smiles are : %s' % smiles) 670 | 671 | gold = [(['5(64%)'], 'n1(cccc1C=C(C#N)C#N)C'), 672 | (['8(74%)'], 'C1COc2c(O1)csc2C=C(C#N)C#N'), 673 | (['2(78%)'], 'n1(c2c(c3c1cccc3)cc(C=C(C#N)C#N)cc2)*'), 674 | (['3(86%)'], 'c1(N(C)C)ccc(cc1)C=C(C#N)C#N'), 675 | (['9(>99%)'], 'C[Fe]C1(C*CCC1)C'), # this will never pass as uses incompatible notation 676 | (['1(82%)'], 'c1c2Cc3cc(C=C(C#N)C#N)ccc3c2ccc1'), 677 | (['7(94%)'], 'c1ccc(s1)c1sc(cc1)C=C(C#N)C#N'), 678 | (['4(78%)'], 'o1cccc1C=C(C#N)C#N'), 679 | (['6(75%)'], 's1cccc1C=C(C#N)C#N')] 680 | 681 | self.assertEqual(gold, smiles) 682 | 683 | def test_label_smile_resolution6(self): 684 | smiles = self.do_label_smile_resolution('S0143720816300274_gr1.jpg') 685 | print('extracted Smiles are : %s' % smiles) 686 | 687 | gold = [(['7a'], 'CC(c1ccc(c2sc(c3c2[nH]c(n3)c2ccc(c3ccc(/C=C(\\C#N)/C(=O)O)cc3)cc2)c2ccc(C(C)(C)C)cc2)cc1)(C)C'), 688 | (['7b'], 'CC(c1ccc(c2sc(c3nc([nH]c23)c2ccc(c3sc(/C=C(\\C#N)/C(=O)O)cc3)cc2)c2ccc(C(C)(C)C)cc2)cc1)(C)C'), 689 | (['7c'], 'CC(c1ccc(c2sc(c3c2[nH]c(n3)c2sc(cc2)c2ccc(/C=C(\\C#N)/C(=O)O)s2)c2ccc(C(C)(C)C)cc2)cc1)(C)C'), 690 | (['8a'], 'c1(cccs1)c1sc(c2nc(n(c12)CCCC)c1ccc(c2ccc(/C=C(\\C#N)/C(=O)O)cc2)cc1)c1sccc1'), 691 | (['8b'], 'c1cc(sc1)c1sc(c2c1n(CCCC)c(n2)c1ccc(c2ccc(/C=C(\\C#N)/C(=O)O)s2)cc1)c1sccc1'), 692 | (['8c'], 's1c(c2sc(c3nc(n(c23)CCCC)c2sc(cc2)c2sc(cc2)/C=C(\\C#N)/C(=O)O)c2sccc2)ccc1')] 693 | 694 | self.assertEqual(gold, smiles) 695 | 696 | def test_label_smile_resolution7(self): 697 | smiles = self.do_label_smile_resolution('S0143720816300286_gr1.jpg') 698 | print('extracted Smiles are : %s' % smiles) 699 | 700 | gold = ['c1(N(C)C)ccc(cc1)C=C(C#N)C#N', 'n1(c2c(c3c1cccc3)cc(C=C(C#N)C#N)cc2)*', 701 | 'c1c2Cc3cc(C=C(C#N)C#N)ccc3c2ccc1', 'o1cccc1C=C(C#N)C#N', 'n1(cccc1C=C(C#N)C#N)C', 702 | 's1cccc1C=C(C#N)C#N', 'C[Fe]C1(C*CCC1)C', 'c1ccc(s1)c1sc(cc1)C=C(C#N)C#N'] 703 | 704 | self.assertEqual(gold, smiles) 705 | 706 | def test_label_smile_resolution8(self): 707 | smiles = self.do_label_smile_resolution('S0143720816300419_sc1.jpg') 708 | print('extracted Smiles are : %s' % smiles) 709 | 710 | gold = [(['PDOF'], 'c12c3ccc(cc3C(=O)c2cc(cc1)c1ccc(C(=O)C)cc1)c1ccc(C(=O)C)cc1'), 711 | (['NDOF'], 'c12c3c(C(=O)c2cc(cc1)c1ccc(C(=O)OC)cc1)cc(c1ccc(C(=O)OC)cc1)cc3'), 712 | (['DDOF'], 'c12c3c(C(=O)c2cc(cc1)c1ccc(C=O)cc1)cc(c1ccc(C=O)cc1)cc3'), 713 | (['DPF'], 'c12c3c(C(=O)c2cc(c2ccccc2)cc1)cc(c1ccccc1)cc3')] 714 | 715 | self.assertEqual(gold, smiles) 716 | 717 | def test_label_smile_resolution9(self): 718 | smiles = self.do_label_smile_resolution('S0143720816300559_sc2.jpg') 719 | print('extracted Smiles are : %s' % smiles) 720 | 721 | gold = [(['3'], 'c1ccc2c(c1)nc(o2)c1c(O)c(OCC)c(c2oc3c(n2)cccc3)s1'), 722 | (['2'], 'c1ccc2c(c1)nc(o2)c1c(OCC)c(OCC)c(c2oc3c(n2)cccc3)s1'), 723 | (['1'], 'c1ccc2nc(oc2c1)c1c(O)c(O)c(c2oc3c(n2)cccc3)s1')] 724 | 725 | self.assertEqual(gold, smiles) 726 | 727 | def test_label_smile_resolution10(self): 728 | smiles = self.do_label_smile_resolution('S0143720816300821_gr2.jpg') 729 | print('extracted Smiles are : %s' % smiles) 730 | 731 | gold = [(['10'], 'c1(ccccc1)Nc1c2C(=O)c3ccccc3C(=O)c2c(Nc2cc(c(N)c3C(=O)c4c(C(=O)c23)cccc4)C)c(C)c1'), 732 | (['9'], 'c12ccccc1C(=O)c1c(Nc3c4C(=O)c5ccccc5C(=O)c4c(N)c(C)c3)c(C)ccc1C2=O')] 733 | 734 | self.assertEqual(gold, smiles) 735 | 736 | def test_label_smile_resolution11(self): 737 | smiles = self.do_label_smile_resolution('S0143720816300900_gr2.jpg') 738 | 739 | # TODO : Currently broken. Likely due to difficul-to-parse + and - signs in circles. Output contains wildcards and failures 740 | print('extracted Smiles are : %s' % smiles) 741 | 742 | gold = [(['lRDyeSOOCW'], 'C(C(=O)O)CCCCN1/C(=C/C=C\\2/C(=C(CCC2)/C=C/C2=[N](CCCCS(=O)(=O)O)c3c(C2(C)C)cc(cc3)S(=O)(=O)O)Oc2ccc(S(=O)(=O)OC)cc2)/C(C)(C)c2cc(S(=O)(=O)O)ccc12'), 743 | (['ZW800-1'], 'C(CC[N]1=C(/C=C/C2=C(Oc3ccc(CCC(=O)O)cc3)/C(=C/C=C\\3/N(CCC[N](=O)(C)(C)C)c4c(C3(C)C)cc(S(=O)(=O)O)cc4)/CCC2)C(c2c1ccc(S(=O)(=O)*)c2)(C)C)[N](C)(C)C'), 744 | (['Compound13'], 'S(=O)(=O)(CCC[N]1=C(/C=C/C2=C(c3ccc(CCC(=O)O)cc3)/C(=C/C=C\\3/N(CCCS(=O)(=O)O)c4c(C3(C)C)cc(S(=O)(=O)O)cc4)/CCC2)C(c2c1ccc(S(=O)(=O)O)c2)(C)C)O'), 745 | (['Compound10'], 'C(CCN1C(C(c2c1ccc1ccccc21)(C)C)/C=C/C1=C(c2ccc(CCC(=O)O)cc2)/C(=C/C=C\\2/N(CCCS(=O)(=O)O)c3c(C2(C)C)c2ccccc2cc3)/CCC1)S(=O)(=O)O'), 746 | (['Compound11'], 'S(=O)(=O)(CCC[N]1=C(/C=C/C2=C(c3ccc(CCC(=O)O)cc3)/C(=C/C=C\\3/N(CCCS(=O)(=O)*)c4c(C3(C)C)cccc4)/CCC2)C(c2c1cccc2)(C)C)O'), 747 | (['Compound12'], 'C1C(=C(c2ccc(CCC(=O)O)cc2)/C(=C/C=C/2\\C(C)(C)c3cc(S(=O)(=O)*)ccc3N2C)/CC1)/C=C/C1=[N](c2c(C1(C)C)cc(cc2)S(=O)(=O)*)C'), 748 | (['ICG'], 'C(CCC[N]1=C(C(c2c3c(ccc12)cccc3)(C)C)/C=C/CC(/C=C/C=C\\1/N(CCCCS(=O)(=O)O)c2c(C1(C)C)c1ccccc1cc2)C)S(=O)(=O)O')] 749 | 750 | self.assertEqual(gold, smiles) 751 | 752 | def test_label_smile_resolution12(self): 753 | # TODO : This diagram still fails in resolving the : in the label of R-group diagram 754 | smiles = self.do_label_smile_resolution('S0143720816301115_r75.jpg') 755 | print('extracted Smiles are : %s' % smiles) 756 | 757 | gold = [([], 'C1=C2C3C(=CC=C2)C=Cc2cc(c(O)c(C1)c32)/C=N\\CC*CC/N=C\\c1cc2ccc3c4c(ccc3)ccc(c1O)c24'), 758 | ([], 'C1=C2C3C(=CC=C2)C=Cc2cc(c(O)c(C1)c32)/C=N\\CC*CC/N=C\\c1cc2ccc3c4c(ccc3)ccc(c1O)c24'), 759 | (['107'], 'c1c2ccc3cccc4c3c2c(cc4)cc1/C=N/CCSSCC/N=C/c1cc2ccc3cccc4ccc(c1)c2c34'), 760 | (['106'], 'c1c2cccc3ccc4c(c(cc(c1)c4c23)/C=N/c1ccccc1O)O')] 761 | 762 | self.assertEqual(gold, smiles) 763 | 764 | def test_label_smile_resolution13(self): 765 | smiles = self.do_label_smile_resolution('S0143720816301681_gr1.jpg') 766 | print('extracted Smiles are : %s' % smiles) 767 | 768 | gold = [(['1'], 'N(CC)(CC)c1ccc2cc(C#N)c(=N)oc2c1'), 769 | (['2'], 'N(CC)(CC)c1ccc2cc(C#N)/c(=N/C(=O)OCC)/oc2c1'), 770 | (['3'], 'N(CC)(CC)c1ccc2cc(C#N)c(=O)oc2c1'), 771 | (['4'], 'N(CC)(CC)c1ccc2ccc(=O)oc2c1')] 772 | 773 | self.assertEqual(gold, smiles) 774 | 775 | def test_label_smile_resolution14(self): 776 | smiles = self.do_label_smile_resolution('S0143720816301115_gr4.jpg') 777 | print('extracted Smiles are : %s' % smiles) 778 | 779 | gold = [(['14:1-isoTQTACN'], 'n1ccc2c(c1CN1CCN(Cc3nccc4ccccc34)CCN(CC1)Cc1c3ccccc3ccn1)cccc2'), 780 | (['11'], 'c1cc(nc2c1cccc2)CN1CCN(Cc2nc3c(cc2)cccc3)CCN(CC1)Cc1ccc2ccccc2n1'), 781 | (['15:6-MeOTQTACN'], 'c1c(cc2c(c1)nc(cc2)CN1CCN(Cc2nc3c(cc2)cc(OC)cc3)CCN(CC1)Cc1ccc2cc(OC)ccc2n1)OC'), 782 | (['12'], 'N1CCNCCN(CC1)Cc1ccc2ccccc2n1'), 783 | (['13'], 'C(N1CCN(Cc2nc3ccccc3cc2)CCSCC1)c1nc2c(cc1)cccc2')] 784 | 785 | self.assertEqual(gold, smiles) 786 | 787 | def test_label_smile_resolution15(self): 788 | # Currently broken - need to improve resolution of colons 789 | 790 | smiles = self.do_label_smile_resolution('S0143720816301115_gr1.jpg') 791 | print('extracted Smiles are : %s' % smiles) 792 | 793 | gold = [(['1'], '*c1c2ccc(CN(CCN(Cc3nc4cc(c(c(c4cc3)*)*)*)Cc3nc4c(cc3)c(c(c(c4)*)*)*)Cc3ccc4c(*)c(*)c(*)cc4n3)nc2cc(*)c1*') 794 | ([], '*c1c2ccc(CN(CCN(Cc3nc4cc(c(c(c4cc3)*)*)*)Cc3nc4c(cc3)c(c(c(c4)*)*)*)Cc3ccc4c(*)c(*)c(*)cc4n3)nc2cc(*)c1*'), 795 | (['3', 'T(TMQ)EN'], '*c1c2ccc(CN(CCN(Cc3nc4cc(c(c(c4cc3)*)OC)*)Cc3nc4c(cc3)c(c(c(c4)*)OC)*)Cc3ccc4c(*)c(OC)c(*)cc4n3)nc2cc(*)c1OC')] 796 | 797 | self.assertEqual(gold, smiles) 798 | 799 | def do_osra(self, filename): 800 | """ Tests the OSRA chemical diagram recognition """ 801 | 802 | test_diag = os.path.join(examples_dir, filename) 803 | 804 | # Read in float and raw pixel images 805 | fig = csr.io.imread(test_diag) 806 | raw_fig = csr.io.imread(test_diag, raw=True) 807 | 808 | # Create unreferenced binary copy 809 | bin_fig = copy.deepcopy(fig) 810 | 811 | panels = csr.actions.segment(bin_fig) 812 | panels = csr.actions.preprocessing(panels, bin_fig) 813 | 814 | # Create output image 815 | out_fig, ax = plt.subplots(figsize=(10, 6)) 816 | ax.imshow(fig.img) 817 | 818 | diags, labels = csr.actions.classify_kruskal(panels) 819 | labelled_diags = csr.actions.label_kruskal(diags, labels) 820 | 821 | colours = iter( 822 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 823 | 'y']) 824 | 825 | smiles = [] 826 | 827 | for diag in labelled_diags: 828 | colour = next(colours) 829 | 830 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 831 | fill=False, edgecolor=colour, linewidth=2) 832 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 833 | ax.add_patch(diag_rect) 834 | 835 | label = diag.label 836 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 837 | fill=False, edgecolor=colour, linewidth=2) 838 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 839 | ax.add_patch(label_rect) 840 | 841 | smile, confidence = csr.actions.read_diagram(fig, diag) 842 | if '*' not in smile: 843 | print(smile, confidence) 844 | smiles.append(smile) 845 | print("Label {} ({}): {} ".format(diag.tag, confidence, smile)) 846 | 847 | 848 | ax.set_axis_off() 849 | plt.savefig(os.path.join(labelled_output_dir, filename)) 850 | 851 | return smiles 852 | 853 | def test_osra2(self): 854 | smiles = self.do_osra('S014372081630122X_gr1.jpg') 855 | print(smiles) 856 | 857 | """ Output with .jpg on 300dpi 858 | Label 4 (7.9463): N#C/C(=C\c1ccc(s1)c1ccc(c2c1*=C(*)C(=*2)**)c1ccc(s1)c1ccc(cc1)*(c1ccc(cc1)*)c1ccc(cc1)*)/C(=O)O 859 | Label 1 (5.9891): N#C/C(=C\c1ccc(s1)c1ccc(c2c1nc(**)c(n2)*)c1ccc(s1)c1ccc(cc1)*(c1ccccc1)c1ccccc1)/C(=O)O 860 | Label 5 (6.1918): N#C/C(=C\c1ccc(s1)c1cc(*)c(cc1*)c1ccc(s1)c1ccc(cc1)N(c1ccccc1)c1ccccc1)/C(=O)O 861 | Label 0 (9.4724): CCCCCCc1nc2c(c3ccc(s3)/C=C(/C(=O)O)\C#N)c3*=C(**)C(=*c3c(c2nc1*)c1ccc(s1)c1ccc(cc1)N(c1ccccc1)c1ccccc1)** """ 862 | 863 | def test_osra3(self): 864 | smiles = self.do_osra('S014372081630167X_sc1.jpg') 865 | print(smiles) 866 | 867 | def test_osra5(self): 868 | smiles = self.do_osra('S0143720816300201_sc2.jpg') 869 | print(smiles) 870 | 871 | 872 | def test_osra6(self): 873 | smiles = self.do_osra('S0143720816300274_gr1.jpg') 874 | print(smiles) 875 | 876 | def test_osra7(self): 877 | smiles = self.do_osra('S0143720816300419_sc1.jpg') 878 | print(smiles) 879 | 880 | def test_osra8(self): 881 | smiles = self.do_osra('S0143720816300559_sc2.jpg') 882 | print(smiles) 883 | 884 | def test_osra9(self): 885 | smiles = self.do_osra('S0143720816300821_gr2.jpg') 886 | print(smiles) 887 | 888 | def test_osra10(self): 889 | # IR dye doesn't work 890 | smiles= self.do_osra('S0143720816300900_gr2.jpg') 891 | print(smiles) 892 | 893 | 894 | class TestValidation(unittest.TestCase): 895 | 896 | def do_metrics(self, filename): 897 | """ Used to identify correlations between metrics and output validity""" 898 | 899 | test_fig = os.path.join(examples_dir, filename) 900 | 901 | # Read in float and raw pixel images 902 | fig = csr.io.imread(test_fig) 903 | raw_fig = csr.io.imread(test_fig, raw=True) 904 | 905 | # Create unreferenced binary copy 906 | bin_fig = copy.deepcopy(fig) 907 | 908 | panels = csr.actions.segment(bin_fig) 909 | panels = csr.actions.preprocessing(panels, fig) 910 | 911 | # Create output image 912 | out_fig, ax = plt.subplots(figsize=(10, 6)) 913 | ax.imshow(fig.img) 914 | 915 | diags, labels = csr.actions.classify_kruskal(panels) 916 | labelled_diags = csr.actions.label_kruskal(diags, labels) 917 | 918 | colours = iter( 919 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 920 | 'y']) 921 | 922 | smiles = [] 923 | 924 | avg_pixel_ratio = csr.validate.total_pixel_ratio(bin_fig, labelled_diags) 925 | diags_to_image_ratio = csr.validate.diagram_to_image_area_ratio(bin_fig, labelled_diags) 926 | avg_diag_area_to_total_img_ratio = csr.validate.avg_diagram_area_to_image_area(bin_fig, labelled_diags) 927 | 928 | for diag in labelled_diags: 929 | colour = next(colours) 930 | 931 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 932 | fill=False, edgecolor=colour, linewidth=2) 933 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 934 | ax.add_patch(diag_rect) 935 | 936 | label = diag.label 937 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 938 | fill=False, edgecolor=colour, linewidth=2) 939 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 940 | ax.add_patch(label_rect) 941 | 942 | smile, confidence = csr.actions.read_diagram(fig, diag) 943 | smiles.append(smile) 944 | print("Label {} ({}): {} ".format(diag.tag, confidence, smile)) 945 | print("Black pixel ratio : %s " % csr.validate.pixel_ratio(bin_fig, diag)) 946 | 947 | print('Overall diagram metrics:') 948 | print('Average 1 / all ratio: %s' % avg_pixel_ratio) 949 | print('Average diag / fig area ratio: %s' % avg_diag_area_to_total_img_ratio) 950 | print('Diag number to fig area ratio: %s' % diags_to_image_ratio) 951 | 952 | 953 | ax.set_axis_off() 954 | plt.savefig(os.path.join(labelled_output_dir, filename)) 955 | 956 | return smiles 957 | 958 | def test_validation2(self): 959 | smiles = self.do_metrics('S014372081630122X_gr1.jpg') 960 | 961 | def test_validation3(self): 962 | smiles = self.do_metrics('S014372081630167X_sc1.jpg') 963 | 964 | def test_validation5(self): 965 | smiles = self.do_metrics('S0143720816300201_sc2.jpg') 966 | 967 | def test_validation6(self): 968 | smiles = self.do_metrics('S0143720816300274_gr1.jpg') 969 | 970 | def test_validation7(self): 971 | smiles = self.do_metrics('S0143720816300419_sc1.jpg') 972 | 973 | def test_validation8(self): 974 | smiles = self.do_metrics('S0143720816300559_sc2.jpg') 975 | 976 | def test_validation9(self): 977 | smiles = self.do_metrics('S0143720816300821_gr2.jpg') 978 | 979 | def test_validation10(self): 980 | # IR dye doesn't work 981 | smiles= self.do_metrics('S0143720816300900_gr2.jpg') 982 | 983 | 984 | class TestFiltering(unittest.TestCase): 985 | """ Tests the results filtering via wildcard removal and pybel validation """ 986 | 987 | def do_filtering(self, filename): 988 | """ Used to identify correlations between metrics and output validity""" 989 | test_fig = os.path.join(examples_dir, filename) 990 | 991 | # Read in float and raw pixel images 992 | fig = csr.io.imread(test_fig) 993 | raw_fig = csr.io.imread(test_fig, raw=True) 994 | 995 | # Create unreferenced binary copy 996 | bin_fig = copy.deepcopy(fig) 997 | 998 | # Preprocessing steps 999 | panels = csr.actions.segment(bin_fig) 1000 | panels = csr.actions.preprocessing(panels, fig) 1001 | 1002 | # Create output image 1003 | out_fig, ax = plt.subplots(figsize=(10, 6)) 1004 | ax.imshow(fig.img) 1005 | 1006 | # Get label pairs 1007 | diags, labels = csr.actions.classify_kruskal(panels) 1008 | labelled_diags = csr.actions.label_kruskal(diags, labels) 1009 | 1010 | colours = iter( 1011 | ['r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 'y', 'r', 'b', 'g', 'k', 'c', 'm', 1012 | 'y']) 1013 | 1014 | diags_with_smiles = [] 1015 | 1016 | for diag in labelled_diags: 1017 | colour = next(colours) 1018 | 1019 | diag_rect = mpatches.Rectangle((diag.left, diag.top), diag.width, diag.height, 1020 | fill=False, edgecolor=colour, linewidth=2) 1021 | ax.text(diag.left, diag.top + diag.height / 4, '[%s]' % diag.tag, size=diag.height / 20, color='r') 1022 | ax.add_patch(diag_rect) 1023 | 1024 | label = diag.label 1025 | label_rect = mpatches.Rectangle((label.left, label.top), label.width, label.height, 1026 | fill=False, edgecolor=colour, linewidth=2) 1027 | ax.text(label.left, label.top + label.height / 4, '[%s]' % label.tag, size=label.height / 5, color='r') 1028 | ax.add_patch(label_rect) 1029 | 1030 | smile, confidence = csr.actions.read_diagram(fig, diag) 1031 | diag.smile = smile 1032 | diags_with_smiles.append(diag) 1033 | 1034 | # Run post-processing: 1035 | formatted_smiles = csr.validate.format_all_smiles(diags_with_smiles) 1036 | print(formatted_smiles) 1037 | return formatted_smiles 1038 | 1039 | def test_filtering2(self): 1040 | smiles = self.do_filtering('S014372081630122X_gr1.jpg') 1041 | 1042 | def test_filtering3(self): 1043 | smiles = self.do_filtering('S014372081630167X_sc1.jpg') 1044 | 1045 | def test_filtering5(self): 1046 | smiles = self.do_filtering('S0143720816300201_sc2.jpg') 1047 | 1048 | def test_filtering6(self): 1049 | smiles = self.do_filtering('S0143720816300274_gr1.jpg') 1050 | 1051 | def test_filtering7(self): 1052 | smiles = self.do_filtering('S0143720816300419_sc1.jpg') 1053 | 1054 | def test_filtering8(self): 1055 | smiles = self.do_filtering('S0143720816300559_sc2.jpg') 1056 | 1057 | def test_filtering9(self): 1058 | smiles = self.do_filtering('S0143720816300821_gr2.jpg') 1059 | 1060 | def test_filtering10(self): 1061 | # IR dye doesn't work 1062 | smiles= self.do_filtering('S0143720816300900_gr2.jpg') -------------------------------------------------------------------------------- /tests/test_validate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_validation 4 | ======== 5 | 6 | Test image processing on images from examples 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | from __future__ import unicode_literals 14 | import logging 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | import chemschematicresolver.validate as val 19 | import unittest 20 | 21 | 22 | class TestValidation(unittest.TestCase): 23 | 24 | def test_remove_false_positives(self): 25 | 26 | self.assertTrue(val.is_false_positive(([], 'C1CCCCC1C2CCCCC2'))) 27 | self.assertTrue(val.is_false_positive((['3a'], 'C1CC*CC1C2CCCCC2'))) 28 | self.assertFalse(val.is_false_positive((['3a'], 'C1CCCCC1C2CCCCC2'))) 29 | --------------------------------------------------------------------------------