├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── deep_reference_parser ├── __init__.py ├── __main__.py ├── __version__.py ├── common.py ├── configs │ ├── 2020.3.19_multitask.ini │ ├── 2020.3.6_splitting.ini │ ├── 2020.3.8_parsing.ini │ └── 2020.4.5_multitask.ini ├── deep_reference_parser.py ├── io │ ├── __init__.py │ └── io.py ├── logger.py ├── model_utils.py ├── parse.py ├── prodigy │ ├── README.md │ ├── __init__.py │ ├── __main__.py │ ├── labels_to_prodigy.py │ ├── misc.py │ ├── numbered_reference_annotator.py │ ├── prodigy_to_tsv.py │ ├── reach_to_prodigy.py │ ├── reference_to_token_annotations.py │ └── spacy_doc_to_prodigy.py ├── reference_utils.py ├── split.py ├── split_parse.py ├── tokens_to_references.py └── train.py ├── pytest.ini ├── requirements.txt ├── requirements_test.txt ├── setup.py ├── tests ├── __init__.py ├── common.py ├── prodigy │ ├── __init__.py │ ├── common.py │ ├── test_data │ │ ├── test_numbered_references.jsonl │ │ ├── test_reach.jsonl │ │ ├── test_reference_to_token_expected.jsonl │ │ ├── test_reference_to_token_spans.jsonl │ │ ├── test_reference_to_token_tokens.jsonl │ │ ├── test_token_labelled_references.jsonl │ │ ├── test_token_labelled_references.tsv │ │ ├── test_tokens_to_tsv_spans.jsonl │ │ └── test_tokens_to_tsv_tokens.jsonl │ ├── test_labels_to_prodigy.py │ ├── test_misc.py │ ├── test_numbered_reference_annotator.py │ ├── test_prodigy_entrypoints.py │ ├── test_prodigy_to_tsv.py │ ├── test_reach_to_prodigy.py │ ├── test_reference_to_token_annotations.py │ └── test_spacy_doc_to_prodigy.py ├── test_data │ ├── test_config.ini │ ├── test_config_multitask.ini │ ├── test_jsonl.jsonl │ ├── test_load_tsv.tsv │ ├── test_references.txt │ ├── test_tsv_predict.tsv │ └── test_tsv_train.tsv ├── test_deep_reference_parser.py ├── test_deep_reference_parser_entrypoints.py ├── test_io.py ├── test_model_utils.py └── test_reference_utils.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | build/ 3 | data/ 4 | dist/ 5 | deep_reference_parser.egg-info/ 6 | */__pycache__/ 7 | deep_reference_parser/embeddings/ 8 | deep_reference_parser/models/ 9 | *.whl 10 | embeddings/ 11 | models/ 12 | .tox/ 13 | *__pycache__/ 14 | .coverage 15 | 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - 3.7 5 | 6 | install: 7 | - pip install -r requirements_test.txt 8 | - pip install tox-travis 9 | 10 | script: 11 | - tox 12 | 13 | cache: 14 | directories: 15 | - $HOME/.cache/pip 16 | 17 | branches: 18 | only: 19 | - master 20 | 21 | after_success: 22 | - python -m codecov 23 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 2020.4.23 - Pre-release 4 | 5 | * Add multitask split_parse command and tests, called with python -m deep_reference_parser split_parse 6 | * Fix issues with training data creation 7 | * Output predictions of validation data by default 8 | * Various improvements - using tox for testing, refactoring, improving error messages, README and tests 9 | 10 | ## 2020.3.3 - Pre-release 11 | 12 | NOTE: This version includes changes to both the way that model artefacts are packaged and saved, and the way that data are laded and parsed from tsv files. This results in a significantly faster training time (c.14 hours -> c.0.5 hour), but older models will no longer be compatible. For compatibility you must use multitask modles > 2020.3.19, splitting models > 2020.3.6, and parisng models > 2020.3.8. These models currently perform less well than previous versions, but performance is expected to improve with more data and experimentation predominatly around sequence length. 13 | 14 | * Adds support for a Multitask models as in the original Rodrigues paper 15 | * Combines artefacts into a single `indices.pickle` rather than the several previous pickles. Now the model just requires the embedding, `indices.pickle`, and `weights.h5`. 16 | * Updates load_tsv to better handle quoting. 17 | 18 | 19 | ## 2020.3.2 - Pre-release 20 | 21 | * Adds parse command that can be called with `python -m deep_reference_parser parse` 22 | * Rename predict command to 'split' which can be called with `python -m deep_reference_parser parse` 23 | * Squashes most `tensorflow`, `keras_contrib`, and `numpy` warnings in `__init__.py` resulting from old versions and soon-to-be deprecated functions. 24 | * Reduces verbosity of logging, improving CLI clarity. 25 | 26 | ## 2020.2.0 - Pre-release 27 | 28 | First release. Features train and predict functions tested mainly for the task of labelling reference (e.g. academic references) spans in policy documents (e.g. documents produced by government, NGOs, etc). 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (C) 2020 The Wellcome Trust 4 | Copyright (C) 2016-2019 ExplosionAI GmbH, 2016 spaCy GmbH, 2015 Matthew Honnibal 5 | Copyright (C) 2018 Digital Humanities Laboratory 6 | 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := all 2 | 3 | # Determine OS (from https://gist.github.com/sighingnow/deee806603ec9274fd47) 4 | UNAME_S := $(shell uname -s) 5 | ifeq ($(UNAME_S),Linux) 6 | OSFLAG := linux 7 | endif 8 | ifeq ($(UNAME_S),Darwin) 9 | OSFLAG := macosx 10 | endif 11 | 12 | # 13 | # Set file and version for embeddings and model, plus local paths 14 | # 15 | 16 | NAME := deep_reference_parser 17 | 18 | EMBEDDING_PATH := embeddings 19 | WORD_EMBEDDING := 2020.1.1-wellcome-embeddings-300 20 | WORD_EMBEDDING_TEST := 2020.1.1-wellcome-embeddings-10-test 21 | 22 | MODEL_PATH := models 23 | MODEL_VERSION := multitask/2020.4.5_multitask 24 | 25 | # 26 | # S3 Bucket 27 | # 28 | 29 | S3_BUCKET := s3://datalabs-public/deep_reference_parser 30 | S3_BUCKET_HTTP := https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser 31 | 32 | # 33 | # Create a virtualenv for local dev 34 | # 35 | 36 | VIRTUALENV := build/virtualenv 37 | 38 | $(VIRTUALENV)/.installed: requirements.txt 39 | @if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi 40 | @mkdir -p $(VIRTUALENV) 41 | virtualenv --python python3 $(VIRTUALENV) 42 | $(VIRTUALENV)/bin/pip3 install -r requirements.txt 43 | $(VIRTUALENV)/bin/pip3 install -r requirements_test.txt 44 | $(VIRTUALENV)/bin/pip3 install -e . 45 | touch $@ 46 | 47 | $(VIRTUALENV)/.en_core_web_sm: 48 | $(VIRTUALENV)/bin/python -m spacy download en_core_web_sm 49 | touch $@ 50 | 51 | 52 | .PHONY: virtualenv 53 | virtualenv: $(VIRTUALENV)/.installed $(VIRTUALENV)/.en_core_web_sm 54 | 55 | # 56 | # Get the word embedding 57 | # 58 | 59 | # Set the tar.gz as intermediate so it will be removed automatically 60 | .INTERMEDIATE: $(EMBEDDINGS_PATH)/$(WORD_EMBEDDING).tar.gz 61 | 62 | $(EMBEDDING_PATH)/$(WORD_EMBEDDING).tar.gz: 63 | @mkdir -p $(@D) 64 | curl $(S3_BUCKET_HTTP)/embeddings/$(@F) --output $@ 65 | 66 | $(EMBEDDING_PATH)/$(WORD_EMBEDDING).txt: $(EMBEDDING_PATH)/$(WORD_EMBEDDING).tar.gz 67 | tar -zxvf $< vectors.txt 68 | tail -n +2 vectors.txt > $@ 69 | rm vectors.txt 70 | 71 | embeddings: $(EMBEDDING_PATH)/$(WORD_EMBEDDING).txt 72 | 73 | # 74 | # Get the model artefacts and weights 75 | # 76 | 77 | artefact_targets = char2ind.pickle ind2label.pickle ind2word.pickle \ 78 | label2ind.pickle maxes.pickle word2ind.pickle \ 79 | weights.h5 80 | 81 | artefacts = $(addprefix $(MODEL_PATH)/$(MODEL_VERSION)/, $(artefact_targets)) 82 | 83 | $(artefacts): 84 | @mkdir -p $(@D) 85 | aws s3 cp $(S3_BUCKET)/models/$(MODEL_VERSION)/$(@F) $@ 86 | 87 | models: $(artefacts) 88 | 89 | 90 | datasets = data/splitting/2019.12.0_splitting_train.tsv \ 91 | data/splitting/2019.12.0_splitting_test.tsv \ 92 | data/splitting/2019.12.0_splitting_valid.tsv \ 93 | data/parsing/2020.3.2_parsing_train.tsv \ 94 | data/parsing/2020.3.2_parsing_test.tsv \ 95 | data/parsing/2020.3.2_parsing_valid.tsv \ 96 | data/multitask/2020.3.18_multitask_train.tsv \ 97 | data/multitask/2020.3.18_multitask_test.tsv \ 98 | data/multitask/2020.3.18_multitask_valid.tsv 99 | 100 | 101 | rodrigues_datasets = data/rodrigues/clean_train.txt \ 102 | data/rodrigues/clean_test.txt \ 103 | data/rodrigues/clean_valid.txt 104 | 105 | RODRIGUES_DATA_URL = https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing/raw/master/dataset/ 106 | 107 | $(datasets): 108 | @ mkdir -p $(@D) 109 | curl -s $(S3_BUCKET_HTTP)/$@ --output $@ 110 | 111 | $(rodrigues_datasets): 112 | @ mkdir -p data/rodrigues 113 | curl -sL $(RODRIGUES_DATA_URL)/$(@F) --output $@ 114 | 115 | data: $(datasets) $(rodrigues_datasets) 116 | 117 | # 118 | # Add model artefacts to s3 119 | # 120 | 121 | sync_model_to_s3: 122 | aws s3 sync --acl public-read $(MODEL_PATH)/$(MODEL_VERSION) \ 123 | $(S3_BUCKET)/models/$(MODEL_VERSION) 124 | 125 | # 126 | # Ship a new wheel to public s3 bucket, containing model weights 127 | # 128 | 129 | # Ship the wheel to the datalabs-public s3 bucket. Need to remove these build 130 | # artefacts otherwise they can make a mess of your build! Public access to 131 | # the wheel is granted with the --acl public-read flag. 132 | 133 | 134 | .PHONY: dist 135 | dist: 136 | -rm build/lib build/bin build/bdist.$(OSFLAG)* -r 137 | -rm deep_reference_parser-20* -r 138 | -rm deep_reference_parser.egg-info -r 139 | -rm dist/* 140 | $(VIRTUALENV)/bin/python3 setup.py sdist bdist_wheel 141 | aws s3 cp --recursive --exclude "*" --include "*.whl" --acl public-read dist/ $(S3_BUCKET) 142 | 143 | # 144 | # Tests 145 | # 146 | 147 | $(EMBEDDING_PATH)/$(WORD_EMBEDDING_TEST).txt: 148 | @mkdir -p $(@D) 149 | curl $(S3_BUCKET_HTTP)/embeddings/$(@F) --output $@ 150 | 151 | test_embedding: $(EMBEDDING_PATH)/$(WORD_EMBEDDING_TEST).txt 152 | 153 | test_artefacts = $(addprefix $(MODEL_PATH)/test/, $(artefact_targets)) 154 | 155 | $(test_artefacts): 156 | @mkdir -p $(@D) 157 | curl $(S3_BUCKET_HTTP)/models/test/$(@F) --output $@ 158 | 159 | .PHONY: test 160 | test: $(test_artefacts) test_embedding 161 | $(VIRTUALENV)/bin/tox 162 | 163 | all: virtualenv model embedding test 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/wellcometrust/deep_reference_parser.svg?branch=master)](https://travis-ci.org/wellcometrust/deep_reference_parser)[![codecov](https://codecov.io/gh/wellcometrust/deep_reference_parser/branch/master/graph/badge.svg)](https://codecov.io/gh/wellcometrust/deep_reference_parser) 2 | 3 | # Deep Reference Parser 4 | 5 | Deep Reference Parser is a Deep Learning Model for recognising references in free text. In this context we mean references to other works, for example an academic paper, or a book. Given an arbitrary block of text (nominally a section containing references), the model will extract the limits of the individual references, and identify key information like: authors, year published, and title. 6 | 7 | The model itself is a Bi-directional Long Short Term Memory (BiLSTM) Deep Neural Network with a stacked Conditional Random Field (CRF). It is designed to be used in the [Reach](https://github.com/wellcometrust/reach) application to replace a number of existing machine learning models which find references, and extract the constituent parts. 8 | 9 | The BiLSTM model is based on [Rodrigues et al. (2018)](https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing) who developed a model to find (split) references, parse them into contituent parts, and classify them according to the type of reference (e.g. primary reference, secondary reference, etc). This implementation of the model implements a the first two tasks and is intened for use in the medical field. Three models are implemented here: individual splitting and parsing models, and a combined multitask model which both splits and parses. We have not yet attempted to include reference type classification, but this may be done in the future. 10 | 11 | ### Current status: 12 | 13 | |Component|Individual|MultiTask| 14 | |---|---|---| 15 | |Spans (splitting)|✔️ Implemented|✔️ Implemented| 16 | |Components (parsing)|✔️ Implemented|✔️ Implemented| 17 | |Type (classification)|❌ Not Implemented|❌ Not Implemented| 18 | 19 | ### The model 20 | 21 | The model itself is based on the work of [Rodrigues et al. (2018)](https://github.com/dhlab-epfl/LinkedBooksDeepReferenceParsing), although the implemention here differs significantly. The main differences are: 22 | 23 | * We use a combination of the training data used by Rodrigues, et al. (2018) in addition to data that we have annotated ourselves. No Rodrigues et al. data are included in the test and validation sets. 24 | * We also use a new word embedding that has been trained on documents relevant to the field of medicine. 25 | * Whereas Rodrigues at al. split documents on lines, and sent the lines to the model, we combine the lines of the document together, and then send larger chunks to the model, giving it more context to work with when training and predicting. 26 | * Whilst the splitter model makes predictions at the token level, it outputs references by naively splitting on these tokens ([source](https://github.com/wellcometrust/deep_reference_parser/blob/master/deep_reference_parser/tokens_to_references.py)). 27 | * Hyperparameters are passed to the model in a config (.ini) file. This is to keep track of experiments, but also because it is difficult to save the model with the CRF architecture, so it is necesary to rebuild (not re-train!) the model object each time you want to use it. Storing the hyperparameters in a config file makes this easier. 28 | * The package ships with a [config file](https://github.com/wellcometrust/deep_reference_parser/blob/master/deep_reference_parser/configs/2020.3.19_multitask.ini) which defines the latest, highest performing model. The config file defines where to find the various objects required to build the model (index dictionaries, weights, embeddings), and will automatically fetch them when run, if they are not found locally. 29 | * The model includes a command line interface inspired by [SpaCy](https://github.com/explosion/spaCy); functions can be called from the command line with `python -m deep_reference_parser` ([source](https://github.com/wellcometrust/deep_reference_parser/blob/master/deep_reference_parser/predict.py)). 30 | * Python version updated to 3.7, along with dependencies (although more to do). 31 | 32 | ### Performance 33 | 34 | On the validation set. 35 | 36 | #### Finding references spans (splitting) 37 | 38 | Current mode version: *2020.3.6_splitting* 39 | 40 | |token|f1| 41 | |---|---| 42 | |b-r|0.8146| 43 | |e-r|0.7075| 44 | |i-r|0.9623| 45 | |o|0.8463| 46 | |weighted avg|0.9326| 47 | 48 | #### Identifying reference components (parsing) 49 | 50 | Current mode version: *2020.3.8_parsing* 51 | 52 | |token|f1| 53 | |---|---| 54 | |author|0.9053| 55 | |title|0.8607| 56 | |year|0.8639| 57 | |o|0.9340| 58 | |weighted avg|0.9124| 59 | 60 | #### Multitask model (splitting and parsing) 61 | 62 | Current mode version: *2020.4.5_multitask* 63 | 64 | |token|f1| 65 | |---|---| 66 | |author|0.9458| 67 | |title|0.9002| 68 | |year|0.8704| 69 | |o|0.9407| 70 | |parsing weighted avg|0.9285| 71 | |b-r|0.9111| 72 | |e-r|0.8788| 73 | |i-r|0.9726| 74 | |o|0.9332| 75 | |weighted avg|0.9591| 76 | 77 | #### Computing requirements 78 | 79 | Models are trained on AWS instances using CPU only. 80 | 81 | |Model|Time Taken|Instance type|Instance cost (p/h)|Total cost| 82 | |---|---|---|---|---| 83 | |Span detection|00:26:41|m4.4xlarge|$0.88|$0.39| 84 | |Components|00:17:22|m4.4xlarge|$0.88|$0.25| 85 | |MultiTask|00:42:43|c4.4xlarge|$0.91|$0.63| 86 | 87 | ## tl;dr: Just get me to the references! 88 | 89 | ``` 90 | # Install from github 91 | 92 | pip install git+git://github.com/wellcometrust/deep_reference_parser.git#egg=deep_reference_parser 93 | 94 | 95 | # Create references.txt with some references in it 96 | 97 | cat > references.txt < 252 | ``` 253 | 254 | If you wish to use a custom model that you have trained, you must specify the config file which defines the hyperparameters for that model using the `-c` flag: 255 | 256 | ``` 257 | python -m deep_reference_parser split -c new_model.ini 258 | ``` 259 | 260 | Use the `-t` flag to return the raw token predictions, and the `-v` to return everything in a much more user friendly format. 261 | 262 | Note that the model makes predictions at the token level, but a naive splitting is performed by simply splitting on the `b-r` tags. 263 | 264 | ### Developing the package further 265 | 266 | To create a local virtual environment and activate it: 267 | 268 | ``` 269 | make virtualenv 270 | 271 | # to activate 272 | 273 | source ./build/virtualenv/bin/activate 274 | ``` 275 | 276 | ## Get the data, models, and embeddings 277 | 278 | ``` 279 | make data models embeddings 280 | ``` 281 | 282 | ## Testing 283 | 284 | The package uses pytest: 285 | 286 | ``` 287 | make test 288 | ``` 289 | 290 | ## References 291 | 292 | Rodrigues Alves, D., Colavizza, G., & Kaplan, F. (2018). Deep Reference Mining From Scholarly Literature in the Arts and Humanities. Frontiers in Research Metrics and Analytics, 3(July), 1–13. https://doi.org/10.3389/frma.2018.00021 293 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "**/__*__.py" 3 | - "**/common.py" 4 | -------------------------------------------------------------------------------- /deep_reference_parser/__init__.py: -------------------------------------------------------------------------------- 1 | # Tensorflow and Keras emikt a very large number of warnings that are very 2 | # distracting on the command line. These lines here (while undesireable) 3 | # reduce the level of verbosity. 4 | 5 | import os 6 | import sys 7 | import warnings 8 | 9 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 10 | 11 | if not sys.warnoptions: 12 | warnings.filterwarnings("ignore", category=FutureWarning) 13 | warnings.filterwarnings("ignore", category=DeprecationWarning) 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | import tensorflow as tf 17 | 18 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 19 | 20 | from .common import download_model_artefact 21 | from .deep_reference_parser import DeepReferenceParser 22 | from .io import ( 23 | load_tsv, 24 | read_jsonl, 25 | read_pickle, 26 | write_jsonl, 27 | write_pickle, 28 | write_to_csv, 29 | write_tsv, 30 | ) 31 | from .logger import logger 32 | from .model_utils import get_config 33 | from .reference_utils import break_into_chunks 34 | from .tokens_to_references import tokens_to_references 35 | -------------------------------------------------------------------------------- /deep_reference_parser/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | 3 | """ 4 | Modified from https://github.com/explosion/spaCy/blob/master/spacy/__main__.py 5 | 6 | """ 7 | 8 | if __name__ == "__main__": 9 | import plac 10 | import sys 11 | from wasabi import msg 12 | from .train import train 13 | from .split import split 14 | from .parse import parse 15 | from .split_parse import split_parse 16 | 17 | commands = { 18 | "split": split, 19 | "parse": parse, 20 | "train": train, 21 | "split_parse": split_parse, 22 | } 23 | 24 | if len(sys.argv) == 1: 25 | msg.info("Available commands", ", ".join(commands), exits=1) 26 | command = sys.argv.pop(1) 27 | sys.argv[0] = "deep_reference_parser %s" % command 28 | 29 | if command in commands: 30 | plac.call(commands[command], sys.argv[1:]) 31 | else: 32 | available = "Available: {}".format(", ".join(commands)) 33 | msg.fail("Unknown command: {}".format(command), available, exits=1) 34 | -------------------------------------------------------------------------------- /deep_reference_parser/__version__.py: -------------------------------------------------------------------------------- 1 | __name__ = "deep_reference_parser" 2 | __version__ = "2020.4.5" 3 | __description__ = "Deep learning model for finding and parsing references" 4 | __url__ = "https://github.com/wellcometrust/deep_reference_parser" 5 | __author__ = "Wellcome Trust DataLabs Team" 6 | __author_email__ = "Grp_datalabs-datascience@Wellcomecloud.onmicrosoft.com" 7 | __license__ = "MIT" 8 | __splitter_model_version__ = "2020.3.6_splitting" 9 | __parser_model_version__ = "2020.3.8_parsing" 10 | __splitparser_model_version__ = "2020.4.5_multitask" 11 | -------------------------------------------------------------------------------- /deep_reference_parser/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | from logging import getLogger 6 | from urllib import parse, request 7 | 8 | from .__version__ import ( 9 | __parser_model_version__, 10 | __splitparser_model_version__, 11 | __splitter_model_version__, 12 | ) 13 | from .logger import logger 14 | 15 | 16 | def get_path(path): 17 | return os.path.join(os.path.dirname(__file__), path) 18 | 19 | 20 | SPLITTER_CFG = get_path(f"configs/{__splitter_model_version__}.ini") 21 | PARSER_CFG = get_path(f"configs/{__parser_model_version__}.ini") 22 | MULTITASK_CFG = get_path(f"configs/{__splitparser_model_version__}.ini") 23 | 24 | 25 | def download_model_artefact(artefact, s3_slug): 26 | """ Checks if model artefact exists and downloads if not 27 | 28 | Args: 29 | artefact (str): File to be downloaded 30 | s3_slug (str): http uri to latest model dir on s3, e.g.: 31 | https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser 32 | /models/latest 33 | """ 34 | 35 | path, _ = os.path.split(artefact) 36 | os.makedirs(path, exist_ok=True) 37 | 38 | if os.path.exists(artefact): 39 | logger.debug(f"{artefact} exists, nothing to be done...") 40 | else: 41 | logger.debug("Could not find %s. Downloading...", artefact) 42 | 43 | url = parse.urljoin(s3_slug, artefact) 44 | 45 | request.urlretrieve(url, artefact) 46 | 47 | 48 | def download_model_artefacts(model_dir, s3_slug, artefacts=None): 49 | """ 50 | """ 51 | 52 | if not artefacts: 53 | 54 | artefacts = [ 55 | "indices.pickle" "maxes.pickle", 56 | "weights.h5", 57 | ] 58 | 59 | for artefact in artefacts: 60 | artefact = os.path.join(model_dir, artefact) 61 | download_model_artefact(artefact, s3_slug) 62 | -------------------------------------------------------------------------------- /deep_reference_parser/configs/2020.3.19_multitask.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = 2020.3.19_multitask 3 | description = Same as 2020.3.13 but with adam rather than rmsprop 4 | deep_reference_parser_version = b61de984f95be36445287c40af4e65a403637692 5 | 6 | [data] 7 | # Note that test and valid proportion are only used for data creation steps, 8 | # not when running the train command. 9 | test_proportion = 0.25 10 | valid_proportion = 0.25 11 | data_path = data/ 12 | respect_line_endings = 0 13 | respect_doc_endings = 1 14 | line_limit = 150 15 | policy_train = data/multitask/2020.3.19_multitask_train.tsv 16 | policy_test = data/multitask/2020.3.19_multitask_test.tsv 17 | policy_valid = data/multitask/2020.3.19_multitask_valid.tsv 18 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 19 | 20 | [build] 21 | output_path = models/multitask/2020.3.19_multitask/ 22 | output = crf 23 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt 24 | pretrained_embedding = 0 25 | dropout = 0.5 26 | lstm_hidden = 400 27 | word_embedding_size = 300 28 | char_embedding_size = 100 29 | char_embedding_type = BILSTM 30 | optimizer = rmsprop 31 | 32 | [train] 33 | epochs = 60 34 | batch_size = 100 35 | early_stopping_patience = 5 36 | metric = val_f1 37 | 38 | -------------------------------------------------------------------------------- /deep_reference_parser/configs/2020.3.6_splitting.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = 2020.3.6_splitting 3 | description = Splitting model trained on a combination of Reach and Rodrigues 4 | data. The Rodrigues data have been concatenated into a single continuous 5 | document and then cut into sequences of length=line_length, so that the 6 | Rodrigues data and Reach data have the same lengths without need for much 7 | padding or truncating. 8 | deep_reference_parser_version = e489f7efa31072b95175be8f728f1fcf03a4cabb 9 | 10 | [data] 11 | test_proportion = 0.25 12 | valid_proportion = 0.25 13 | data_path = data/ 14 | respect_line_endings = 0 15 | respect_doc_endings = 1 16 | line_limit = 250 17 | policy_train = data/splitting/2020.3.6_splitting_train.tsv 18 | policy_test = data/splitting/2020.3.6_splitting_test.tsv 19 | policy_valid = data/splitting/2020.3.6_splitting_valid.tsv 20 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 21 | 22 | [build] 23 | output_path = models/splitting/2020.3.6_splitting/ 24 | output = crf 25 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt 26 | pretrained_embedding = 0 27 | dropout = 0.5 28 | lstm_hidden = 400 29 | word_embedding_size = 300 30 | char_embedding_size = 100 31 | char_embedding_type = BILSTM 32 | optimizer = rmsprop 33 | 34 | [train] 35 | epochs = 30 36 | batch_size = 100 37 | early_stopping_patience = 5 38 | metric = val_f1 39 | 40 | -------------------------------------------------------------------------------- /deep_reference_parser/configs/2020.3.8_parsing.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = 2020.3.8_parsing 3 | description = Parsing model trained on a combination of Reach and Rodrigues 4 | data. The Rodrigues data have been concatenated into a single continuous 5 | document and then cut into sequences of length=line_length, so that the 6 | Rodrigues data and Reach data have the same lengths without need for much 7 | padding or truncating. 8 | deep_reference_parser_version = e489f7efa31072b95175be8f728f1fcf03a4cabb 9 | 10 | [data] 11 | test_proportion = 0.25 12 | valid_proportion = 0.25 13 | data_path = data/ 14 | respect_line_endings = 0 15 | respect_doc_endings = 1 16 | line_limit = 100 17 | policy_train = data/parsing/2020.3.8_parsing_train.tsv 18 | policy_test = data/parsing/2020.3.8_parsing_test.tsv 19 | policy_valid = data/parsing/2020.3.8_parsing_valid.tsv 20 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 21 | 22 | [build] 23 | output_path = models/parsing/2020.3.8_parsing/ 24 | output = crf 25 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt 26 | pretrained_embedding = 0 27 | dropout = 0.5 28 | lstm_hidden = 400 29 | word_embedding_size = 300 30 | char_embedding_size = 100 31 | char_embedding_type = BILSTM 32 | optimizer = rmsprop 33 | 34 | [train] 35 | epochs = 30 36 | batch_size = 100 37 | early_stopping_patience = 5 38 | metric = val_f1 39 | -------------------------------------------------------------------------------- /deep_reference_parser/configs/2020.4.5_multitask.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = 2020.4.5_multitask 3 | description = Uses 2020.3.18 data 4 | deep_reference_parser_version = 9432b6e 5 | 6 | [data] 7 | # Note that test and valid proportion are only used for data creation steps, 8 | # not when running the train command. 9 | test_proportion = 0.25 10 | valid_proportion = 0.25 11 | data_path = data/ 12 | respect_line_endings = 0 13 | respect_doc_endings = 1 14 | line_limit = 150 15 | policy_train = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_train.tsv 16 | policy_test = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv 17 | policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_valid.tsv 18 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 19 | 20 | [build] 21 | output_path = models/multitask/2020.4.5_multitask/ 22 | output = crf 23 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300.txt 24 | pretrained_embedding = 0 25 | dropout = 0.5 26 | lstm_hidden = 400 27 | word_embedding_size = 300 28 | char_embedding_size = 100 29 | char_embedding_type = BILSTM 30 | optimizer = adam 31 | 32 | [train] 33 | epochs = 60 34 | batch_size = 100 35 | early_stopping_patience = 5 36 | metric = val_f1 37 | -------------------------------------------------------------------------------- /deep_reference_parser/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import (load_tsv, read_jsonl, read_pickle, write_jsonl, write_pickle, 2 | write_to_csv, write_tsv) 3 | -------------------------------------------------------------------------------- /deep_reference_parser/io/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | Utilities for loading and saving data from various formats 6 | """ 7 | 8 | import json 9 | import pickle 10 | import csv 11 | import os 12 | import pandas as pd 13 | 14 | from ..logger import logger 15 | 16 | def _unpack(tuples): 17 | """Convert list of tuples into the correct format: 18 | 19 | From: 20 | 21 | [ 22 | ( 23 | (token0, token1, token2, token3), 24 | (label0, label1, label2, label3), 25 | ), 26 | ( 27 | (token0, token1, token2), 28 | (label0, label1, label2), 29 | ), 30 | ) 31 | 32 | to: 33 | ] 34 | ( 35 | (token0, token1, token2, token3), 36 | (token0, token1, token2), 37 | ), 38 | ( 39 | (label0, label1, label2, label3), 40 | (label0, label1, label2), 41 | ), 42 | ] 43 | """ 44 | return list(zip(*list(tuples))) 45 | 46 | def _split_list_by_linebreaks(rows): 47 | """Cycle through a list of tokens (or labels) and split them into lists 48 | based on the presence of Nones or more likely math.nan caused by converting 49 | pd.DataFrame columns to lists. 50 | """ 51 | out = [] 52 | rows_gen = iter(rows) 53 | while True: 54 | try: 55 | row = next(rows_gen) 56 | token = row[0] 57 | # Check whether there are missing labels that have been converted 58 | # to float('nan') 59 | if isinstance(token, str) and any([not isinstance(label, str) for label in row]): 60 | pass 61 | elif isinstance(token, str) and token: 62 | out.append(row) 63 | else: 64 | yield out 65 | out = [] 66 | except StopIteration: 67 | if out: 68 | yield out 69 | break 70 | 71 | def load_tsv(filepath, split_char="\t"): 72 | """ 73 | Load and return the data stored in the given path. 74 | 75 | Expects data in the following format (tab separations). 76 | 77 | References o o 78 | 1 o o 79 | . o o 80 | WHO title b-r 81 | treatment title i-r 82 | guidelines title i-r 83 | for title i-r 84 | drug title i-r 85 | - title i-r 86 | resistant title i-r 87 | tuberculosis title i-r 88 | , title i-r 89 | 2016 title i-r 90 | 91 | Args: 92 | filepath (str): Path to the data. 93 | split_char(str): Character to be used to split each line of the 94 | document. 95 | 96 | Returns: 97 | a series of lists depending on the number of label columns provided in 98 | filepath. 99 | 100 | """ 101 | df = pd.read_csv(filepath, delimiter=split_char, header=None, skip_blank_lines=False, encoding="utf-8", quoting=csv.QUOTE_NONE, engine="python") 102 | 103 | tuples = _split_list_by_linebreaks(df.to_records(index=False)) 104 | 105 | # Remove leading empty lists if found 106 | 107 | tuples = list(filter(None, tuples)) 108 | 109 | unpacked_tuples = list(map(_unpack, tuples)) 110 | 111 | out = _unpack(unpacked_tuples) 112 | 113 | logger.debug("Loaded %s training examples", len(out[0])) 114 | 115 | return tuple(out) 116 | 117 | def write_jsonl(input_data, output_file): 118 | """ 119 | Write a dict to jsonl (line delimited json) 120 | 121 | Output format will look like: 122 | 123 | ``` 124 | {'a': 0} 125 | {'b': 1} 126 | {'c': 2} 127 | {'d': 3} 128 | ``` 129 | 130 | Args: 131 | input_data(dict): A dict to be written to json. 132 | output_file(str): Filename to which the jsonl will be saved. 133 | """ 134 | 135 | with open(output_file, "w") as fb: 136 | 137 | # Check if a dict (and convert to list if so) 138 | 139 | if isinstance(input_data, dict): 140 | input_data = [value for key, value in input_data.items()] 141 | 142 | # Write out to jsonl file 143 | 144 | logger.debug("Writing %s lines to %s", len(input_data), output_file) 145 | 146 | for i in input_data: 147 | json_ = json.dumps(i) + "\n" 148 | fb.write(json_) 149 | 150 | 151 | def _yield_jsonl(file_name): 152 | for row in open(file_name, "r"): 153 | yield json.loads(row) 154 | 155 | 156 | def read_jsonl(input_file): 157 | """Create a list from a jsonl file 158 | 159 | Args: 160 | input_file(str): File to be loaded. 161 | """ 162 | 163 | out = list(_yield_jsonl(input_file)) 164 | 165 | logger.debug("Read %s lines from %s", len(out), input_file) 166 | 167 | return out 168 | 169 | 170 | def write_to_csv(filename, columns, rows): 171 | """ 172 | Create a .csv file from data given as columns and rows 173 | 174 | Args: 175 | filename(str): Path and name of the .csv file, without csv extension 176 | columns(list): Columns of the csv file (First row of the file) 177 | rows: Data to write into the csv file, given per row 178 | """ 179 | 180 | with open(filename, "w") as csvfile: 181 | wr = csv.writer(csvfile, quoting=csv.QUOTE_ALL) 182 | wr.writerow(columns) 183 | 184 | for i, row in enumerate(rows): 185 | wr.writerow(row) 186 | logger.info("Wrote results to %s", filename) 187 | 188 | 189 | def write_pickle(input_data, output_file, path=None): 190 | """ 191 | Write an object to pickle 192 | 193 | Args: 194 | input_data(dict): A dict to be written to json. 195 | output_file(str): A filename or path to which the json will be saved. 196 | path(str): A string which will be prepended onto `output_file` with 197 | `os.path.join()`. Obviates the need for lengthy `os.path.join` 198 | statements each time this function is called. 199 | """ 200 | 201 | if path: 202 | 203 | output_file = os.path.join(path, output_file) 204 | 205 | with open(output_file, "wb") as fb: 206 | pickle.dump(input_data, fb) 207 | 208 | 209 | def read_pickle(input_file, path=None): 210 | """Create a list from a jsonl file 211 | 212 | Args: 213 | input_file(str): File to be loaded. 214 | path(str): A string which will be prepended onto `input_file` with 215 | `os.path.join()`. Obviates the need for lengthy `os.path.join` 216 | statements each time this function is called. 217 | """ 218 | 219 | if path: 220 | input_file = os.path.join(path, input_file) 221 | 222 | with open(input_file, "rb") as fb: 223 | out = pickle.load(fb) 224 | 225 | logger.debug("Read data from %s", input_file) 226 | 227 | return out 228 | 229 | def write_tsv(token_label_pairs, output_path): 230 | """ 231 | Write tsv files to disk 232 | """ 233 | with open(output_path, "w") as fb: 234 | writer = csv.writer(fb, delimiter="\t") 235 | writer.writerows(token_label_pairs) 236 | -------------------------------------------------------------------------------- /deep_reference_parser/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | """ 6 | 7 | import logging 8 | import os 9 | 10 | LOGGING_LEVEL = os.getenv("LOGGING_LEVEL") 11 | 12 | if isinstance(LOGGING_LEVEL, str): 13 | numeric_level = getattr(logging, LOGGING_LEVEL.upper(), 10) 14 | else: 15 | numeric_level = 20 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | logging.basicConfig( 20 | format="%(asctime)s %(name)s %(levelname)s: %(message)s", 21 | datefmt="%Y-%m-%d %H:%M:%S", 22 | level=numeric_level, 23 | ) 24 | -------------------------------------------------------------------------------- /deep_reference_parser/parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | """ 4 | Run predictions from a pre-trained model 5 | """ 6 | 7 | 8 | import itertools 9 | import os 10 | import json 11 | 12 | import en_core_web_sm 13 | import plac 14 | import spacy 15 | import wasabi 16 | 17 | from deep_reference_parser import __file__ 18 | from deep_reference_parser.common import download_model_artefact, PARSER_CFG 19 | from deep_reference_parser.deep_reference_parser import DeepReferenceParser 20 | from deep_reference_parser.logger import logger 21 | from deep_reference_parser.model_utils import get_config 22 | from deep_reference_parser.reference_utils import break_into_chunks 23 | from deep_reference_parser.tokens_to_references import tokens_to_references 24 | from deep_reference_parser.__version__ import __parser_model_version__ 25 | 26 | msg = wasabi.Printer(icons={"check": "\u2023"}) 27 | 28 | 29 | class Parser: 30 | def __init__(self, config_file): 31 | 32 | msg.info(f"Using config file: {config_file}") 33 | 34 | cfg = get_config(config_file) 35 | 36 | # Build config 37 | try: 38 | OUTPUT_PATH = cfg["build"]["output_path"] 39 | S3_SLUG = cfg["data"]["s3_slug"] 40 | except KeyError: 41 | config_dir, missing_config = os.path.split(config_file) 42 | files = os.listdir(config_dir) 43 | other_configs = [f for f in os.listdir(config_dir) if os.path.isfile(os.path.join(config_dir, f))] 44 | msg.fail(f"Could not find config {missing_config}, perhaps you meant one of {other_configs}") 45 | 46 | msg.info( 47 | f"Attempting to download model artefacts if they are not found locally in {cfg['build']['output_path']}. This may take some time..." 48 | ) 49 | 50 | # Check whether the necessary artefacts exists and download them if 51 | # not. 52 | 53 | artefacts = [ 54 | "indices.pickle", 55 | "weights.h5", 56 | ] 57 | 58 | for artefact in artefacts: 59 | with msg.loading(f"Could not find {artefact} locally, downloading..."): 60 | try: 61 | artefact = os.path.join(OUTPUT_PATH, artefact) 62 | download_model_artefact(artefact, S3_SLUG) 63 | msg.good(f"Found {artefact}") 64 | except: 65 | msg.fail(f"Could not download {S3_SLUG}{artefact}") 66 | logger.exception("Could not download %s%s", S3_SLUG, artefact) 67 | 68 | # Check on word embedding and download if not exists 69 | 70 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 71 | 72 | with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."): 73 | try: 74 | download_model_artefact(WORD_EMBEDDINGS, S3_SLUG) 75 | msg.good(f"Found {WORD_EMBEDDINGS}") 76 | except: 77 | msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}") 78 | logger.exception("Could not download %s", WORD_EMBEDDINGS) 79 | 80 | OUTPUT = cfg["build"]["output"] 81 | PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"] 82 | DROPOUT = float(cfg["build"]["dropout"]) 83 | LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"]) 84 | WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"]) 85 | CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"]) 86 | 87 | self.MAX_WORDS = int(cfg["data"]["line_limit"]) 88 | 89 | # Evaluate config 90 | 91 | self.drp = DeepReferenceParser(output_path=OUTPUT_PATH) 92 | 93 | # Encode data and load required mapping dicts. Note that the max word and 94 | # max char lengths will be loaded in this step. 95 | 96 | self.drp.load_data(OUTPUT_PATH) 97 | 98 | # Build the model architecture 99 | 100 | self.drp.build_model( 101 | output=OUTPUT, 102 | word_embeddings=WORD_EMBEDDINGS, 103 | pretrained_embedding=PRETRAINED_EMBEDDING, 104 | dropout=DROPOUT, 105 | lstm_hidden=LSTM_HIDDEN, 106 | word_embedding_size=WORD_EMBEDDING_SIZE, 107 | char_embedding_size=CHAR_EMBEDDING_SIZE, 108 | ) 109 | 110 | def parse(self, text, verbose=False): 111 | 112 | nlp = en_core_web_sm.load() 113 | doc = nlp(text) 114 | chunks = break_into_chunks(doc, max_words=self.MAX_WORDS) 115 | tokens = [[token.text for token in chunk] for chunk in chunks] 116 | 117 | preds = self.drp.predict(tokens, load_weights=True) 118 | 119 | flat_predictions = list(itertools.chain.from_iterable(preds))[0] 120 | flat_X = list(itertools.chain.from_iterable(tokens)) 121 | rows = [i for i in zip(flat_X, flat_predictions)] 122 | 123 | if verbose: 124 | 125 | msg.divider("Token Results") 126 | 127 | header = ("token", "label") 128 | aligns = ("r", "l") 129 | formatted = wasabi.table(rows, header=header, divider=True, aligns=aligns) 130 | print(formatted) 131 | 132 | out = rows 133 | 134 | return out 135 | 136 | 137 | @plac.annotations( 138 | text=("Plaintext from which to extract references", "positional", None, str), 139 | config_file=("Path to config file", "option", "c", str), 140 | outfile=("Path to json file to which results will be written", "option", "o", str), 141 | ) 142 | def parse(text, config_file=PARSER_CFG, outfile=None): 143 | """ 144 | Runs the default parsing model and pretty prints results to console unless 145 | --outfile is parsed with a path. Output written to the the path specified in 146 | --outfile will be a valid json. 147 | 148 | NOTE: that this function is provided for examples only and should not be used 149 | in production as the model is instantiated each time the command is run. To 150 | use in a production setting, a more sensible approach would be to replicate 151 | the split or parse functions within your own logic. 152 | """ 153 | parser = Parser(config_file) 154 | if outfile: 155 | out = parser.parse(text, verbose=False) 156 | 157 | try: 158 | with open(outfile, "w") as fb: 159 | json.dump(out, fb) 160 | msg.good(f"Wrote model output to {outfile}") 161 | except: 162 | msg.fail(f"Failed to write output to {outfile}") 163 | 164 | else: 165 | out = parser.parse(text, verbose=True) 166 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/README.md: -------------------------------------------------------------------------------- 1 | # Prodigy utilities 2 | 3 | The `deep_reference_parser.prodigy` module contains a number of utility functions for working with annotations created in [prodi.gy](http://prodi.gy). 4 | 5 | The individual functions can be access with the usual `import deep_reference_parser.prodigy` logic, but can also be accessed on the command line with: 6 | 7 | ``` 8 | $ python -m deep_reference_parser.prodigy 9 | Using TensorFlow backend. 10 | 11 | ℹ Available commands 12 | annotate_numbered_refs, prodigy_to_tsv, reach_to_prodigy, 13 | refs_to_token_annotations 14 | ``` 15 | 16 | |Name|Description| 17 | |---|---| 18 | |reach_to_prodigy|Converts a jsonl of reference sections output by reach into a jsonl containing prodigy format documents.| 19 | |annotate_numbered_refs|Takes numbered reference sections extract by Reach, and roughly annotates the references by splitting the reference lines apart on the numbers.| 20 | |prodigy_to_tsv|Converts a jsonl file of prodigy documents to a tab separated values (tsv) file where each token and its associated label occupy a line.| 21 | |refs_to_token_annotations|Takes a jsonl of annotated reference sections in prodigy format that have been manually annotated to the reference level, and converts the references into token level annotations based on the IOBE schema, saving a new file or prodigy documents to jsonl.| 22 | 23 | Help for each of these commands can be sought with the `--help` flag, e.g.: 24 | 25 | ``` 26 | $ python -m deep_reference_parser.prodigy prodigy_to_tsv --help 27 | Using TensorFlow backend. 28 | usage: deep_reference_parser prodigy_to_tsv [-h] input_file output_file 29 | 30 | Convert token annotated jsonl to token annotated tsv ready for use in the 31 | Rodrigues model. 32 | 33 | 34 | positional arguments: 35 | input_file Path to jsonl file containing prodigy docs. 36 | output_file Path to output tsv file. 37 | 38 | optional arguments: 39 | -h, --help show this help message and exit 40 | 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/__init__.py: -------------------------------------------------------------------------------- 1 | from .numbered_reference_annotator import ( 2 | NumberedReferenceAnnotator, 3 | annotate_numbered_references, 4 | ) 5 | from .prodigy_to_tsv import TokenLabelPairs, prodigy_to_tsv 6 | from .reach_to_prodigy import ReachToProdigy, reach_to_prodigy 7 | from .reference_to_token_annotations import TokenTagger, reference_to_token_annotations 8 | from .spacy_doc_to_prodigy import SpacyDocToProdigy 9 | from .misc import prodigy_to_conll 10 | from .labels_to_prodigy import labels_to_prodigy 11 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | 3 | """ 4 | Modified from https://github.com/explosion/spaCy/blob/master/spacy/__main__.py 5 | 6 | """ 7 | 8 | if __name__ == "__main__": 9 | import plac 10 | import sys 11 | from wasabi import msg 12 | from .numbered_reference_annotator import annotate_numbered_references 13 | from .prodigy_to_tsv import prodigy_to_tsv 14 | from .reach_to_prodigy import reach_to_prodigy 15 | from .reference_to_token_annotations import reference_to_token_annotations 16 | 17 | commands = { 18 | "annotate_numbered_refs": annotate_numbered_references, 19 | "prodigy_to_tsv": prodigy_to_tsv, 20 | "reach_to_prodigy": reach_to_prodigy, 21 | "refs_to_token_annotations": reference_to_token_annotations, 22 | } 23 | 24 | if len(sys.argv) == 1: 25 | msg.info("Available commands", ", ".join(commands), exits=1) 26 | command = sys.argv.pop(1) 27 | sys.argv[0] = "deep_reference_parser %s" % command 28 | 29 | if command in commands: 30 | plac.call(commands[command], sys.argv[1:]) 31 | else: 32 | available = "Available: {}".format(", ".join(commands)) 33 | msg.fail("Unknown command: {}".format(command), available, exits=1) 34 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/labels_to_prodigy.py: -------------------------------------------------------------------------------- 1 | def labels_to_prodigy(tokens, labels): 2 | """ 3 | Converts a list of tokens and labels like those used by Rodrigues et al, 4 | and converts to prodigy format dicts. 5 | 6 | Args: 7 | tokens (list): A list of tokens. 8 | labels (list): A list of labels relating to `tokens`. 9 | 10 | Returns: 11 | A list of prodigy format dicts containing annotated data. 12 | """ 13 | 14 | prodigy_data = [] 15 | 16 | all_token_index = 0 17 | 18 | for line_index, line in enumerate(tokens): 19 | prodigy_example = {} 20 | 21 | tokens = [] 22 | spans = [] 23 | token_start_offset = 0 24 | 25 | for token_index, token in enumerate(line): 26 | 27 | token_end_offset = token_start_offset + len(token) 28 | 29 | tokens.append( 30 | { 31 | "text": token, 32 | "id": token_index, 33 | "start": token_start_offset, 34 | "end": token_end_offset, 35 | } 36 | ) 37 | 38 | spans.append( 39 | { 40 | "label": labels[line_index][token_index : token_index + 1][0], 41 | "start": token_start_offset, 42 | "end": token_end_offset, 43 | "token_start": token_index, 44 | "token_end": token_index, 45 | } 46 | ) 47 | 48 | prodigy_example["text"] = " ".join(line) 49 | prodigy_example["tokens"] = tokens 50 | prodigy_example["spans"] = spans 51 | prodigy_example["meta"] = {"line": line_index} 52 | 53 | token_start_offset = token_end_offset + 1 54 | 55 | prodigy_data.append(prodigy_example) 56 | 57 | return prodigy_data 58 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/misc.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | 4 | def _join_prodigy_tokens(text): 5 | """Return all prodigy tokens in a single string 6 | """ 7 | 8 | return "\n".join([str(i) for i in text]) 9 | 10 | 11 | def prodigy_to_conll(docs): 12 | """ 13 | Expect list of jsons loaded from a jsonl 14 | """ 15 | 16 | nlp = spacy.load("en_core_web_sm") 17 | texts = [doc["text"] for doc in docs] 18 | docs = list(nlp.tokenizer.pipe(texts)) 19 | 20 | out = [_join_prodigy_tokens(i) for i in docs] 21 | 22 | out_str = "DOCSTART\n\n" + "\n\n".join(out) 23 | 24 | return out_str 25 | 26 | 27 | def prodigy_to_lists(docs): 28 | """ 29 | Expect list of jsons loaded from a jsonl 30 | """ 31 | 32 | nlp = spacy.load("en_core_web_sm") 33 | texts = [doc["text"] for doc in docs] 34 | docs = list(nlp.tokenizer.pipe(texts)) 35 | 36 | out = [[str(token) for token in doc] for doc in docs] 37 | 38 | return out 39 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/numbered_reference_annotator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | #!/usr/bin/env python3 3 | 4 | import re 5 | 6 | import plac 7 | 8 | from ..io import read_jsonl, write_jsonl 9 | from ..logger import logger 10 | 11 | REGEX = r"\n{1,2}(?:(?:\s)|(?:\(|\[))?(?:\d{1,2})(?:(?:\.\)|\.\]|\]\n|\.|\s)|(?:\]|\)))(\s+)?(?:\n)?(?:\s+)?(?!Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)" 12 | 13 | 14 | class NumberedReferenceAnnotator: 15 | """ 16 | Takes reference sections with numeric labelling scraped by Reach in prodigy 17 | format, and labels the references as spans by splitting them using regex. 18 | 19 | Note that you must identify numbered reference section first. This can be 20 | done with a simple textcat model trained in prodigy. 21 | """ 22 | 23 | def __init__(self): 24 | 25 | self.regex = r"" 26 | 27 | def run(self, docs, regex=REGEX): 28 | 29 | self.regex = regex 30 | 31 | for doc in docs: 32 | 33 | spans = self.label_numbered_references(doc["text"], doc["tokens"]) 34 | doc["spans"] = spans 35 | 36 | yield doc 37 | 38 | def label_numbered_references(self, text, tokens): 39 | 40 | # Search for number reference using regex 41 | 42 | splits = list(re.finditer(self.regex, text)) 43 | spans = [] 44 | 45 | for index in range(0, len(splits) - 1): 46 | 47 | # Calculate the approximate start and end of the reference using 48 | # the character offsets returned by re.finditer. 49 | 50 | start = splits[index].end() 51 | end = splits[index + 1].start() 52 | 53 | # Calculate which is the closest token to the character offset 54 | # returned above. 55 | 56 | token_start = self._find_closest_token(tokens, start, "start") 57 | token_end = self._find_closest_token(tokens, end, "end") 58 | 59 | # To avoid the possibility of mismatches between the character 60 | # offset and the token offset, reset the character offsets 61 | # based on the token offsets. 62 | 63 | start = self._get_token_offset(tokens, token_start, "start") 64 | end = self._get_token_offset(tokens, token_end, "end") 65 | 66 | # Create dict and append 67 | 68 | span = { 69 | "start": start, 70 | "end": end, 71 | "token_start": token_start, 72 | "token_end": token_end, 73 | "label": "BE", 74 | } 75 | 76 | spans.append(span) 77 | 78 | return spans 79 | 80 | def _find_closest_token(self, tokens, char_offset, pos_string): 81 | """ 82 | Find the token start/end closest to "number" 83 | 84 | Args: 85 | tokens: A list of token dicts from a prodigy document. 86 | char_offset(int): A character offset relating to either the start or the 87 | end of a token. 88 | pos_string(str): One of ["start", "end"] denoting whether `char_offset` 89 | is a start or the end of a token 90 | """ 91 | token_map = self._token_start_mapper(tokens, pos_string) 92 | token_key = self._find_closest_number(token_map.keys(), char_offset) 93 | 94 | return token_map[token_key] 95 | 96 | def _get_token_offset(self, tokens, token_id, pos_string): 97 | """ 98 | Return the character offset for the token with id == token_id 99 | """ 100 | 101 | token_match = (token[pos_string] for token in tokens if token["id"] == token_id) 102 | 103 | return next(token_match, None) 104 | 105 | def _find_closest_number(self, numbers, number): 106 | """ Find the closest match in a list of numbers when presented with 107 | a number 108 | """ 109 | 110 | return min(numbers, key=lambda x: abs(x - number)) 111 | 112 | def _token_start_mapper(self, tokens, pos_string): 113 | """ Map token id by the token start/end position 114 | """ 115 | 116 | return {token[pos_string]: token["id"] for token in tokens} 117 | 118 | 119 | @plac.annotations( 120 | input_file=( 121 | "Path to jsonl file containing numbered reference sections as docs.", 122 | "positional", 123 | None, 124 | str, 125 | ), 126 | output_file=( 127 | "Path to output jsonl file containing prodigy docs with numbered references labelled.", 128 | "positional", 129 | None, 130 | str, 131 | ), 132 | ) 133 | def annotate_numbered_references(input_file, output_file): 134 | """ 135 | Takes reference sections with numeric labelling scraped by Reach in prodigy 136 | format, and labels the references as spans by splitting them using regex. 137 | """ 138 | 139 | numbered_reference_sections = read_jsonl(input_file) 140 | 141 | logger.info("Loaded %s prodigy docs", len(numbered_reference_sections)) 142 | 143 | nra = NumberedReferenceAnnotator() 144 | docs = list(nra.run(numbered_reference_sections)) 145 | 146 | write_jsonl(docs, output_file) 147 | 148 | logger.info("Wrote %s annotated references to %s", len(docs), output_file) 149 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/prodigy_to_tsv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | Class used in scripts/prodigy_to_tsv.py which converts token annotated jsonl 6 | files to tab-separated-values files for use in the deep reference parser 7 | """ 8 | 9 | import csv 10 | import re 11 | import sys 12 | from functools import reduce 13 | 14 | import numpy as np 15 | import plac 16 | from wasabi import Printer, table 17 | 18 | from ..io import read_jsonl, write_tsv 19 | from ..logger import logger 20 | 21 | msg = Printer() 22 | 23 | ROWS_TO_PRINT = 15 24 | 25 | 26 | class TokenLabelPairs: 27 | """ 28 | Convert prodigy format docs or list of lists into tuples of (token, label). 29 | """ 30 | 31 | def __init__( 32 | self, line_limit=250, respect_line_endings=False, respect_doc_endings=True 33 | ): 34 | """ 35 | Args: 36 | line_limit(int): Maximum number of tokens allowed per training 37 | example. If you are planning to use this data for making 38 | predictions, then this should correspond to the max_words 39 | attribute for the DeepReferenceParser class used to train the 40 | model. 41 | respect_line_endings(bool): If true, line endings appearing in the 42 | text will be respected, leading to much shorter line lengths 43 | usually <10. Typically this results in a much worser performing 44 | model, but follows the convention set by Rodrigues et al. 45 | respect_doc_endings(bool): If true, a line ending is added at the 46 | end of each document. If false, then the end of a document flows 47 | into the beginning of the next document. 48 | """ 49 | 50 | self.line_count = 0 51 | self.line_lengths = [] 52 | self.line_limit = line_limit 53 | self.respect_doc_endings = respect_doc_endings 54 | self.respect_line_endings = respect_line_endings 55 | 56 | def run(self, datasets): 57 | """ 58 | 59 | Args: 60 | datasets (list): An arbitrary number of lists containing an 61 | arbitrary number of prodigy docs (dicts) which will be combined 62 | in to a single list of tokens based on the arguments provided 63 | when instantiating the TokenLabelPairs class, e.g.: 64 | 65 | [ 66 | (token0, label0_0, label0_1, label0_2), 67 | (token1, label1_0, label1_1, label1_2), 68 | (token2, label2_0, label2_1, label2_2), 69 | (None,) # blank line 70 | (token3, label3_0, label3_1, label3_2), 71 | (token4, label4_0, label4_1, label4_2), 72 | (token5, label5_0, label5_1, label5_2), 73 | 74 | ] 75 | 76 | """ 77 | 78 | out = [] 79 | 80 | input_hashes = list(map(get_input_hashes, datasets)) 81 | 82 | # Check that datasets are compatible by comparing the _input_hash of 83 | # each document across the list of datasets. 84 | 85 | if not check_all_equal(input_hashes): 86 | msg.fail("Some documents missing from one of the input datasets") 87 | 88 | # If this is the case, also output some useful information for 89 | # determining which dataset is at fault. 90 | 91 | for i in range(len(input_hashes)): 92 | for j in range(i + 1, len(input_hashes)): 93 | diff = set(input_hashes[i]) ^ set(input_hashes[j]) 94 | 95 | if diff: 96 | msg.fail( 97 | f"Docs {diff} unequal between dataset {i} and {j}", exits=1 98 | ) 99 | 100 | # Now that we know the input_hashes are equal, cycle through the first 101 | # one, and compare the tokens across the documents in each dataset from 102 | # datasets. 103 | 104 | for input_hash in input_hashes[0]: 105 | 106 | # Create list of docs whose _input_hash matches _input_hash. 107 | # len(matched_docs) == len(datasets) 108 | 109 | matched_docs = list(map(lambda x: get_doc_by_input_hash(x, input_hash), datasets)) 110 | 111 | # Create a list of tokens from input_hash_matches 112 | 113 | tokens = list(map(get_sorted_tokens, matched_docs)) 114 | 115 | # All the tokens should match because they have the same _input_hash 116 | # but lets check just be sure... 117 | 118 | if check_all_equal(tokens): 119 | tokens_and_labels = [tokens[0]] 120 | else: 121 | msg.fail(f"Token mismatch for document {input_hash}", exits=1) 122 | 123 | # Create a list of spans from input_hash_matches 124 | 125 | spans = list(map(get_sorted_labels, matched_docs)) 126 | 127 | # Create a list of lists like: 128 | # [[token0, token1, token2],[label0, label1, label2],...]. Sometimes 129 | # this will just be [None] if there were no spans in the documents, 130 | # so check for this. 131 | 132 | def all_nones(spans): 133 | return all(i is None for i in spans) 134 | 135 | if not all_nones(spans): 136 | tokens_and_labels.extend(spans) 137 | 138 | # Flatten the list of lists to give: 139 | # [(token0, label0, ...), (token1, label1, ...), (token2, label2, ...)] 140 | 141 | flattened_tokens_and_labels = list(zip(*tokens_and_labels)) 142 | 143 | out.extend(list(self.yield_token_label_pair(flattened_tokens_and_labels))) 144 | 145 | # Print some statistics about the data. 146 | 147 | self.stats() 148 | 149 | return out 150 | 151 | def stats(self): 152 | 153 | avg_line_len = np.round(np.mean(self.line_lengths), 2) 154 | 155 | msg.info(f"Returning {self.line_count} examples") 156 | msg.info(f"Average line length: {avg_line_len}") 157 | 158 | def yield_token_label_pair(self, flattened_tokens_and_labels): 159 | """ 160 | Args: 161 | flattened_tokens_and_labels (list): List of tuples relating to the 162 | tokens and labels of a given document. 163 | 164 | NOTE: Makes the assumption that every token has been labelled in spans. This 165 | assumption will be true if the data has been labelled with prodigy, then 166 | spans covering entire references have been converted to token spans. OR that 167 | there are no spans at all, and this is being used to prepare data for 168 | prediction. 169 | """ 170 | 171 | # Set a token counter that is used to limit the number of tokens to 172 | # line_limit. 173 | 174 | token_counter = int(0) 175 | 176 | doc_len = len(flattened_tokens_and_labels) 177 | 178 | for i, token_and_labels in enumerate(flattened_tokens_and_labels, 1): 179 | 180 | token = token_and_labels[0] 181 | labels = token_and_labels[1:] 182 | blank = tuple([None] * (len(labels) + 1)) 183 | 184 | # If the token is just spaces even if it has been labelled, pass it. 185 | 186 | if re.search(r"^[ ]+$", token): 187 | 188 | pass 189 | 190 | # If the token is a newline and we want to respect line endings in 191 | # the text, then yield None which will be converted to a blank line 192 | # when the resulting tsv file is read. 193 | 194 | elif re.search(r"\n", token) and self.respect_line_endings and i != doc_len: 195 | 196 | # Is it blank after whitespace is removed? 197 | 198 | if token.strip() == "": 199 | yield blank 200 | 201 | self.line_lengths.append(token_counter) 202 | self.line_count += 1 203 | token_counter = 0 204 | 205 | # Was it a \n combined with another token? if so return the 206 | # stripped token. 207 | 208 | else: 209 | yield (token.strip(), *labels) 210 | self.line_lengths.append(token_counter) 211 | self.line_count += 1 212 | token_counter = 1 213 | 214 | 215 | # Skip new lines if respect_line_endings not set and not the end 216 | # of a doc. 217 | 218 | elif re.search(r"\n", token) and i != doc_len: 219 | 220 | pass 221 | 222 | elif token_counter == self.line_limit: 223 | 224 | # Yield blank to signify a line ending, then yield the next 225 | # token. 226 | 227 | yield blank 228 | yield (token.strip(), *labels) 229 | 230 | # Set to one to account for the first token being added. 231 | 232 | self.line_lengths.append(token_counter) 233 | self.line_count += 1 234 | 235 | token_counter = 1 236 | 237 | elif i == doc_len and self.respect_doc_endings: 238 | 239 | # Case when the end of the document has been reached, but it is 240 | # less than self.lime_limit. This assumes that we want to retain 241 | # a line ending which denotes the end of a document, and the 242 | # start of new one. 243 | 244 | if token.strip(): 245 | yield (token.strip(), *labels) 246 | yield blank 247 | 248 | self.line_lengths.append(token_counter) 249 | self.line_count += 1 250 | 251 | else: 252 | # Returned the stripped label. 253 | 254 | yield (token.strip(), *labels) 255 | 256 | token_counter += 1 257 | 258 | 259 | def get_input_hashes(dataset): 260 | """Get the hashes for every doc in a dataset and return as set 261 | """ 262 | return set([doc["_input_hash"] for doc in dataset]) 263 | 264 | 265 | def check_all_equal(lst): 266 | """Check that all items in a list are equal and return True or False 267 | """ 268 | return not lst or lst.count(lst[0]) == len(lst) 269 | 270 | 271 | def hash_matches(doc, hash): 272 | """Check whether the hash of the passed doc matches the passed hash 273 | """ 274 | return doc["_input_hash"] == hash 275 | 276 | 277 | def get_doc_by_input_hash(dataset, hash): 278 | """Return a doc from a dataset where hash matches doc["_input_hash"] 279 | Assumes there will only be one match! 280 | """ 281 | return [doc for doc in dataset if doc["_input_hash"] == hash][0] 282 | 283 | 284 | def get_sorted_tokens(doc): 285 | tokens = sorted(doc["tokens"], key=lambda k: k["id"]) 286 | return [token["text"] for token in doc["tokens"]] 287 | 288 | def get_sorted_labels(doc): 289 | if doc.get("spans"): 290 | spans = sorted(doc["spans"], key=lambda k: k["token_start"]) 291 | return [span["label"] for span in doc["spans"]] 292 | 293 | def sort_docs_list(lst): 294 | """Sort a list of prodigy docs by input hash 295 | """ 296 | return sorted(lst, key=lambda k: k["_input_hash"]) 297 | 298 | 299 | def combine_token_label_pairs(pairs): 300 | """Combines a list of [(token, label), (token, label)] to give 301 | (token,label,label). 302 | """ 303 | return pairs[0][0:] + tuple(pair[1] for pair in pairs[1:]) 304 | 305 | 306 | @plac.annotations( 307 | input_files=( 308 | "Comma separated list of paths to jsonl files containing prodigy docs.", 309 | "positional", 310 | None, 311 | str, 312 | ), 313 | output_file=("Path to output tsv file.", "positional", None, str), 314 | respect_lines=( 315 | "Respect line endings? Or parse entire document in a single string?", 316 | "flag", 317 | "r", 318 | bool, 319 | ), 320 | respect_docs=( 321 | "Respect doc endings or parse corpus in single string?", 322 | "flag", 323 | "d", 324 | bool, 325 | ), 326 | line_limit=("Number of characters to include on a line", "option", "l", int), 327 | ) 328 | def prodigy_to_tsv( 329 | input_files, output_file, respect_lines, respect_docs, line_limit=250 330 | ): 331 | """ 332 | Convert token annotated jsonl to token annotated tsv ready for use in the 333 | deep_reference_parser model. 334 | 335 | Will combine annotations from two jsonl files containing the same docs and 336 | the same tokens by comparing the "_input_hash" and token texts. If they are 337 | compatible, the output file will contain both labels ready for use in a 338 | multi-task model, for example: 339 | 340 | token label label 341 | ------------ ----- ----- 342 | References o o 343 | 1 o o 344 | . o o 345 | WHO title b-r 346 | treatment title i-r 347 | guidelines title i-r 348 | for title i-r 349 | drug title i-r 350 | - title i-r 351 | resistant title i-r 352 | tuberculosis title i-r 353 | , title i-r 354 | 2016 title i-r 355 | 356 | Multiple files must be passed as a comma separated list e.g. 357 | 358 | python -m deep_reference_parser.prodigy prodigy_to_tsv file1.jsonl,file2.jsonl out.tsv 359 | 360 | """ 361 | 362 | input_files = input_files.split(",") 363 | 364 | msg.info(f"Loading annotations from {len(input_files)} datasets") 365 | msg.info(f"Respect line endings: {respect_lines}") 366 | msg.info(f"Respect doc endings: {respect_docs}") 367 | msg.info(f"Target example length (n tokens): {line_limit}") 368 | 369 | # Read the input_files. Note the use of map here, because we don't know 370 | # how many sets of annotations area being passed in the list. It could be 2 371 | # but in future it may be more. 372 | 373 | annotated_data = list(map(read_jsonl, input_files)) 374 | 375 | # Sort the docs so that they are in the same order before converting to 376 | # token label pairs. 377 | 378 | tlp = TokenLabelPairs( 379 | respect_doc_endings=respect_docs, 380 | respect_line_endings=respect_lines, 381 | line_limit=line_limit, 382 | ) 383 | 384 | pairs_list = tlp.run(annotated_data) 385 | 386 | write_tsv(pairs_list, output_file) 387 | 388 | # Print out the first ten rows as a sense check 389 | 390 | msg.divider("Example output") 391 | header = ["token"] + ["label"] * len(annotated_data) 392 | aligns = ["r"] + ["l"] * len(annotated_data) 393 | formatted = table(pairs_list[0:ROWS_TO_PRINT], header=header, divider=True, aligns=aligns) 394 | print(formatted) 395 | 396 | msg.good(f"Wrote token/label pairs to {output_file}") 397 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/reach_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import copy 5 | import itertools 6 | 7 | import en_core_web_sm as model 8 | import plac 9 | 10 | from ..io import read_jsonl, write_jsonl 11 | from ..logger import logger 12 | 13 | 14 | class ReachToProdigy: 15 | """ 16 | Converts json of scraped reference section into prodigy style json. 17 | 18 | The resulting json can then be loaded into prodigy is required. 19 | 20 | Expects dict in the following format: 21 | 22 | ``` 23 | { 24 | ..., 25 | "sections": { 26 | "Reference": "References\n1. Upson. M. (2018) ..." 27 | } 28 | } 29 | 30 | ``` 31 | 32 | Returns references in the following format: 33 | 34 | ``` 35 | [{ 36 | 'text': ' This is an example with a linebreak\n', 37 | 'meta': {'doc_hash': None, 'provider': None, 'line_number': 3}, 38 | 'tokens': [ 39 | {'text': ' ', 'start': 0, 'end': 1, 'id': 0}, 40 | {'text': 'This', 'start': 1, 'end': 5, 'id': 1}, 41 | {'text': 'is', 'start': 6, 'end': 8, 'id': 2}, 42 | {'text': 'an', 'start': 9, 'end': 11, 'id': 3}, 43 | {'text': 'example', 'start': 12, 'end': 19, 'id': 4}, 44 | {'text': 'with', 'start': 20, 'end': 24, 'id': 5}, 45 | {'text': 'a', 'start': 25, 'end': 26, 'id': 6}, 46 | {'text': 'linebreak', 'start': 27, 'end': 36, 'id': 7}, 47 | {'text': '\n', 'start': 36, 'end': 37, 'id': 8}] 48 | }, 49 | ... 50 | ] 51 | 52 | ``` 53 | """ 54 | 55 | def __init__( 56 | self, ref_sections, lines=10, split_char="\n", add_linebreak=True, join_char=" " 57 | ): 58 | """ 59 | Args: 60 | ref_sections(list): List of dicts extracted in scrape. 61 | lines(int): Number of lines to combine into one chunk 62 | split_char(str): Character to split lines on. 63 | add_linebreak(bool): Should a linebreak be re-added so that it is 64 | clear where a break was made? 65 | join_chars(str): Which character will be used to join lines at the 66 | point which they are merged into a chunk. 67 | """ 68 | 69 | self.ref_sections = ref_sections 70 | self.lines = lines 71 | self.split_char = split_char 72 | self.add_linebreak = add_linebreak 73 | self.join_char = join_char 74 | 75 | self.nlp = model.load() 76 | 77 | def run(self): 78 | """ 79 | Main method of the class 80 | """ 81 | 82 | prodigy_format = [] 83 | 84 | for i, refs in enumerate(self.ref_sections): 85 | 86 | one_record = self.one_record_to_prodigy_format( 87 | refs, 88 | self.nlp, 89 | self.lines, 90 | self.split_char, 91 | self.add_linebreak, 92 | self.join_char, 93 | ) 94 | 95 | # If something is returned (i.e. there is a ref section) 96 | # then append to prodigy_format. 97 | 98 | if one_record: 99 | 100 | prodigy_format.append(one_record) 101 | 102 | out = list(itertools.chain.from_iterable(prodigy_format)) 103 | 104 | logger.info("Returned %s reference sections", len(out)) 105 | 106 | return out 107 | 108 | def one_record_to_prodigy_format( 109 | self, 110 | input_dict, 111 | nlp, 112 | lines=10, 113 | split_char="\n", 114 | add_linebreak=True, 115 | join_char=" ", 116 | ): 117 | """ 118 | Convert one dict produced by the scrape to a list of prodigy dicts 119 | 120 | Args: 121 | input_dict(dict): One reference section dict from the scrape 122 | nlp: A spacy model, for example loaded with spacy.load("en_core_web_sm") 123 | lines(int): Number of lines to combine into one chunk 124 | split_char(str): Character to split lines on. 125 | add_linebreak(bool): Should a linebreak be re-added so that it is 126 | clear where a break was made? 127 | join_chars(str): Which character will be used to join lines at the 128 | point which they are merged into a chunk. 129 | """ 130 | 131 | out = [] 132 | 133 | # Only continue if references are found 134 | 135 | if input_dict: 136 | 137 | sections = input_dict.get("sections") 138 | 139 | # If there is something in sections: this will be a keyword for example 140 | # reference, or bibliography, etc 141 | 142 | if sections: 143 | 144 | # In case there are more than one keyword, cycle through them 145 | 146 | for _, refs in sections.items(): 147 | 148 | # Refs will be a list, so cycle through it in case there was 149 | # more than one section found with the same keyword 150 | 151 | for ref in refs: 152 | 153 | if refs: 154 | 155 | refs_lines = self.split_lines( 156 | ref, split_char=split_char, add_linebreak=add_linebreak 157 | ) 158 | refs_grouped = self.combine_n_rows( 159 | refs_lines, n=lines, join_char=join_char 160 | ) 161 | 162 | _meta = { 163 | "doc_hash": input_dict.get("file_hash"), 164 | "provider": input_dict.get("provider"), 165 | } 166 | 167 | for i, lines in enumerate(refs_grouped): 168 | 169 | meta = copy.deepcopy(_meta) 170 | 171 | meta["line_number"] = i 172 | 173 | tokens = nlp.tokenizer(lines) 174 | formatted_tokens = [ 175 | self.format_token(i) for i in tokens 176 | ] 177 | 178 | out.append( 179 | { 180 | "text": lines, 181 | "meta": meta, 182 | "tokens": formatted_tokens, 183 | } 184 | ) 185 | 186 | return out 187 | 188 | def format_token(self, token): 189 | """ 190 | Converts prodigy token to dict of format: 191 | 192 | {"text":"of","start":32,"end":34,"id":5} 193 | """ 194 | out = dict() 195 | out["text"] = token.text 196 | out["start"] = token.idx 197 | out["end"] = token.idx + len(token) 198 | out["id"] = token.i 199 | 200 | return out 201 | 202 | def combine_n_rows(self, doc, n=5, join_char=" "): 203 | """ 204 | Splits a document into chunks of length `n` lines. 205 | 206 | Args: 207 | doc(str): A document as a string. 208 | n(int): The number of lines allowed in each chunk. 209 | join_char(str): The character used to join lines within a chunk. 210 | 211 | Returns: 212 | list: A list of chunks containing `n` lines. 213 | """ 214 | 215 | indices = list(range(len(doc))) 216 | 217 | # Split the document into blocks 218 | 219 | groups = list(zip(indices[0::n], indices[n::n])) 220 | 221 | # Iterate through each group of n rows, convert all the items 222 | # to str, and concatenate into a single string 223 | 224 | out = [join_char.join([str(j) for j in doc[beg:end]]) for beg, end in groups] 225 | 226 | # Check whether there is a remainder and concatenate if so 227 | 228 | max_index = len(groups) * n 229 | 230 | last_group = join_char.join([str(j) for j in doc[max_index : len(doc)]]) 231 | 232 | out.append(last_group) 233 | 234 | return out 235 | 236 | def split_lines(self, doc, split_char="\\n", add_linebreak=True): 237 | """ 238 | Split a document by `split_char` 239 | 240 | Args: 241 | doc(str): A document containing references 242 | split_char(str): Character by which `doc` will be split 243 | add_linebreak(bool): If `True`, re-adds the linebreak character to the 244 | end of each line that is split. 245 | 246 | Returns: 247 | (list): List of split lines (str). 248 | 249 | """ 250 | 251 | lines = doc.split(split_char) 252 | 253 | if add_linebreak: 254 | lines = [i + split_char for i in lines] 255 | 256 | return lines 257 | 258 | 259 | @plac.annotations( 260 | input_file=( 261 | "Path to jsonl file containing produced by scraper and containing reference sections.", 262 | "positional", 263 | None, 264 | str, 265 | ), 266 | output_file=( 267 | "Path to jsonl file into which prodigy format references will be saved.", 268 | "positional", 269 | None, 270 | str, 271 | ), 272 | lines=("How many lines to include in an annotation example.", "option", "l", int), 273 | split_char=("Which character to split lines on.", "option", "s", str), 274 | no_linebreak=( 275 | "Don't re-add linebreaks to the annotation examples after splitting.", 276 | "flag", 277 | "n", 278 | str, 279 | ), 280 | join_char=( 281 | "Which character should be used to join lines into an annotation example.", 282 | "option", 283 | "j", 284 | str, 285 | ), 286 | ) 287 | def reach_to_prodigy( 288 | input_file, 289 | output_file, 290 | lines=10, 291 | split_char="\\n", 292 | no_linebreak=False, 293 | join_char=" ", 294 | ): 295 | 296 | print(split_char) 297 | 298 | scraped_json = read_jsonl(input_file) 299 | 300 | logger.info("Loaded %s scraped examples", len(scraped_json)) 301 | 302 | if no_linebreak: 303 | add_linebreak = False 304 | else: 305 | add_linebreak = True 306 | 307 | prodigy_format_references = ReachToProdigy( 308 | scraped_json, 309 | lines=lines, 310 | split_char=split_char, 311 | add_linebreak=add_linebreak, 312 | join_char=join_char, 313 | ) 314 | 315 | references = prodigy_format_references.run() 316 | 317 | write_jsonl(references, output_file=output_file) 318 | 319 | logger.info("Prodigy format written to %s", output_file) 320 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/reference_to_token_annotations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import itertools 5 | 6 | import plac 7 | 8 | from ..io import read_jsonl, write_jsonl 9 | from ..logger import logger 10 | 11 | 12 | class TokenTagger: 13 | def __init__(self, task="splitting", lowercase=True, text=True): 14 | """ 15 | Converts data in prodigy format with full reference spans to per-token 16 | spans 17 | 18 | Args: 19 | task (str): One of ["parsing", "splitting"]. See below further 20 | explanation. 21 | lowercase (bool): Automatically convert upper case annotations to 22 | lowercase under the parsing scenario. 23 | text (bool): Include the token text in the output span (very useful 24 | for debugging). 25 | 26 | Since the parsing, splitting, and classification tasks have quite 27 | different labelling requirements, this class behaves differently 28 | depending on which task is specified in the task argument. 29 | 30 | For splitting: 31 | 32 | Expects one of four labels for the spans: 33 | 34 | * BE: A complete reference 35 | * BI: A frgament of reference that captures the beginning but not the end 36 | * IE: A frgament of reference that captures the end but not the beginning 37 | * II: A fragment of a reference that captures neither the beginning nor the 38 | end . 39 | 40 | Depending on which label is applied the tokens within the span will be 41 | labelled differently as one of ["b-r", "i-r", "e-r", "o"]. 42 | 43 | For parsing: 44 | 45 | Expects any arbitrary label for spans. All tokens within that span will 46 | be labelled with the same span. 47 | 48 | """ 49 | 50 | self.out = [] 51 | self.task = task 52 | self.lowercase = lowercase 53 | self.text = text 54 | 55 | def tag_doc(self, doc): 56 | """ 57 | Tags a document with appropriate labels for the parsing task 58 | 59 | Args: 60 | doc(dict): A single document in prodigy dict format to be labelled. 61 | """ 62 | 63 | bie_spans = self.reference_spans(doc["spans"], doc["tokens"], task=self.task) 64 | o_spans = self.outside_spans(bie_spans, doc["tokens"]) 65 | 66 | # Flatten into one list. 67 | 68 | spans = itertools.chain(bie_spans, o_spans) 69 | 70 | # Sort by token id to ensure it is ordered. 71 | 72 | spans = sorted(spans, key=lambda k: k["token_start"]) 73 | 74 | doc["spans"] = spans 75 | 76 | return doc 77 | 78 | def run(self, docs): 79 | """ 80 | Main class method for tagging multiple documents. 81 | 82 | Args: 83 | docs(dict): A list of docs in prodigy dict format to be labelled. 84 | """ 85 | 86 | for doc in docs: 87 | 88 | self.out.append(self.tag_doc(doc)) 89 | 90 | return self.out 91 | 92 | def reference_spans(self, spans, tokens, task): 93 | """ 94 | Given a whole reference span as labelled in prodigy, break this into 95 | appropriate single token spans depending on the label that was applied to 96 | the whole reference span. 97 | """ 98 | split_spans = [] 99 | 100 | if task == "splitting": 101 | 102 | for span in spans: 103 | if span["label"] in ["BE", "be"]: 104 | 105 | split_spans.extend( 106 | self.split_long_span(tokens, span, "b-r", "e-r", "i-r") 107 | ) 108 | 109 | elif span["label"] in ["BI", "bi"]: 110 | 111 | split_spans.extend( 112 | self.split_long_span(tokens, span, "b-r", "i-r", "i-r") 113 | ) 114 | 115 | elif span["label"] in ["IE", "ie"]: 116 | 117 | split_spans.extend( 118 | self.split_long_span(tokens, span, "i-r", "e-r", "i-r") 119 | ) 120 | 121 | elif span["label"] in ["II", "ii"]: 122 | 123 | split_spans.extend( 124 | self.split_long_span(tokens, span, "i-r", "i-r", "i-r") 125 | ) 126 | 127 | elif task == "parsing": 128 | 129 | for span in spans: 130 | if self.lowercase: 131 | label = span["label"].lower() 132 | else: 133 | label = span["label"] 134 | split_spans.extend( 135 | self.split_long_span(tokens, span, label, label, label) 136 | ) 137 | 138 | return split_spans 139 | 140 | def outside_spans(self, spans, tokens): 141 | """ 142 | Label tokens with `o` if they are outside a reference 143 | 144 | Args: 145 | spans(list): Spans in prodigy format. 146 | tokens(list): Tokens in prodigy format. 147 | 148 | Returns: 149 | list: A list of spans in prodigy format that comprises the tokens which 150 | are outside of a reference. 151 | """ 152 | # Get the diff between inside and outside tokens 153 | 154 | span_indices = set([span["token_start"] for span in spans]) 155 | token_indices = set([token["id"] for token in tokens]) 156 | 157 | outside_indices = token_indices - span_indices 158 | 159 | outside_spans = [] 160 | 161 | for index in outside_indices: 162 | outside_spans.append(self.create_span(tokens, index, "o")) 163 | 164 | return outside_spans 165 | 166 | def create_span(self, tokens, index, label): 167 | """ 168 | Given a list of tokens, (in prodigy format) and an index relating to one of 169 | those tokens, and a new label: create a single token span using the new 170 | label, and the token selected by `index`. 171 | """ 172 | 173 | token = tokens[index] 174 | 175 | span = { 176 | "start": token["start"], 177 | "end": token["end"], 178 | "token_start": token["id"], 179 | "token_end": token["id"], 180 | "label": label, 181 | } 182 | 183 | if self.text: 184 | span["text"] = token["text"] 185 | 186 | return span 187 | 188 | def split_long_span(self, tokens, span, start_label, end_label, inside_label): 189 | """ 190 | Split a multi-token span into `n` spans of lengh `1`, where `n=len(tokens)` 191 | """ 192 | spans = [] 193 | start = span["token_start"] 194 | end = span["token_end"] 195 | 196 | span_size = end - start 197 | 198 | # Case when there is only one token in the span 199 | if span_size == 0: 200 | spans.append(self.create_span(tokens, start, start_label)) 201 | # Case when there are two or more tokens in the span 202 | else: 203 | spans.append(self.create_span(tokens, start, start_label)) 204 | spans.append(self.create_span(tokens, end, end_label)) 205 | 206 | if span_size > 1: 207 | 208 | for index in range(start + 1, end): 209 | spans.append(self.create_span(tokens, index, inside_label)) 210 | 211 | spans = sorted(spans, key=lambda k: k["token_start"]) 212 | 213 | return spans 214 | 215 | 216 | @plac.annotations( 217 | input_file=( 218 | "Path to jsonl file containing chunks of references in prodigy format.", 219 | "positional", 220 | None, 221 | str, 222 | ), 223 | output_file=( 224 | "Path to jsonl file into which fully annotate files will be saved.", 225 | "positional", 226 | None, 227 | str, 228 | ), 229 | task=( 230 | "Which task is being performed. Either splitting or parsing.", 231 | "positional", 232 | None, 233 | str, 234 | ), 235 | lowercase=( 236 | "Convert UPPER case reference labels to lower case token labels?", 237 | "flag", 238 | "f", 239 | bool, 240 | ), 241 | text=( 242 | "Output the token text in the span (useful for debugging).", 243 | "flag", 244 | "t", 245 | bool, 246 | ), 247 | ) 248 | def reference_to_token_annotations( 249 | input_file, output_file, task="splitting", lowercase=False, text=False 250 | ): 251 | """ 252 | Creates a span for every token from existing multi-token spans 253 | 254 | Converts a jsonl file output by prodigy (using prodigy db-out) with spans 255 | extending over more than a single token to individual token level spans. 256 | 257 | The rationale for this is that reference level annotations are much easier 258 | for humans to do, but not useful when training a token level model. 259 | 260 | This command functions in two ways: 261 | 262 | * task=splitting: For the splitting task where we are interested in 263 | labelling the beginning (b-r) and end (e-r) of references, reference 264 | spans are labelled with one of BI, BE, IE, II. These are then converted 265 | to token level spans b-r, i-r, e-r, and o using logic. Symbolically: 266 | * BE: [BE, BE, BE] becomes [b-r][i-r][e-r] 267 | * BI: [BI, BI, BI] becomes [b-r][i-r][i-r] 268 | * IE: [IE, IE, IE] becomes [i-r][i-r][e-r] 269 | * II: [II, II, II] becomes [i-r][i-r][i-r] 270 | * All other tokens become [o] 271 | 272 | * task=parsing: For the parsing task, multi-task annotations are much 273 | simpler and would tend to be just 'author', or 'title'. These simple 274 | labels can be applied directly to the individual tokens contained within 275 | these multi-token spans; for each token in the multi-token span, a span 276 | is created with the same label. Symbolically: 277 | * [author author author] becomes [author][author][author] 278 | """ 279 | 280 | ref_annotated_docs = read_jsonl(input_file) 281 | 282 | # Only run the tagger on annotated examples. 283 | 284 | ref_annotated_docs = [doc for doc in ref_annotated_docs if doc.get("spans")] 285 | 286 | logger.info( 287 | "Loaded %s documents with reference annotations", len(ref_annotated_docs) 288 | ) 289 | 290 | annotator = TokenTagger(task=task, lowercase=lowercase, text=text) 291 | 292 | token_annotated_docs = annotator.run(ref_annotated_docs) 293 | 294 | write_jsonl(token_annotated_docs, output_file=output_file) 295 | 296 | logger.info( 297 | "Wrote %s docs with token annotations to %s", 298 | len(token_annotated_docs), 299 | output_file, 300 | ) 301 | -------------------------------------------------------------------------------- /deep_reference_parser/prodigy/spacy_doc_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import spacy 5 | 6 | 7 | class SpacyDocToProdigy: 8 | """Convert spacy documents into prodigy format 9 | """ 10 | 11 | def run(self, docs): 12 | """ 13 | Cycle through docs and return prodigy docs. 14 | """ 15 | 16 | return list(self.return_one_prodigy_doc(doc) for doc in docs) 17 | 18 | def return_one_prodigy_doc(self, doc): 19 | """Given one spacy document, yield a prodigy style dict 20 | 21 | Args: 22 | doc (spacy.tokens.doc.Doc): A spacy document 23 | 24 | Returns: 25 | dict: Prodigy style document 26 | 27 | """ 28 | 29 | if not isinstance(doc, spacy.tokens.doc.Doc): 30 | raise TypeError("doc must be of type spacy.tokens.doc.Doc") 31 | 32 | text = doc.text 33 | spans = [] 34 | tokens = [] 35 | 36 | for token in doc: 37 | tokens.append( 38 | { 39 | "text": token.text, 40 | "start": token.idx, 41 | "end": token.idx + len(token.text), 42 | "id": token.i, 43 | } 44 | ) 45 | 46 | for ent in doc.ents: 47 | spans.append( 48 | { 49 | "token_start": ent.start, 50 | "token_end": ent.end, 51 | "start": ent.start_char, 52 | "end": ent.end_char, 53 | "label": ent.label_, 54 | } 55 | ) 56 | 57 | out = { 58 | "text": text, 59 | "spans": spans, 60 | "tokens": tokens, 61 | } 62 | 63 | return out 64 | -------------------------------------------------------------------------------- /deep_reference_parser/reference_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | """ 6 | 7 | from .logger import logger 8 | 9 | 10 | def yield_token_label_pairs(tokens, labels): 11 | """ 12 | Convert matching lists of tokens and labels to tuples of (token, label) but 13 | preserving the nexted list boundaries as (None, None). 14 | 15 | Args: 16 | tokens(list): list of tokens. 17 | labels(list): list of labels corresponding to tokens. 18 | """ 19 | 20 | for tokens, labels in zip(tokens, labels): 21 | if tokens and labels: 22 | for token, label in zip(tokens, labels): 23 | yield (token, label) 24 | yield (None, None) 25 | else: 26 | yield (None, None) 27 | 28 | 29 | def break_into_chunks(doc, max_words=250): 30 | """ 31 | Breaks a list into lists of lists of length max_words 32 | Also works on lists: 33 | 34 | >>> doc = ["a", "b", "c", "d", "e"] 35 | >>> break_into_chunks(doc, max_words=2) 36 | [['a', 'b'], ['c', 'd'], ['e']] 37 | """ 38 | out = [] 39 | chunk = [] 40 | for i, token in enumerate(doc, 1): 41 | chunk.append(token) 42 | if (i > 0 and i % max_words == 0) or i == len(doc): 43 | out.append(chunk) 44 | chunk = [] 45 | return out 46 | -------------------------------------------------------------------------------- /deep_reference_parser/split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | """ 4 | Run predictions from a pre-trained model 5 | """ 6 | 7 | import itertools 8 | import json 9 | import os 10 | 11 | import en_core_web_sm 12 | import plac 13 | import spacy 14 | import wasabi 15 | 16 | import warnings 17 | 18 | with warnings.catch_warnings(): 19 | warnings.filterwarnings("ignore", category=DeprecationWarning) 20 | 21 | from deep_reference_parser import __file__ 22 | from deep_reference_parser.__version__ import __splitter_model_version__ 23 | from deep_reference_parser.common import SPLITTER_CFG, download_model_artefact 24 | from deep_reference_parser.deep_reference_parser import DeepReferenceParser 25 | from deep_reference_parser.logger import logger 26 | from deep_reference_parser.model_utils import get_config 27 | from deep_reference_parser.reference_utils import break_into_chunks 28 | from deep_reference_parser.tokens_to_references import tokens_to_references 29 | 30 | msg = wasabi.Printer(icons={"check": "\u2023"}) 31 | 32 | 33 | class Splitter: 34 | def __init__(self, config_file): 35 | 36 | msg.info(f"Using config file: {config_file}") 37 | 38 | cfg = get_config(config_file) 39 | 40 | # Build config 41 | try: 42 | OUTPUT_PATH = cfg["build"]["output_path"] 43 | S3_SLUG = cfg["data"]["s3_slug"] 44 | except KeyError: 45 | config_dir, missing_config = os.path.split(config_file) 46 | files = os.listdir(config_dir) 47 | other_configs = [f for f in os.listdir(config_dir) if os.path.isfile(os.path.join(config_dir, f))] 48 | msg.fail(f"Could not find config {missing_config}, perhaps you meant one of {other_configs}") 49 | 50 | msg.info( 51 | f"Attempting to download model artefacts if they are not found locally in {cfg['build']['output_path']}. This may take some time..." 52 | ) 53 | 54 | # Check whether the necessary artefacts exists and download them if 55 | # not. 56 | 57 | artefacts = [ 58 | "indices.pickle", 59 | "weights.h5", 60 | ] 61 | 62 | for artefact in artefacts: 63 | with msg.loading(f"Could not find {artefact} locally, downloading..."): 64 | try: 65 | artefact = os.path.join(OUTPUT_PATH, artefact) 66 | download_model_artefact(artefact, S3_SLUG) 67 | msg.good(f"Found {artefact}") 68 | except: 69 | msg.fail(f"Could not download {S3_SLUG}{artefact}") 70 | logger.exception("Could not download %s%s", S3_SLUG, artefact) 71 | 72 | # Check on word embedding and download if not exists 73 | 74 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 75 | 76 | with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."): 77 | try: 78 | download_model_artefact(WORD_EMBEDDINGS, S3_SLUG) 79 | msg.good(f"Found {WORD_EMBEDDINGS}") 80 | except: 81 | msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}") 82 | logger.exception("Could not download %s", WORD_EMBEDDINGS) 83 | 84 | OUTPUT = cfg["build"]["output"] 85 | PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"] 86 | DROPOUT = float(cfg["build"]["dropout"]) 87 | LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"]) 88 | WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"]) 89 | CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"]) 90 | 91 | self.MAX_WORDS = int(cfg["data"]["line_limit"]) 92 | 93 | # Evaluate config 94 | 95 | self.drp = DeepReferenceParser(output_path=OUTPUT_PATH) 96 | 97 | # Encode data and load required mapping dicts. Note that the max word and 98 | # max char lengths will be loaded in this step. 99 | 100 | self.drp.load_data(OUTPUT_PATH) 101 | 102 | # Build the model architecture 103 | 104 | self.drp.build_model( 105 | output=OUTPUT, 106 | word_embeddings=WORD_EMBEDDINGS, 107 | pretrained_embedding=PRETRAINED_EMBEDDING, 108 | dropout=DROPOUT, 109 | lstm_hidden=LSTM_HIDDEN, 110 | word_embedding_size=WORD_EMBEDDING_SIZE, 111 | char_embedding_size=CHAR_EMBEDDING_SIZE, 112 | ) 113 | 114 | def split(self, text, return_tokens=False, verbose=False): 115 | 116 | nlp = en_core_web_sm.load() 117 | doc = nlp(text) 118 | chunks = break_into_chunks(doc, max_words=self.MAX_WORDS) 119 | tokens = [[token.text for token in chunk] for chunk in chunks] 120 | 121 | preds = self.drp.predict(tokens, load_weights=True) 122 | 123 | # If tokens argument passed, return the labelled tokens 124 | 125 | if return_tokens: 126 | 127 | flat_predictions = list(itertools.chain.from_iterable(preds))[0] 128 | flat_X = list(itertools.chain.from_iterable(tokens)) 129 | rows = [i for i in zip(flat_X, flat_predictions)] 130 | 131 | if verbose: 132 | 133 | msg.divider("Token Results") 134 | 135 | header = ("token", "label") 136 | aligns = ("r", "l") 137 | formatted = wasabi.table( 138 | rows, header=header, divider=True, aligns=aligns 139 | ) 140 | print(formatted) 141 | 142 | out = rows 143 | 144 | else: 145 | 146 | # Otherwise convert the tokens into references and return 147 | 148 | refs = tokens_to_references(tokens, preds[0]) 149 | 150 | if verbose: 151 | 152 | msg.divider("Results") 153 | 154 | if refs: 155 | 156 | msg.good(f"Found {len(refs)} references.") 157 | msg.info("Printing found references:") 158 | 159 | for ref in refs: 160 | msg.text(ref, icon="check", spaced=True) 161 | 162 | else: 163 | 164 | msg.fail("Failed to find any references.") 165 | 166 | out = refs 167 | 168 | return out 169 | 170 | 171 | @plac.annotations( 172 | text=("Plaintext from which to extract references", "positional", None, str), 173 | config_file=("Path to config file", "option", "c", str), 174 | tokens=("Output tokens instead of complete references", "flag", "t", str), 175 | outfile=("Path to json file to which results will be written", "option", "o", str), 176 | ) 177 | def split(text, config_file=SPLITTER_CFG, tokens=False, outfile=None): 178 | """ 179 | Runs the default splitting model and pretty prints results to console unless 180 | --outfile is parsed with a path. Files output to the path specified in 181 | --outfile will be a valid json. Can output either tokens (with -t|--tokens) 182 | or split naively into references based on the b-r tag (default). 183 | 184 | NOTE: that this function is provided for examples only and should not be used 185 | in production as the model is instantiated each time the command is run. To 186 | use in a production setting, a more sensible approach would be to replicate 187 | the split or parse functions within your own logic. 188 | """ 189 | splitter = Splitter(config_file) 190 | if outfile: 191 | out = splitter.split(text, return_tokens=tokens, verbose=False) 192 | 193 | try: 194 | with open(outfile, "w") as fb: 195 | json.dump(out, fb) 196 | msg.good(f"Wrote model output to {outfile}") 197 | except: 198 | msg.fail(f"Failed to write output to {outfile}") 199 | 200 | else: 201 | out = splitter.split(text, return_tokens=tokens, verbose=True) 202 | -------------------------------------------------------------------------------- /deep_reference_parser/split_parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | """ 4 | Run predictions from a pre-trained model 5 | """ 6 | 7 | import itertools 8 | import json 9 | import os 10 | 11 | import en_core_web_sm 12 | import plac 13 | import spacy 14 | import wasabi 15 | 16 | import warnings 17 | 18 | with warnings.catch_warnings(): 19 | warnings.filterwarnings("ignore", category=DeprecationWarning) 20 | 21 | from deep_reference_parser import __file__ 22 | from deep_reference_parser.__version__ import __splitter_model_version__ 23 | from deep_reference_parser.common import MULTITASK_CFG, download_model_artefact 24 | from deep_reference_parser.deep_reference_parser import DeepReferenceParser 25 | from deep_reference_parser.logger import logger 26 | from deep_reference_parser.model_utils import get_config 27 | from deep_reference_parser.reference_utils import break_into_chunks 28 | from deep_reference_parser.tokens_to_references import tokens_to_reference_lists 29 | 30 | msg = wasabi.Printer(icons={"check": "\u2023"}) 31 | 32 | 33 | class SplitParser: 34 | def __init__(self, config_file): 35 | 36 | msg.info(f"Using config file: {config_file}") 37 | 38 | cfg = get_config(config_file) 39 | 40 | try: 41 | OUTPUT_PATH = cfg["build"]["output_path"] 42 | S3_SLUG = cfg["data"]["s3_slug"] 43 | except KeyError: 44 | config_dir, missing_config = os.path.split(config_file) 45 | files = os.listdir(config_dir) 46 | other_configs = [f for f in os.listdir(config_dir) if os.path.isfile(os.path.join(config_dir, f))] 47 | msg.fail(f"Could not find config {missing_config}, perhaps you meant one of {other_configs}") 48 | 49 | # Check whether the necessary artefacts exists and download them if 50 | # not. 51 | 52 | artefacts = [ 53 | "indices.pickle", 54 | "weights.h5", 55 | ] 56 | 57 | for artefact in artefacts: 58 | with msg.loading(f"Could not find {artefact} locally, downloading..."): 59 | try: 60 | artefact = os.path.join(OUTPUT_PATH, artefact) 61 | download_model_artefact(artefact, S3_SLUG) 62 | msg.good(f"Found {artefact}") 63 | except: 64 | msg.fail(f"Could not download {S3_SLUG}{artefact}") 65 | logger.exception("Could not download %s%s", S3_SLUG, artefact) 66 | 67 | # Check on word embedding and download if not exists 68 | 69 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 70 | 71 | with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."): 72 | try: 73 | download_model_artefact(WORD_EMBEDDINGS, S3_SLUG) 74 | msg.good(f"Found {WORD_EMBEDDINGS}") 75 | except: 76 | msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}") 77 | logger.exception("Could not download %s", WORD_EMBEDDINGS) 78 | 79 | OUTPUT = cfg["build"]["output"] 80 | PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"] 81 | DROPOUT = float(cfg["build"]["dropout"]) 82 | LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"]) 83 | WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"]) 84 | CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"]) 85 | 86 | self.MAX_WORDS = int(cfg["data"]["line_limit"]) 87 | 88 | # Evaluate config 89 | 90 | self.drp = DeepReferenceParser(output_path=OUTPUT_PATH) 91 | 92 | # Encode data and load required mapping dicts. Note that the max word and 93 | # max char lengths will be loaded in this step. 94 | 95 | self.drp.load_data(OUTPUT_PATH) 96 | 97 | # Build the model architecture 98 | 99 | self.drp.build_model( 100 | output=OUTPUT, 101 | word_embeddings=WORD_EMBEDDINGS, 102 | pretrained_embedding=PRETRAINED_EMBEDDING, 103 | dropout=DROPOUT, 104 | lstm_hidden=LSTM_HIDDEN, 105 | word_embedding_size=WORD_EMBEDDING_SIZE, 106 | char_embedding_size=CHAR_EMBEDDING_SIZE, 107 | ) 108 | 109 | def split_parse(self, text, return_tokens=False, verbose=False): 110 | 111 | nlp = en_core_web_sm.load() 112 | doc = nlp(text) 113 | chunks = break_into_chunks(doc, max_words=self.MAX_WORDS) 114 | tokens = [[token.text for token in chunk] for chunk in chunks] 115 | 116 | preds = self.drp.predict(tokens, load_weights=True) 117 | 118 | # If tokens argument passed, return the labelled tokens 119 | 120 | if return_tokens: 121 | 122 | flat_preds_list = list(map(itertools.chain.from_iterable,preds)) 123 | flat_X = list(itertools.chain.from_iterable(tokens)) 124 | rows = [i for i in zip(*[flat_X] + flat_preds_list)] 125 | 126 | if verbose: 127 | 128 | msg.divider("Token Results") 129 | 130 | header = tuple(["token"] + ["label"] * len(flat_preds_list)) 131 | aligns = tuple(["r"] + ["l"] * len(flat_preds_list)) 132 | formatted = wasabi.table( 133 | rows, header=header, divider=True, aligns=aligns 134 | ) 135 | print(formatted) 136 | 137 | out = rows 138 | 139 | else: 140 | 141 | # Return references with attributes (author, title, year) 142 | # in json format. 143 | # List of lists for each reference - each reference list contains all token attributes predictions 144 | # [[(token, attribute), ... , (token, attribute)], ..., [(token, attribute), ...]] 145 | 146 | references_components = tokens_to_reference_lists(tokens, spans=preds[1], components=preds[0]) 147 | if verbose: 148 | 149 | msg.divider("Results") 150 | 151 | if references_components: 152 | 153 | msg.good(f"Found {len(references_components)} references.") 154 | msg.info("Printing found references:") 155 | 156 | for ref in references_components: 157 | msg.text(ref['Reference'], icon="check", spaced=True) 158 | 159 | else: 160 | 161 | msg.fail("Failed to find any references.") 162 | 163 | out = references_components 164 | 165 | return out 166 | 167 | 168 | @plac.annotations( 169 | text=("Plaintext from which to extract references", "positional", None, str), 170 | config_file=("Path to config file", "option", "c", str), 171 | tokens=("Output tokens instead of complete references", "flag", "t", str), 172 | outfile=("Path to json file to which results will be written", "option", "o", str), 173 | ) 174 | def split_parse(text, config_file=MULTITASK_CFG, tokens=False, outfile=None): 175 | """ 176 | Runs the default splitting model and pretty prints results to console unless 177 | --outfile is parsed with a path. Files output to the path specified in 178 | --outfile will be a valid json. Can output either tokens (with -t|--tokens) 179 | or split naively into references based on the b-r tag (default). 180 | 181 | NOTE: that this function is provided for examples only and should not be used 182 | in production as the model is instantiated each time the command is run. To 183 | use in a production setting, a more sensible approach would be to replicate 184 | the split or parse functions within your own logic. 185 | """ 186 | mt = SplitParser(config_file) 187 | if outfile: 188 | out = mt.split_parse(text, return_tokens=tokens, verbose=True) 189 | 190 | try: 191 | with open(outfile, "w") as fb: 192 | json.dump(out, fb) 193 | msg.good(f"Wrote model output to {outfile}") 194 | except: 195 | msg.fail(f"Failed to write output to {outfile}") 196 | 197 | else: 198 | out = mt.split_parse(text, return_tokens=tokens, verbose=True) 199 | -------------------------------------------------------------------------------- /deep_reference_parser/tokens_to_references.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | Converts a list of tokens and labels into a list of human readable references 6 | """ 7 | 8 | import itertools 9 | 10 | from .deep_reference_parser import logger 11 | 12 | 13 | def get_reference_spans(tokens, spans): 14 | 15 | # Flatten the lists of tokens and predictions into a single list. 16 | 17 | flat_tokens = list(itertools.chain.from_iterable(tokens)) 18 | flat_predictions = list(itertools.chain.from_iterable(spans)) 19 | 20 | # Find all b-r and e-r tokens. 21 | 22 | ref_starts = [ 23 | index for index, label in enumerate(flat_predictions) if label == "b-r" 24 | ] 25 | 26 | ref_ends = [index for index, label in enumerate(flat_predictions) if label == "e-r"] 27 | 28 | logger.debug("Found %s b-r tokens", len(ref_starts)) 29 | logger.debug("Found %s e-r tokens", len(ref_ends)) 30 | 31 | n_refs = len(ref_starts) 32 | 33 | # Split on each b-r. 34 | 35 | token_starts = [] 36 | token_ends = [] 37 | for i in range(0, n_refs): 38 | token_starts.append(ref_starts[i]) 39 | if i + 1 < n_refs: 40 | token_ends.append(ref_starts[i + 1] - 1) 41 | else: 42 | token_ends.append(len(flat_tokens)) 43 | 44 | return token_starts, token_ends, flat_tokens 45 | 46 | 47 | def tokens_to_references(tokens, labels): 48 | """ 49 | Given a corresponding list of tokens and a list of labels: split the tokens 50 | and return a list of references. 51 | 52 | Args: 53 | tokens(list): A list of tokens. 54 | labels(list): A corresponding list of labels. 55 | 56 | """ 57 | 58 | token_starts, token_ends, flat_tokens = get_reference_spans(tokens, labels) 59 | 60 | references = [] 61 | for token_start, token_end in zip(token_starts, token_ends): 62 | ref = flat_tokens[token_start : token_end + 1] 63 | flat_ref = " ".join(ref) 64 | references.append(flat_ref) 65 | 66 | return references 67 | 68 | def tokens_to_reference_lists(tokens, spans, components): 69 | """ 70 | Given a corresponding list of tokens, a list of 71 | reference spans (e.g. 'b-r') and components (e.g. 'author;): 72 | split the tokens according to the spans and return a 73 | list of reference components for each reference. 74 | 75 | Args: 76 | tokens(list): A list of tokens. 77 | spans(list): A corresponding list of reference spans. 78 | components(list): A corresponding list of reference components. 79 | 80 | """ 81 | 82 | token_starts, token_ends, flat_tokens = get_reference_spans(tokens, spans) 83 | flat_components = list(itertools.chain.from_iterable(components)) 84 | 85 | references_components = [] 86 | for token_start, token_end in zip(token_starts, token_ends): 87 | 88 | ref_tokens = flat_tokens[token_start : token_end + 1] 89 | ref_components = flat_components[token_start : token_end + 1] 90 | flat_ref = " ".join(ref_tokens) 91 | 92 | references_components.append({'Reference': flat_ref, 'Attributes': list(zip(ref_tokens, ref_components))}) 93 | 94 | return references_components 95 | -------------------------------------------------------------------------------- /deep_reference_parser/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | """ 4 | Runs the model using configuration defined in a config file. This is suitable for 5 | running model versions < 2019.10.8 6 | """ 7 | 8 | import plac 9 | import wasabi 10 | 11 | from deep_reference_parser import load_tsv 12 | from deep_reference_parser.common import download_model_artefact 13 | from deep_reference_parser.deep_reference_parser import DeepReferenceParser 14 | from deep_reference_parser.logger import logger 15 | from deep_reference_parser.model_utils import get_config 16 | 17 | msg = wasabi.Printer() 18 | 19 | 20 | @plac.annotations(config_file=("Path to config file", "positional", None, str),) 21 | def train(config_file): 22 | 23 | # Load variables from config files. Config files are used instead of ENV 24 | # vars due to the relatively large number of hyper parameters, and the need 25 | # to load these configs in both the train and predict moduldes. 26 | 27 | cfg = get_config(config_file) 28 | 29 | # Data config 30 | 31 | POLICY_TRAIN = cfg["data"]["policy_train"] 32 | POLICY_TEST = cfg["data"]["policy_test"] 33 | POLICY_VALID = cfg["data"]["policy_valid"] 34 | 35 | # Build config 36 | 37 | OUTPUT_PATH = cfg["build"]["output_path"] 38 | S3_SLUG = cfg["data"]["s3_slug"] 39 | 40 | # Check on word embedding and download if not exists 41 | 42 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 43 | 44 | with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."): 45 | try: 46 | download_model_artefact(WORD_EMBEDDINGS, S3_SLUG) 47 | msg.good(f"Found {WORD_EMBEDDINGS}") 48 | except: 49 | msg.fail(f"Could not download {WORD_EMBEDDINGS}") 50 | logger.exception("Could not download %s", WORD_EMBEDDINGS) 51 | 52 | OUTPUT = cfg["build"]["output"] 53 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 54 | PRETRAINED_EMBEDDING = cfg["build"]["pretrained_embedding"] 55 | DROPOUT = float(cfg["build"]["dropout"]) 56 | LSTM_HIDDEN = int(cfg["build"]["lstm_hidden"]) 57 | WORD_EMBEDDING_SIZE = int(cfg["build"]["word_embedding_size"]) 58 | CHAR_EMBEDDING_SIZE = int(cfg["build"]["char_embedding_size"]) 59 | MAX_LEN = int(cfg["data"]["line_limit"]) 60 | 61 | # Train config 62 | 63 | EPOCHS = int(cfg["train"]["epochs"]) 64 | BATCH_SIZE = int(cfg["train"]["batch_size"]) 65 | EARLY_STOPPING_PATIENCE = int(cfg["train"]["early_stopping_patience"]) 66 | METRIC = cfg["train"]["metric"] 67 | 68 | # Load policy data 69 | 70 | train_data = load_tsv(POLICY_TRAIN) 71 | test_data = load_tsv(POLICY_TEST) 72 | valid_data = load_tsv(POLICY_VALID) 73 | 74 | X_train, y_train = train_data[0], train_data[1:] 75 | X_test, y_test = test_data[0], test_data[1:] 76 | X_valid, y_valid = valid_data[0], valid_data[1:] 77 | 78 | import statistics 79 | 80 | logger.debug("Max token length %s", max([len(i) for i in X_train])) 81 | logger.debug("Min token length %s", min([len(i) for i in X_train])) 82 | logger.debug("Mean token length %s", statistics.median([len(i) for i in X_train])) 83 | 84 | logger.debug("Max token length %s", max([len(i) for i in X_test])) 85 | logger.debug("Min token length %s", min([len(i) for i in X_test])) 86 | logger.debug("Mean token length %s", statistics.median([len(i) for i in X_test])) 87 | 88 | logger.debug("Max token length %s", max([len(i) for i in X_valid])) 89 | logger.debug("Min token length %s", min([len(i) for i in X_valid])) 90 | logger.debug("Mean token length %s", statistics.median([len(i) for i in X_valid])) 91 | 92 | logger.info("X_train, y_train examples: %s, %s", len(X_train), list(map(len, y_train))) 93 | logger.info("X_test, y_test examples: %s, %s", len(X_test), list(map(len, y_test))) 94 | logger.info("X_valid, y_valid examples: %s, %s", len(X_valid), list(map(len, y_valid))) 95 | 96 | drp = DeepReferenceParser( 97 | X_train=X_train, 98 | X_test=X_test, 99 | X_valid=X_valid, 100 | y_train=y_train, 101 | y_test=y_test, 102 | y_valid=y_valid, 103 | max_len=MAX_LEN, 104 | output_path=OUTPUT_PATH, 105 | ) 106 | 107 | ## Encode data and create required mapping dicts 108 | 109 | drp.prepare_data(save=True) 110 | 111 | ## Build the model architecture 112 | 113 | drp.build_model( 114 | output=OUTPUT, 115 | word_embeddings=WORD_EMBEDDINGS, 116 | pretrained_embedding=PRETRAINED_EMBEDDING, 117 | dropout=DROPOUT, 118 | lstm_hidden=LSTM_HIDDEN, 119 | word_embedding_size=WORD_EMBEDDING_SIZE, 120 | char_embedding_size=CHAR_EMBEDDING_SIZE, 121 | ) 122 | 123 | ## Train the model. Not required if downloading weights from s3 124 | 125 | drp.train_model( 126 | epochs=EPOCHS, 127 | batch_size=BATCH_SIZE, 128 | early_stopping_patience=EARLY_STOPPING_PATIENCE, 129 | metric=METRIC, 130 | ) 131 | 132 | # Evaluate the model. Confusion matrices etc will be stored in 133 | # data/model_output 134 | 135 | drp.evaluate( 136 | load_weights=True, 137 | test_set=True, 138 | validation_set=True, 139 | print_padding=False, 140 | ) 141 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | slow: marks tests as slow (deselect with '-m "not slow"') 4 | integration: Depends on downloading data from S3 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | attrs==19.3.0 4 | blis==0.2.4 5 | certifi==2019.11.28 6 | chardet==3.0.4 7 | cycler==0.10.0 8 | cymem==2.0.3 9 | gast==0.3.3 10 | google-pasta==0.1.8 11 | grpcio==1.26.0 12 | h5py==2.10.0 13 | idna==2.8 14 | importlib-metadata==1.5.0 15 | joblib==0.14.1 16 | Keras==2.2.5 17 | Keras-Applications==1.0.8 18 | git+https://www.github.com/keras-team/keras-contrib.git@5ffab172661411218e517a50170bb97760ea567b 19 | Keras-Preprocessing==1.1.0 20 | kiwisolver==1.1.0 21 | Markdown==3.1.1 22 | matplotlib==3.1.1 23 | more-itertools==8.2.0 24 | murmurhash==1.0.2 25 | numpy==1.18.1 26 | packaging==20.1 27 | plac==0.9.6 28 | pluggy==0.13.1 29 | preshed==2.0.1 30 | protobuf==3.11.3 31 | py==1.8.1 32 | pyparsing==2.4.6 33 | pytest==5.3.5 34 | python-crfsuite==0.9.6 35 | python-dateutil==2.8.1 36 | PyYAML==5.3 37 | requests==2.22.0 38 | scikit-learn==0.21.3 39 | scipy==1.4.1 40 | six==1.14.0 41 | sklearn-crfsuite==0.3.6 42 | spacy==2.1.7 43 | srsly==1.0.1 44 | tabulate==0.8.6 45 | tensorboard==1.15.0 46 | tensorflow==1.15.4 47 | tensorflow-estimator==1.15.1 48 | termcolor==1.1.0 49 | thinc==7.0.8 50 | tqdm==4.42.1 51 | urllib3==1.25.8 52 | wasabi==0.6.0 53 | wcwidth==0.1.8 54 | Werkzeug==0.16.1 55 | wrapt==1.11.2 56 | zipp==2.1.0 57 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | tox 3 | codecov 4 | pytest-cov 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | 7 | # Load data from the__versions__.py module. Change version, etc in 8 | # that module, and it will be automatically populated here. 9 | 10 | about = {} # type: dict 11 | version_path = os.path.join(here, "deep_reference_parser", "__version__.py") 12 | with open(version_path, "r") as f: 13 | exec(f.read(), about) 14 | 15 | with open("README.md", "r") as f: 16 | long_description = f.read() 17 | 18 | setuptools.setup( 19 | name=about["__name__"], 20 | version=about["__version__"], 21 | author=about["__author__"], 22 | author_email=about["__author_email__"], 23 | description=about["__description__"], 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | url=about["__url__"], 27 | license=["__license__"], 28 | packages=[ 29 | "deep_reference_parser", 30 | "deep_reference_parser/prodigy", 31 | "deep_reference_parser/io", 32 | ], 33 | package_dir={"deep_reference_parser": "deep_reference_parser"}, 34 | package_data={ 35 | "deep_reference_parser": [ 36 | f"configs/{about['__splitter_model_version__']}.ini", 37 | f"configs/{about['__parser_model_version__']}.ini", 38 | f"configs/{about['__splitparser_model_version__']}.ini", 39 | ] 40 | }, 41 | classifiers=[ 42 | "Programming Language :: Python :: 3", 43 | "Operating System :: OS Independent", 44 | ], 45 | install_requires=[ 46 | "spacy<2.2.0", 47 | "pandas", 48 | "tensorflow==1.15.4", 49 | "keras==2.2.5", 50 | "keras-contrib @ https://github.com/keras-team/keras-contrib/tarball/5ffab172661411218e517a50170bb97760ea567b", 51 | "en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz#egg=en_core_web_sm==2.1.0", 52 | "sklearn", 53 | "sklearn_crfsuite", 54 | "matplotlib", 55 | ], 56 | tests_require=["pytest", "pytest-cov"], 57 | ) 58 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/deep_reference_parser/b58e4616f4de9bfe18ab41e90f696f80ab876245/tests/__init__.py -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | 7 | def get_path(p): 8 | return os.path.join(os.path.dirname(__file__), p) 9 | 10 | 11 | TEST_CFG = get_path("test_data/test_config.ini") 12 | TEST_CFG_MULTITASK = get_path("test_data/test_config_multitask.ini") 13 | TEST_JSONL = get_path("test_data/test_jsonl.jsonl") 14 | TEST_REFERENCES = get_path("test_data/test_references.txt") 15 | TEST_TSV_PREDICT = get_path("test_data/test_tsv_predict.tsv") 16 | TEST_TSV_TRAIN = get_path("test_data/test_tsv_train.tsv") 17 | TEST_LOAD_TSV = get_path("test_data/test_load_tsv.tsv") 18 | -------------------------------------------------------------------------------- /tests/prodigy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/deep_reference_parser/b58e4616f4de9bfe18ab41e90f696f80ab876245/tests/prodigy/__init__.py -------------------------------------------------------------------------------- /tests/prodigy/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | 7 | def get_path(p): 8 | return os.path.join(os.path.dirname(__file__), p) 9 | 10 | 11 | TEST_TOKENS = get_path("test_data/test_tokens_to_tsv_tokens.jsonl") 12 | TEST_SPANS = get_path("test_data/test_tokens_to_tsv_spans.jsonl") 13 | TEST_REF_TOKENS = get_path("test_data/test_reference_to_token_tokens.jsonl") 14 | TEST_REF_SPANS = get_path("test_data/test_reference_to_token_spans.jsonl") 15 | TEST_REF_EXPECTED_SPANS = get_path("test_data/test_reference_to_token_expected.jsonl") 16 | 17 | # Prodigy format document containing numbered reference section 18 | 19 | TEST_NUMBERED_REFERENCES = get_path("test_data/test_numbered_references.jsonl") 20 | 21 | # Prodigy format document with spans annotating every token in the document 22 | 23 | TEST_TOKEN_LABELLED = get_path("test_data/test_token_labelled_references.jsonl") 24 | 25 | # tsv data created from the above file 26 | 27 | TEST_TOKEN_LABELLED_TSV = get_path("test_data/test_token_labelled_references.tsv") 28 | 29 | # Reference section in Reach format 30 | 31 | TEST_REACH = get_path("test_data/test_reach.jsonl") 32 | -------------------------------------------------------------------------------- /tests/prodigy/test_data/test_reference_to_token_spans.jsonl: -------------------------------------------------------------------------------- 1 | [{"start":21,"end":64,"token_start":5,"token_end":11,"label":"author"},{"start":68,"end":161,"token_start":14,"token_end":26,"label":"title"},{"start":182,"end":186,"token_start":34,"token_end":34,"label":"year"},{"start":483,"end":519,"token_start":102,"token_end":108,"label":"author"},{"start":521,"end":550,"token_start":110,"token_end":113,"label":"title"},{"start":674,"end":719,"token_start":144,"token_end":153,"label":"author"},{"start":736,"end":752,"token_start":158,"token_end":162,"label":"title"},{"start":803,"end":807,"token_start":176,"token_end":176,"label":"year"},{"start":891,"end":930,"token_start":195,"token_end":201,"label":"author"},{"start":932,"end":1002,"token_start":203,"token_end":215,"label":"title"},{"start":1024,"end":1028,"token_start":221,"token_end":221,"label":"year"},{"start":1143,"end":1188,"token_start":254,"token_end":261,"label":"author"},{"start":1190,"end":1198,"token_start":263,"token_end":264,"label":"title"},{"start":1337,"end":1344,"token_start":294,"token_end":294,"label":"author"},{"start":1346,"end":1429,"token_start":296,"token_end":305,"label":"title"},{"start":1563,"end":1616,"token_start":334,"token_end":341,"label":"author"},{"start":1618,"end":1640,"token_start":343,"token_end":345,"label":"title"},{"start":1670,"end":1674,"token_start":353,"token_end":353,"label":"year"},{"start":1770,"end":1807,"token_start":375,"token_end":380,"label":"author"},{"start":1809,"end":1840,"token_start":382,"token_end":384,"label":"title"}] 2 | -------------------------------------------------------------------------------- /tests/prodigy/test_data/test_tokens_to_tsv_tokens.jsonl: -------------------------------------------------------------------------------- 1 | [{"text":"References","start":0,"end":10,"id":0},{"text":"\n \n ","start":11,"end":16,"id":1},{"text":"1","start":16,"end":17,"id":2},{"text":".","start":17,"end":18,"id":3},{"text":"\n ","start":19,"end":21,"id":4},{"text":"United","start":21,"end":27,"id":5},{"text":"Nations","start":28,"end":35,"id":6},{"text":"Development","start":36,"end":47,"id":7},{"text":"Programme","start":48,"end":57,"id":8},{"text":"(","start":58,"end":59,"id":9},{"text":"UNDP","start":59,"end":63,"id":10},{"text":")","start":63,"end":64,"id":11},{"text":".","start":64,"end":65,"id":12},{"text":"A","start":66,"end":67,"id":13},{"text":"Guide","start":68,"end":73,"id":14},{"text":"to","start":74,"end":76,"id":15},{"text":"Civil","start":77,"end":82,"id":16},{"text":"Society","start":83,"end":90,"id":17},{"text":"Organizations","start":91,"end":104,"id":18},{"text":"\n ","start":105,"end":107,"id":19},{"text":"working","start":107,"end":114,"id":20},{"text":"on","start":115,"end":117,"id":21},{"text":"Democratic","start":118,"end":128,"id":22},{"text":"Governance","start":129,"end":139,"id":23},{"text":"[","start":140,"end":141,"id":24},{"text":"online","start":141,"end":147,"id":25},{"text":"publication].","start":148,"end":161,"id":26},{"text":"New","start":162,"end":165,"id":27},{"text":"York","start":166,"end":170,"id":28},{"text":",","start":170,"end":171,"id":29},{"text":"NY","start":172,"end":174,"id":30},{"text":";","start":174,"end":175,"id":31},{"text":"UNDP","start":176,"end":180,"id":32},{"text":";","start":180,"end":181,"id":33},{"text":"2005","start":182,"end":186,"id":34},{"text":".","start":186,"end":187,"id":35},{"text":"\n ","start":188,"end":190,"id":36},{"text":"(","start":190,"end":191,"id":37},{"text":"Available","start":191,"end":200,"id":38},{"text":"from","start":201,"end":205,"id":39},{"text":":","start":205,"end":206,"id":40},{"text":"\n ","start":207,"end":209,"id":41},{"text":"http://www.undp.org","start":209,"end":228,"id":42},{"text":"/","start":228,"end":229,"id":43},{"text":"content","start":229,"end":236,"id":44},{"text":"/","start":236,"end":237,"id":45},{"text":"dam","start":237,"end":240,"id":46},{"text":"/","start":240,"end":241,"id":47},{"text":"aplaws","start":241,"end":247,"id":48},{"text":"/","start":247,"end":248,"id":49},{"text":"publication","start":248,"end":259,"id":50},{"text":"/","start":259,"end":260,"id":51},{"text":"en","start":260,"end":262,"id":52},{"text":"/","start":262,"end":263,"id":53},{"text":"publications","start":263,"end":275,"id":54},{"text":"/","start":275,"end":276,"id":55},{"text":"democratic-","start":276,"end":287,"id":56},{"text":"\n ","start":287,"end":289,"id":57},{"text":"governance","start":289,"end":299,"id":58},{"text":"/","start":299,"end":300,"id":59},{"text":"oslo","start":300,"end":304,"id":60},{"text":"-","start":304,"end":305,"id":61},{"text":"governance","start":305,"end":315,"id":62},{"text":"-","start":315,"end":316,"id":63},{"text":"center","start":316,"end":322,"id":64},{"text":"/","start":322,"end":323,"id":65},{"text":"civic","start":323,"end":328,"id":66},{"text":"-","start":328,"end":329,"id":67},{"text":"engagement","start":329,"end":339,"id":68},{"text":"/","start":339,"end":340,"id":69},{"text":"a","start":340,"end":341,"id":70},{"text":"-","start":341,"end":342,"id":71},{"text":"guide","start":342,"end":347,"id":72},{"text":"-","start":347,"end":348,"id":73},{"text":"to","start":348,"end":350,"id":74},{"text":"-","start":350,"end":351,"id":75},{"text":"civil","start":351,"end":356,"id":76},{"text":"-","start":356,"end":357,"id":77},{"text":"society-","start":357,"end":365,"id":78},{"text":"\n ","start":365,"end":367,"id":79},{"text":"organizations","start":367,"end":380,"id":80},{"text":"-","start":380,"end":381,"id":81},{"text":"working","start":381,"end":388,"id":82},{"text":"-","start":388,"end":389,"id":83},{"text":"on","start":389,"end":391,"id":84},{"text":"-","start":391,"end":392,"id":85},{"text":"democratic","start":392,"end":402,"id":86},{"text":"-","start":402,"end":403,"id":87},{"text":"governance-/3665%20Booklet_heleWEB_.pdf","start":403,"end":442,"id":88},{"text":"\n ","start":442,"end":444,"id":89},{"text":",","start":444,"end":445,"id":90},{"text":"\n ","start":446,"end":448,"id":91},{"text":"accessed","start":448,"end":456,"id":92},{"text":"15","start":457,"end":459,"id":93},{"text":"February","start":460,"end":468,"id":94},{"text":"2017","start":469,"end":473,"id":95},{"text":")","start":473,"end":474,"id":96},{"text":".","start":474,"end":475,"id":97},{"text":"\n ","start":476,"end":478,"id":98},{"text":"2","start":478,"end":479,"id":99},{"text":".","start":479,"end":480,"id":100},{"text":"\n ","start":481,"end":483,"id":101},{"text":"Mental","start":483,"end":489,"id":102},{"text":"Health","start":490,"end":496,"id":103},{"text":"Peer","start":497,"end":501,"id":104},{"text":"Connection","start":502,"end":512,"id":105},{"text":"(","start":513,"end":514,"id":106},{"text":"MHPC","start":514,"end":518,"id":107},{"text":")","start":518,"end":519,"id":108},{"text":".","start":519,"end":520,"id":109},{"text":"Mental","start":521,"end":527,"id":110},{"text":"Health","start":528,"end":534,"id":111},{"text":"Peer","start":535,"end":539,"id":112},{"text":"Connection","start":540,"end":550,"id":113},{"text":"[","start":551,"end":552,"id":114},{"text":"website].","start":552,"end":561,"id":115},{"text":"Buffalo","start":562,"end":569,"id":116},{"text":",","start":569,"end":570,"id":117},{"text":"\n ","start":571,"end":573,"id":118},{"text":"NY","start":573,"end":575,"id":119},{"text":";","start":575,"end":576,"id":120},{"text":"MHPC","start":577,"end":581,"id":121},{"text":";","start":581,"end":582,"id":122},{"text":"n.d","start":583,"end":586,"id":123},{"text":".","start":586,"end":587,"id":124},{"text":"(","start":588,"end":589,"id":125},{"text":"Available","start":589,"end":598,"id":126},{"text":"from","start":599,"end":603,"id":127},{"text":":","start":604,"end":605,"id":128},{"text":"\n ","start":606,"end":608,"id":129},{"text":"http://wnyil.org/mhpc.html","start":608,"end":634,"id":130},{"text":"\n ","start":634,"end":636,"id":131},{"text":",","start":636,"end":637,"id":132},{"text":"a","start":638,"end":639,"id":133},{"text":"ccessed","start":640,"end":647,"id":134},{"text":"15","start":648,"end":650,"id":135},{"text":"February","start":651,"end":659,"id":136},{"text":"2017","start":660,"end":664,"id":137},{"text":")","start":664,"end":665,"id":138},{"text":".","start":665,"end":666,"id":139},{"text":"\n ","start":667,"end":669,"id":140},{"text":"3","start":669,"end":670,"id":141},{"text":".","start":670,"end":671,"id":142},{"text":"\n ","start":672,"end":674,"id":143},{"text":"Avery","start":674,"end":679,"id":144},{"text":"S","start":680,"end":681,"id":145},{"text":",","start":681,"end":682,"id":146},{"text":"Mental","start":683,"end":689,"id":147},{"text":"Health","start":690,"end":696,"id":148},{"text":"Peer","start":697,"end":701,"id":149},{"text":"Connection","start":702,"end":712,"id":150},{"text":"(","start":713,"end":714,"id":151},{"text":"MHPC","start":714,"end":718,"id":152},{"text":")","start":718,"end":719,"id":153},{"text":".","start":719,"end":720,"id":154},{"text":"Channels","start":721,"end":729,"id":155},{"text":"2013","start":730,"end":734,"id":156},{"text":",","start":734,"end":735,"id":157},{"text":"“","start":736,"end":737,"id":158},{"text":"Not","start":737,"end":740,"id":159},{"text":"Without","start":741,"end":748,"id":160},{"text":"Us","start":749,"end":751,"id":161},{"text":"”","start":751,"end":752,"id":162},{"text":"[","start":753,"end":754,"id":163},{"text":"video].","start":754,"end":761,"id":164},{"text":"\n ","start":762,"end":764,"id":165},{"text":"Western","start":764,"end":771,"id":166},{"text":"New","start":772,"end":775,"id":167},{"text":"York","start":776,"end":780,"id":168},{"text":"(","start":781,"end":782,"id":169},{"text":"WNY","start":782,"end":785,"id":170},{"text":")","start":785,"end":786,"id":171},{"text":";","start":786,"end":787,"id":172},{"text":"Squeeky","start":788,"end":795,"id":173},{"text":"Wheel","start":796,"end":801,"id":174},{"text":";","start":801,"end":802,"id":175},{"text":"2013","start":803,"end":807,"id":176},{"text":".","start":807,"end":808,"id":177},{"text":"(","start":809,"end":810,"id":178},{"text":"Available","start":810,"end":819,"id":179},{"text":"from","start":820,"end":824,"id":180},{"text":":","start":824,"end":825,"id":181},{"text":"\n ","start":826,"end":828,"id":182},{"text":"https://vimeo.com/62705552","start":828,"end":854,"id":183},{"text":",","start":854,"end":855,"id":184},{"text":"accessed","start":856,"end":864,"id":185},{"text":"15","start":865,"end":867,"id":186},{"text":"February","start":868,"end":876,"id":187},{"text":"2017","start":877,"end":881,"id":188},{"text":")","start":881,"end":882,"id":189},{"text":".","start":882,"end":883,"id":190},{"text":"\n ","start":884,"end":886,"id":191},{"text":"4","start":886,"end":887,"id":192},{"text":".","start":887,"end":888,"id":193},{"text":"\n ","start":889,"end":891,"id":194},{"text":"Alzheimer","start":891,"end":900,"id":195},{"text":"'s","start":900,"end":902,"id":196},{"text":"Disease","start":903,"end":910,"id":197},{"text":"International","start":911,"end":924,"id":198},{"text":"(","start":925,"end":926,"id":199},{"text":"ADI","start":926,"end":929,"id":200},{"text":")","start":929,"end":930,"id":201},{"text":".","start":930,"end":931,"id":202},{"text":"How","start":932,"end":935,"id":203},{"text":"to","start":936,"end":938,"id":204},{"text":"develop","start":939,"end":946,"id":205},{"text":"an","start":947,"end":949,"id":206},{"text":"Alzheimer","start":950,"end":959,"id":207},{"text":"'s","start":959,"end":961,"id":208},{"text":"association","start":962,"end":973,"id":209},{"text":"and","start":974,"end":977,"id":210},{"text":"get","start":978,"end":981,"id":211},{"text":"\n ","start":982,"end":984,"id":212},{"text":"results","start":984,"end":991,"id":213},{"text":"[","start":992,"end":993,"id":214},{"text":"website].","start":993,"end":1002,"id":215},{"text":"United","start":1003,"end":1009,"id":216},{"text":"Kingdom","start":1010,"end":1017,"id":217},{"text":";","start":1017,"end":1018,"id":218},{"text":"ADI","start":1019,"end":1022,"id":219},{"text":";","start":1022,"end":1023,"id":220},{"text":"2006","start":1024,"end":1028,"id":221},{"text":".","start":1028,"end":1029,"id":222},{"text":"(","start":1030,"end":1031,"id":223},{"text":"Available","start":1031,"end":1040,"id":224},{"text":"from","start":1041,"end":1045,"id":225},{"text":":","start":1045,"end":1046,"id":226},{"text":"https:/","start":1047,"end":1054,"id":227},{"text":"/","start":1055,"end":1056,"id":228},{"text":"\n ","start":1056,"end":1058,"id":229},{"text":"www.alz.co.uk","start":1058,"end":1071,"id":230},{"text":"/","start":1071,"end":1072,"id":231},{"text":"how-","start":1072,"end":1076,"id":232},{"text":"\n ","start":1076,"end":1078,"id":233},{"text":"to","start":1078,"end":1080,"id":234},{"text":"-","start":1080,"end":1081,"id":235},{"text":"develop","start":1081,"end":1088,"id":236},{"text":"-","start":1088,"end":1089,"id":237},{"text":"an","start":1089,"end":1091,"id":238},{"text":"-","start":1091,"end":1092,"id":239},{"text":"association","start":1092,"end":1103,"id":240},{"text":"\n ","start":1103,"end":1105,"id":241},{"text":",","start":1105,"end":1106,"id":242},{"text":" ","start":1107,"end":1108,"id":243},{"text":"accessed","start":1108,"end":1116,"id":244},{"text":"15","start":1117,"end":1119,"id":245},{"text":"February","start":1120,"end":1128,"id":246},{"text":"2017","start":1129,"end":1133,"id":247},{"text":")","start":1133,"end":1134,"id":248},{"text":".","start":1134,"end":1135,"id":249},{"text":"\n ","start":1136,"end":1138,"id":250},{"text":"5","start":1138,"end":1139,"id":251},{"text":".","start":1139,"end":1140,"id":252},{"text":"\n ","start":1141,"end":1143,"id":253},{"text":"Normal","start":1143,"end":1149,"id":254},{"text":"Difference","start":1150,"end":1160,"id":255},{"text":"Mental","start":1161,"end":1167,"id":256},{"text":"Health","start":1168,"end":1174,"id":257},{"text":"Kenya","start":1175,"end":1180,"id":258},{"text":"(","start":1181,"end":1182,"id":259},{"text":"NDMHK","start":1182,"end":1187,"id":260},{"text":")","start":1187,"end":1188,"id":261},{"text":".","start":1188,"end":1189,"id":262},{"text":"About","start":1190,"end":1195,"id":263},{"text":"Us","start":1196,"end":1198,"id":264},{"text":"[","start":1199,"end":1200,"id":265},{"text":"website].","start":1200,"end":1209,"id":266},{"text":"Kenya","start":1210,"end":1215,"id":267},{"text":";","start":1215,"end":1216,"id":268},{"text":"NDMHK","start":1217,"end":1222,"id":269},{"text":";","start":1222,"end":1223,"id":270},{"text":"n.d","start":1224,"end":1227,"id":271},{"text":".","start":1227,"end":1228,"id":272},{"text":"\n ","start":1229,"end":1231,"id":273},{"text":"(","start":1231,"end":1232,"id":274},{"text":"Available","start":1232,"end":1241,"id":275},{"text":"from","start":1242,"end":1246,"id":276},{"text":":","start":1247,"end":1248,"id":277},{"text":"\n ","start":1249,"end":1251,"id":278},{"text":"http://www.normal-difference.org/?page_id=15","start":1251,"end":1295,"id":279},{"text":"\n ","start":1295,"end":1297,"id":280},{"text":",","start":1297,"end":1298,"id":281},{"text":"ac","start":1299,"end":1301,"id":282},{"text":"cessed","start":1302,"end":1308,"id":283},{"text":"15","start":1309,"end":1311,"id":284},{"text":"February","start":1312,"end":1320,"id":285},{"text":"\n ","start":1321,"end":1323,"id":286},{"text":"2017","start":1323,"end":1327,"id":287},{"text":")","start":1327,"end":1328,"id":288},{"text":".","start":1328,"end":1329,"id":289},{"text":"\n ","start":1330,"end":1332,"id":290},{"text":"6","start":1332,"end":1333,"id":291},{"text":".","start":1333,"end":1334,"id":292},{"text":"\n ","start":1335,"end":1337,"id":293},{"text":"TOPSIDE","start":1337,"end":1344,"id":294},{"text":".","start":1344,"end":1345,"id":295},{"text":"Training","start":1346,"end":1354,"id":296},{"text":"Opportunities","start":1355,"end":1368,"id":297},{"text":"for","start":1369,"end":1372,"id":298},{"text":"Peer","start":1373,"end":1377,"id":299},{"text":"Supporters","start":1378,"end":1388,"id":300},{"text":"with","start":1389,"end":1393,"id":301},{"text":"Intellectual","start":1394,"end":1406,"id":302},{"text":"Disabilities","start":1407,"end":1419,"id":303},{"text":"in","start":1420,"end":1422,"id":304},{"text":"Europe","start":1423,"end":1429,"id":305},{"text":"\n ","start":1430,"end":1432,"id":306},{"text":"[","start":1432,"end":1433,"id":307},{"text":"website","start":1433,"end":1440,"id":308},{"text":"]","start":1440,"end":1441,"id":309},{"text":";","start":1441,"end":1442,"id":310},{"text":"TOPSIDE","start":1443,"end":1450,"id":311},{"text":";","start":1450,"end":1451,"id":312},{"text":"n.d","start":1452,"end":1455,"id":313},{"text":".","start":1455,"end":1456,"id":314},{"text":"(","start":1457,"end":1458,"id":315},{"text":"Available","start":1458,"end":1467,"id":316},{"text":"from","start":1468,"end":1472,"id":317},{"text":":","start":1473,"end":1474,"id":318},{"text":"\n ","start":1475,"end":1477,"id":319},{"text":"http://www.peer-support.eu/about-the-project/","start":1477,"end":1522,"id":320},{"text":"\n ","start":1522,"end":1524,"id":321},{"text":",","start":1524,"end":1525,"id":322},{"text":"\n ","start":1526,"end":1528,"id":323},{"text":"accessed","start":1528,"end":1536,"id":324},{"text":"15","start":1537,"end":1539,"id":325},{"text":"February","start":1540,"end":1548,"id":326},{"text":"2017","start":1549,"end":1553,"id":327},{"text":")","start":1553,"end":1554,"id":328},{"text":".","start":1554,"end":1555,"id":329},{"text":"\n ","start":1556,"end":1558,"id":330},{"text":"7","start":1558,"end":1559,"id":331},{"text":".","start":1559,"end":1560,"id":332},{"text":"\n ","start":1561,"end":1563,"id":333},{"text":"KOSHISH","start":1563,"end":1570,"id":334},{"text":"National","start":1571,"end":1579,"id":335},{"text":"Mental","start":1580,"end":1586,"id":336},{"text":"Health","start":1587,"end":1593,"id":337},{"text":"Self","start":1594,"end":1598,"id":338},{"text":"-","start":1598,"end":1599,"id":339},{"text":"help","start":1599,"end":1603,"id":340},{"text":"Organisation","start":1604,"end":1616,"id":341},{"text":".","start":1616,"end":1617,"id":342},{"text":"Advocacy","start":1618,"end":1626,"id":343},{"text":"and","start":1627,"end":1630,"id":344},{"text":"Awareness","start":1631,"end":1640,"id":345},{"text":"[","start":1641,"end":1642,"id":346},{"text":"website].","start":1642,"end":1651,"id":347},{"text":"\n ","start":1652,"end":1654,"id":348},{"text":"Nepal","start":1654,"end":1659,"id":349},{"text":";","start":1659,"end":1660,"id":350},{"text":"KOSHISH","start":1661,"end":1668,"id":351},{"text":";","start":1668,"end":1669,"id":352},{"text":"2015","start":1670,"end":1674,"id":353},{"text":".","start":1674,"end":1675,"id":354},{"text":"(","start":1676,"end":1677,"id":355},{"text":"Available","start":1677,"end":1686,"id":356},{"text":"from","start":1687,"end":1691,"id":357},{"text":":","start":1691,"end":1692,"id":358},{"text":" \n ","start":1693,"end":1696,"id":359},{"text":"http://koshishnepal.org/advocacy","start":1696,"end":1728,"id":360},{"text":"\n ","start":1728,"end":1730,"id":361},{"text":",","start":1730,"end":1731,"id":362},{"text":" ","start":1732,"end":1733,"id":363},{"text":"accessed","start":1733,"end":1741,"id":364},{"text":"15","start":1742,"end":1744,"id":365},{"text":"\n ","start":1745,"end":1747,"id":366},{"text":"February","start":1747,"end":1755,"id":367},{"text":"2017","start":1756,"end":1760,"id":368},{"text":")","start":1760,"end":1761,"id":369},{"text":".","start":1761,"end":1762,"id":370},{"text":"\n ","start":1763,"end":1765,"id":371},{"text":"8","start":1765,"end":1766,"id":372},{"text":".","start":1766,"end":1767,"id":373},{"text":"\n ","start":1768,"end":1770,"id":374},{"text":"Dementia","start":1770,"end":1778,"id":375},{"text":"Alliance","start":1779,"end":1787,"id":376},{"text":"International","start":1788,"end":1801,"id":377},{"text":"(","start":1802,"end":1803,"id":378},{"text":"DAI","start":1803,"end":1806,"id":379},{"text":")","start":1806,"end":1807,"id":380},{"text":".","start":1807,"end":1808,"id":381},{"text":"Dementia","start":1809,"end":1817,"id":382},{"text":"Alliance","start":1818,"end":1826,"id":383},{"text":"International","start":1827,"end":1840,"id":384},{"text":"[","start":1841,"end":1842,"id":385},{"text":"website].","start":1842,"end":1851,"id":386},{"text":"Ankeny","start":1852,"end":1858,"id":387},{"text":",","start":1858,"end":1859,"id":388},{"text":"IA","start":1860,"end":1862,"id":389},{"text":";","start":1862,"end":1863,"id":390},{"text":"\n ","start":1864,"end":1866,"id":391},{"text":"DAI","start":1866,"end":1869,"id":392},{"text":";","start":1869,"end":1870,"id":393},{"text":"2014/2015","start":1871,"end":1880,"id":394},{"text":".","start":1880,"end":1881,"id":395},{"text":"(","start":1882,"end":1883,"id":396},{"text":"Available","start":1883,"end":1892,"id":397},{"text":"from","start":1893,"end":1897,"id":398},{"text":":","start":1897,"end":1898,"id":399},{"text":" \n ","start":1899,"end":1902,"id":400},{"text":"http://www.dementiaallianceinternational.org/","start":1902,"end":1947,"id":401},{"text":"\n ","start":1947,"end":1949,"id":402},{"text":",","start":1949,"end":1950,"id":403},{"text":" ","start":1951,"end":1952,"id":404},{"text":"accessed","start":1952,"end":1960,"id":405},{"text":"\n ","start":1961,"end":1963,"id":406},{"text":"15","start":1963,"end":1965,"id":407},{"text":"February","start":1966,"end":1974,"id":408},{"text":"2017","start":1975,"end":1979,"id":409},{"text":")","start":1979,"end":1980,"id":410},{"text":".","start":1980,"end":1981,"id":411},{"text":"\n ","start":1982,"end":1984,"id":412},{"text":"9","start":1984,"end":1985,"id":413},{"text":".","start":1985,"end":1986,"id":414}] 2 | -------------------------------------------------------------------------------- /tests/prodigy/test_labels_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | from deep_reference_parser.prodigy import labels_to_prodigy 5 | 6 | 7 | def test_labels_to_prodigy(): 8 | 9 | tokens = [["Ackerman", "J", ".", "S", ".,", "Palladio", ",", "Torino", "1972", "."]] 10 | 11 | labels = [["b-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "e-r"]] 12 | 13 | expected = [ 14 | { 15 | "text": "Ackerman J . S ., Palladio , Torino 1972 .", 16 | "tokens": [ 17 | {"text": "Ackerman", "id": 0, "start": 0, "end": 8}, 18 | {"text": "J", "id": 1, "start": 9, "end": 10}, 19 | {"text": ".", "id": 2, "start": 11, "end": 12}, 20 | {"text": "S", "id": 3, "start": 13, "end": 14}, 21 | {"text": ".,", "id": 4, "start": 15, "end": 17}, 22 | {"text": "Palladio", "id": 5, "start": 18, "end": 26}, 23 | {"text": ",", "id": 6, "start": 27, "end": 28}, 24 | {"text": "Torino", "id": 7, "start": 29, "end": 35}, 25 | {"text": "1972", "id": 8, "start": 36, "end": 40}, 26 | {"text": ".", "id": 9, "start": 41, "end": 42}, 27 | ], 28 | "spans": [ 29 | { 30 | "label": "b-r", 31 | "start": 0, 32 | "end": 8, 33 | "token_start": 0, 34 | "token_end": 0, 35 | }, 36 | { 37 | "label": "i-r", 38 | "start": 9, 39 | "end": 10, 40 | "token_start": 1, 41 | "token_end": 1, 42 | }, 43 | { 44 | "label": "i-r", 45 | "start": 11, 46 | "end": 12, 47 | "token_start": 2, 48 | "token_end": 2, 49 | }, 50 | { 51 | "label": "i-r", 52 | "start": 13, 53 | "end": 14, 54 | "token_start": 3, 55 | "token_end": 3, 56 | }, 57 | { 58 | "label": "i-r", 59 | "start": 15, 60 | "end": 17, 61 | "token_start": 4, 62 | "token_end": 4, 63 | }, 64 | { 65 | "label": "i-r", 66 | "start": 18, 67 | "end": 26, 68 | "token_start": 5, 69 | "token_end": 5, 70 | }, 71 | { 72 | "label": "i-r", 73 | "start": 27, 74 | "end": 28, 75 | "token_start": 6, 76 | "token_end": 6, 77 | }, 78 | { 79 | "label": "i-r", 80 | "start": 29, 81 | "end": 35, 82 | "token_start": 7, 83 | "token_end": 7, 84 | }, 85 | { 86 | "label": "i-r", 87 | "start": 36, 88 | "end": 40, 89 | "token_start": 8, 90 | "token_end": 8, 91 | }, 92 | { 93 | "label": "e-r", 94 | "start": 41, 95 | "end": 42, 96 | "token_start": 9, 97 | "token_end": 9, 98 | }, 99 | ], 100 | "meta": {"line": 0}, 101 | } 102 | ] 103 | 104 | out = labels_to_prodigy(tokens, labels) 105 | 106 | assert out == expected 107 | -------------------------------------------------------------------------------- /tests/prodigy/test_misc.py: -------------------------------------------------------------------------------- 1 | from deep_reference_parser.prodigy import prodigy_to_conll 2 | 3 | 4 | def test_prodigy_to_conll(): 5 | 6 | before = [ 7 | {"text": "References",}, 8 | {"text": "37. No single case of malaria reported in"}, 9 | { 10 | "text": "an essential requirement for the correct labelling of potency for therapeutic" 11 | }, 12 | {"text": "EQAS, quality control for STI"}, 13 | ] 14 | 15 | after = "DOCSTART\n\nReferences\n\n37\n.\nNo\nsingle\ncase\nof\nmalaria\nreported\nin\n\nan\nessential\nrequirement\nfor\nthe\ncorrect\nlabelling\nof\npotency\nfor\ntherapeutic\n\nEQAS\n,\nquality\ncontrol\nfor\nSTI" 16 | 17 | out = prodigy_to_conll(before) 18 | 19 | assert after == out 20 | -------------------------------------------------------------------------------- /tests/prodigy/test_numbered_reference_annotator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | import spacy 6 | from deep_reference_parser.prodigy.numbered_reference_annotator import ( 7 | NumberedReferenceAnnotator, 8 | ) 9 | 10 | 11 | @pytest.fixture(scope="function") 12 | def nra(): 13 | return NumberedReferenceAnnotator() 14 | 15 | 16 | def test_numbered_reference_splitter(nra): 17 | 18 | numbered_reference = { 19 | "text": "References\n 1. \n Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168. \n 2. \n WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126. \n 3. \n Consolidated guidelines on the use of antiretroviral drugs for treating and preventing HIV infection: \n recommendations for a public health approach. Geneva: World Health Organization; 2013:272. \n 4.", 20 | "tokens": [ 21 | {"text": "References", "start": 0, "end": 10, "id": 0}, 22 | {"text": "\n ", "start": 10, "end": 12, "id": 1}, 23 | {"text": "1", "start": 12, "end": 13, "id": 2}, 24 | {"text": ".", "start": 13, "end": 14, "id": 3}, 25 | {"text": "\n ", "start": 15, "end": 17, "id": 4}, 26 | {"text": "Global", "start": 17, "end": 23, "id": 5}, 27 | {"text": "update", "start": 24, "end": 30, "id": 6}, 28 | {"text": "on", "start": 31, "end": 33, "id": 7}, 29 | {"text": "the", "start": 34, "end": 37, "id": 8}, 30 | {"text": "health", "start": 38, "end": 44, "id": 9}, 31 | {"text": "sector", "start": 45, "end": 51, "id": 10}, 32 | {"text": "response", "start": 52, "end": 60, "id": 11}, 33 | {"text": "to", "start": 61, "end": 63, "id": 12}, 34 | {"text": "HIV", "start": 64, "end": 67, "id": 13}, 35 | {"text": ",", "start": 67, "end": 68, "id": 14}, 36 | {"text": "2014", "start": 69, "end": 73, "id": 15}, 37 | {"text": ".", "start": 73, "end": 74, "id": 16}, 38 | {"text": "Geneva", "start": 75, "end": 81, "id": 17}, 39 | {"text": ":", "start": 81, "end": 82, "id": 18}, 40 | {"text": "World", "start": 83, "end": 88, "id": 19}, 41 | {"text": "Health", "start": 89, "end": 95, "id": 20}, 42 | {"text": "Organization", "start": 96, "end": 108, "id": 21}, 43 | {"text": ";", "start": 108, "end": 109, "id": 22}, 44 | {"text": "\n ", "start": 110, "end": 112, "id": 23}, 45 | {"text": "2014:168", "start": 112, "end": 120, "id": 24}, 46 | {"text": ".", "start": 120, "end": 121, "id": 25}, 47 | {"text": "\n ", "start": 122, "end": 124, "id": 26}, 48 | {"text": "2", "start": 124, "end": 125, "id": 27}, 49 | {"text": ".", "start": 125, "end": 126, "id": 28}, 50 | {"text": "\n ", "start": 127, "end": 129, "id": 29}, 51 | {"text": "WHO", "start": 129, "end": 132, "id": 30}, 52 | {"text": ",", "start": 132, "end": 133, "id": 31}, 53 | {"text": "UNICEF", "start": 134, "end": 140, "id": 32}, 54 | {"text": ",", "start": 140, "end": 141, "id": 33}, 55 | {"text": "UNAIDS", "start": 142, "end": 148, "id": 34}, 56 | {"text": ".", "start": 148, "end": 149, "id": 35}, 57 | {"text": "Global", "start": 150, "end": 156, "id": 36}, 58 | {"text": "update", "start": 157, "end": 163, "id": 37}, 59 | {"text": "on", "start": 164, "end": 166, "id": 38}, 60 | {"text": "HIV", "start": 167, "end": 170, "id": 39}, 61 | {"text": "treatment", "start": 171, "end": 180, "id": 40}, 62 | {"text": "2013", "start": 181, "end": 185, "id": 41}, 63 | {"text": ":", "start": 185, "end": 186, "id": 42}, 64 | {"text": "results", "start": 187, "end": 194, "id": 43}, 65 | {"text": ",", "start": 194, "end": 195, "id": 44}, 66 | {"text": "impact", "start": 196, "end": 202, "id": 45}, 67 | {"text": "and", "start": 203, "end": 206, "id": 46}, 68 | {"text": "\n ", "start": 207, "end": 209, "id": 47}, 69 | {"text": "opportunities", "start": 209, "end": 222, "id": 48}, 70 | {"text": ".", "start": 222, "end": 223, "id": 49}, 71 | {"text": "Geneva", "start": 224, "end": 230, "id": 50}, 72 | {"text": ":", "start": 230, "end": 231, "id": 51}, 73 | {"text": "World", "start": 232, "end": 237, "id": 52}, 74 | {"text": "Health", "start": 238, "end": 244, "id": 53}, 75 | {"text": "Organization", "start": 245, "end": 257, "id": 54}, 76 | {"text": ";", "start": 257, "end": 258, "id": 55}, 77 | {"text": "2013:126", "start": 259, "end": 267, "id": 56}, 78 | {"text": ".", "start": 267, "end": 268, "id": 57}, 79 | {"text": "\n ", "start": 269, "end": 271, "id": 58}, 80 | {"text": "3", "start": 271, "end": 272, "id": 59}, 81 | {"text": ".", "start": 272, "end": 273, "id": 60}, 82 | {"text": "\n ", "start": 274, "end": 276, "id": 61}, 83 | {"text": "Consolidated", "start": 276, "end": 288, "id": 62}, 84 | {"text": "guidelines", "start": 289, "end": 299, "id": 63}, 85 | {"text": "on", "start": 300, "end": 302, "id": 64}, 86 | {"text": "the", "start": 303, "end": 306, "id": 65}, 87 | {"text": "use", "start": 307, "end": 310, "id": 66}, 88 | {"text": "of", "start": 311, "end": 313, "id": 67}, 89 | {"text": "antiretroviral", "start": 314, "end": 328, "id": 68}, 90 | {"text": "drugs", "start": 329, "end": 334, "id": 69}, 91 | {"text": "for", "start": 335, "end": 338, "id": 70}, 92 | {"text": "treating", "start": 339, "end": 347, "id": 71}, 93 | {"text": "and", "start": 348, "end": 351, "id": 72}, 94 | {"text": "preventing", "start": 352, "end": 362, "id": 73}, 95 | {"text": "HIV", "start": 363, "end": 366, "id": 74}, 96 | {"text": "infection", "start": 367, "end": 376, "id": 75}, 97 | {"text": ":", "start": 376, "end": 377, "id": 76}, 98 | {"text": "\n ", "start": 378, "end": 380, "id": 77}, 99 | {"text": "recommendations", "start": 380, "end": 395, "id": 78}, 100 | {"text": "for", "start": 396, "end": 399, "id": 79}, 101 | {"text": "a", "start": 400, "end": 401, "id": 80}, 102 | {"text": "public", "start": 402, "end": 408, "id": 81}, 103 | {"text": "health", "start": 409, "end": 415, "id": 82}, 104 | {"text": "approach", "start": 416, "end": 424, "id": 83}, 105 | {"text": ".", "start": 424, "end": 425, "id": 84}, 106 | {"text": "Geneva", "start": 426, "end": 432, "id": 85}, 107 | {"text": ":", "start": 432, "end": 433, "id": 86}, 108 | {"text": "World", "start": 434, "end": 439, "id": 87}, 109 | {"text": "Health", "start": 440, "end": 446, "id": 88}, 110 | {"text": "Organization", "start": 447, "end": 459, "id": 89}, 111 | {"text": ";", "start": 459, "end": 460, "id": 90}, 112 | {"text": "2013:272", "start": 461, "end": 469, "id": 91}, 113 | {"text": ".", "start": 469, "end": 470, "id": 92}, 114 | {"text": "\n", "start": 470, "end": 471, "id": 92}, 115 | {"text": "3", "start": 471, "end": 472, "id": 92}, 116 | {"text": ".", "start": 472, "end": 473, "id": 92}, 117 | ], 118 | } 119 | 120 | docs = list(nra.run([numbered_reference])) 121 | text = docs[0]["text"] 122 | spans = docs[0]["spans"] 123 | ref_1 = text[spans[0]["start"] : spans[0]["end"]] 124 | ref_2 = text[spans[1]["start"] : spans[1]["end"]] 125 | ref_3 = text[spans[2]["start"] : spans[2]["end"]] 126 | 127 | assert len(spans) == 3 128 | assert ( 129 | ref_1 130 | == "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168." 131 | ) 132 | assert ( 133 | ref_2.strip() 134 | == "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126." 135 | ) 136 | assert ( 137 | ref_3.strip() 138 | == "Consolidated guidelines on the use of antiretroviral drugs for treating and preventing HIV infection: \n recommendations for a public health approach. Geneva: World Health Organization; 2013:272." 139 | ) 140 | 141 | 142 | def test_numbered_reference_splitter_line_endings(nra): 143 | """ 144 | Test case where there two line enedings immediately preceding the reference 145 | index. 146 | """ 147 | 148 | numbered_reference = { 149 | "text": "References\n\n1. \n Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168. \n\n2. \n WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126.\n\n3.", 150 | "tokens": [ 151 | {"text": "References", "start": 0, "end": 10, "id": 0}, 152 | {"text": "\n\n", "start": 10, "end": 12, "id": 1}, 153 | {"text": "1", "start": 12, "end": 13, "id": 2}, 154 | {"text": ".", "start": 13, "end": 14, "id": 3}, 155 | {"text": "\n ", "start": 15, "end": 17, "id": 4}, 156 | {"text": "Global", "start": 17, "end": 23, "id": 5}, 157 | {"text": "update", "start": 24, "end": 30, "id": 6}, 158 | {"text": "on", "start": 31, "end": 33, "id": 7}, 159 | {"text": "the", "start": 34, "end": 37, "id": 8}, 160 | {"text": "health", "start": 38, "end": 44, "id": 9}, 161 | {"text": "sector", "start": 45, "end": 51, "id": 10}, 162 | {"text": "response", "start": 52, "end": 60, "id": 11}, 163 | {"text": "to", "start": 61, "end": 63, "id": 12}, 164 | {"text": "HIV", "start": 64, "end": 67, "id": 13}, 165 | {"text": ",", "start": 67, "end": 68, "id": 14}, 166 | {"text": "2014", "start": 69, "end": 73, "id": 15}, 167 | {"text": ".", "start": 73, "end": 74, "id": 16}, 168 | {"text": "Geneva", "start": 75, "end": 81, "id": 17}, 169 | {"text": ":", "start": 81, "end": 82, "id": 18}, 170 | {"text": "World", "start": 83, "end": 88, "id": 19}, 171 | {"text": "Health", "start": 89, "end": 95, "id": 20}, 172 | {"text": "Organization", "start": 96, "end": 108, "id": 21}, 173 | {"text": ";", "start": 108, "end": 109, "id": 22}, 174 | {"text": "\n ", "start": 110, "end": 112, "id": 23}, 175 | {"text": "2014:168", "start": 112, "end": 120, "id": 24}, 176 | {"text": ".", "start": 120, "end": 121, "id": 25}, 177 | {"text": "\n\n", "start": 122, "end": 124, "id": 26}, 178 | {"text": "2", "start": 124, "end": 125, "id": 27}, 179 | {"text": ".", "start": 125, "end": 126, "id": 28}, 180 | {"text": "\n ", "start": 127, "end": 129, "id": 29}, 181 | {"text": "WHO", "start": 129, "end": 132, "id": 30}, 182 | {"text": ",", "start": 132, "end": 133, "id": 31}, 183 | {"text": "UNICEF", "start": 134, "end": 140, "id": 32}, 184 | {"text": ",", "start": 140, "end": 141, "id": 33}, 185 | {"text": "UNAIDS", "start": 142, "end": 148, "id": 34}, 186 | {"text": ".", "start": 148, "end": 149, "id": 35}, 187 | {"text": "Global", "start": 150, "end": 156, "id": 36}, 188 | {"text": "update", "start": 157, "end": 163, "id": 37}, 189 | {"text": "on", "start": 164, "end": 166, "id": 38}, 190 | {"text": "HIV", "start": 167, "end": 170, "id": 39}, 191 | {"text": "treatment", "start": 171, "end": 180, "id": 40}, 192 | {"text": "2013", "start": 181, "end": 185, "id": 41}, 193 | {"text": ":", "start": 185, "end": 186, "id": 42}, 194 | {"text": "results", "start": 187, "end": 194, "id": 43}, 195 | {"text": ",", "start": 194, "end": 195, "id": 44}, 196 | {"text": "impact", "start": 196, "end": 202, "id": 45}, 197 | {"text": "and", "start": 203, "end": 206, "id": 46}, 198 | {"text": "\n ", "start": 207, "end": 209, "id": 47}, 199 | {"text": "opportunities", "start": 209, "end": 222, "id": 48}, 200 | {"text": ".", "start": 222, "end": 223, "id": 49}, 201 | {"text": "Geneva", "start": 224, "end": 230, "id": 50}, 202 | {"text": ":", "start": 230, "end": 231, "id": 51}, 203 | {"text": "World", "start": 232, "end": 237, "id": 52}, 204 | {"text": "Health", "start": 238, "end": 244, "id": 53}, 205 | {"text": "Organization", "start": 245, "end": 257, "id": 54}, 206 | {"text": ";", "start": 257, "end": 258, "id": 55}, 207 | {"text": "2013:126", "start": 259, "end": 267, "id": 56}, 208 | {"text": ".", "start": 260, "end": 261, "id": 57}, 209 | {"text": "\n\n", "start": 261, "end": 263, "id": 58}, 210 | {"text": "3", "start": 262, "end": 264, "id": 59}, 211 | {"text": ".", "start": 263, "end": 265, "id": 60}, 212 | ], 213 | } 214 | 215 | docs = list(nra.run([numbered_reference])) 216 | text = docs[0]["text"] 217 | spans = docs[0]["spans"] 218 | ref_1 = text[spans[0]["start"] : spans[0]["end"]] 219 | ref_2 = text[spans[1]["start"] : spans[1]["end"]] 220 | 221 | assert len(spans) == 2 222 | assert ( 223 | ref_1.strip() 224 | == "Global update on the health sector response to HIV, 2014. Geneva: World Health Organization; \n 2014:168." 225 | ) 226 | assert ( 227 | ref_2.strip() 228 | == "WHO, UNICEF, UNAIDS. Global update on HIV treatment 2013: results, impact and \n opportunities. Geneva: World Health Organization; 2013:126" 229 | ) 230 | -------------------------------------------------------------------------------- /tests/prodigy/test_prodigy_entrypoints.py: -------------------------------------------------------------------------------- 1 | """Simple tests that entrypoints run. Functionality is tested in other more 2 | specific tests 3 | """ 4 | 5 | import os 6 | 7 | import pytest 8 | 9 | from deep_reference_parser.prodigy import ( 10 | annotate_numbered_references, 11 | prodigy_to_tsv, 12 | reach_to_prodigy, 13 | reference_to_token_annotations, 14 | ) 15 | 16 | from .common import TEST_NUMBERED_REFERENCES, TEST_TOKEN_LABELLED, TEST_REACH 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def tmpdir(tmpdir_factory): 21 | return tmpdir_factory.mktemp("data") 22 | 23 | 24 | def test_annotate_numbered_references_entrypoint(tmpdir): 25 | annotate_numbered_references( 26 | TEST_NUMBERED_REFERENCES, os.path.join(tmpdir, "references.jsonl") 27 | ) 28 | 29 | 30 | def test_prodigy_to_tsv(tmpdir): 31 | prodigy_to_tsv( 32 | TEST_TOKEN_LABELLED, 33 | os.path.join(tmpdir, "tokens.tsv"), 34 | respect_lines=False, 35 | respect_docs=True, 36 | ) 37 | 38 | 39 | def test_prodigy_to_tsv_multiple_inputs(tmpdir): 40 | prodigy_to_tsv( 41 | TEST_TOKEN_LABELLED + "," + TEST_TOKEN_LABELLED, 42 | os.path.join(tmpdir, "tokens.tsv"), 43 | respect_lines=False, 44 | respect_docs=True, 45 | ) 46 | 47 | 48 | def test_reach_to_prodigy(tmpdir): 49 | reach_to_prodigy(TEST_REACH, os.path.join(tmpdir, "prodigy.jsonl")) 50 | 51 | 52 | def test_reference_to_token_annotations(tmpdir): 53 | reference_to_token_annotations( 54 | TEST_NUMBERED_REFERENCES, os.path.join(tmpdir, "tokens.jsonl") 55 | ) 56 | -------------------------------------------------------------------------------- /tests/prodigy/test_reach_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | from deep_reference_parser.prodigy.reach_to_prodigy import ReachToProdigy 6 | 7 | 8 | @pytest.fixture(scope="function") 9 | def stp(): 10 | ref_sections = [{}, {}, {}] 11 | return ReachToProdigy(ref_sections) 12 | 13 | 14 | def test_combine_n_rows(stp): 15 | 16 | doc = list(range(100, 200)) 17 | out = stp.combine_n_rows(doc, n=5, join_char=" ") 18 | 19 | last_in_doc = doc[len(doc) - 1] 20 | last_in_out = int(out[-1].split(" ")[-1]) 21 | 22 | assert last_in_doc == last_in_out 23 | 24 | assert out[0] == "100 101 102 103 104" 25 | assert out[-2] == "190 191 192 193 194" 26 | assert out[-1] == "195 196 197 198 199" 27 | 28 | 29 | def test_combine_n_rows_uneven_split(stp): 30 | 31 | doc = list(range(100, 200)) 32 | out = stp.combine_n_rows(doc, n=7, join_char=" ") 33 | 34 | last_in_doc = doc[len(doc) - 1] 35 | last_in_out = int(out[-1].split(" ")[-1]) 36 | 37 | assert last_in_doc == last_in_out 38 | assert len(out[-1].split(" ")) == 2 39 | assert len(out[-2].split(" ")) == 7 40 | 41 | assert out[0] == "100 101 102 103 104 105 106" 42 | assert out[-2] == "191 192 193 194 195 196 197" 43 | assert out[-1] == "198 199" 44 | -------------------------------------------------------------------------------- /tests/prodigy/test_reference_to_token_annotations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | 6 | from deep_reference_parser.io import read_jsonl 7 | from deep_reference_parser.prodigy.reference_to_token_annotations import TokenTagger 8 | 9 | from .common import TEST_REF_EXPECTED_SPANS, TEST_REF_SPANS, TEST_REF_TOKENS 10 | 11 | 12 | @pytest.fixture(scope="function") 13 | def splitter(): 14 | return TokenTagger(task="splitting", text=False) 15 | 16 | 17 | @pytest.fixture(scope="function") 18 | def parser(): 19 | return TokenTagger(task="parsing", text=True) 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def doc(): 24 | doc = {} 25 | doc["tokens"] = read_jsonl(TEST_REF_TOKENS)[0] 26 | doc["spans"] = read_jsonl(TEST_REF_SPANS)[0] 27 | 28 | return doc 29 | 30 | 31 | @pytest.fixture(scope="module") 32 | def expected(): 33 | spans = read_jsonl(TEST_REF_EXPECTED_SPANS) 34 | 35 | return spans 36 | 37 | 38 | def test_TokenTagger(splitter): 39 | 40 | doc = dict() 41 | 42 | doc["spans"] = [ 43 | {"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "BE"}, 44 | ] 45 | 46 | doc["tokens"] = [ 47 | {"start": 0, "end": 0, "id": 0}, 48 | {"start": 1, "end": 1, "id": 1}, 49 | {"start": 2, "end": 2, "id": 2}, 50 | {"start": 3, "end": 3, "id": 3}, 51 | {"start": 4, "end": 4, "id": 4}, 52 | {"start": 5, "end": 5, "id": 5}, 53 | {"start": 6, "end": 6, "id": 6}, 54 | ] 55 | 56 | out = [ 57 | {"start": 0, "end": 0, "token_start": 0, "token_end": 0, "label": "o"}, 58 | {"start": 1, "end": 1, "token_start": 1, "token_end": 1, "label": "o"}, 59 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 60 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 61 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "e-r"}, 62 | {"start": 5, "end": 5, "token_start": 5, "token_end": 5, "label": "o"}, 63 | {"start": 6, "end": 6, "token_start": 6, "token_end": 6, "label": "o"}, 64 | ] 65 | 66 | tagged = splitter.run([doc]) 67 | 68 | assert out == tagged[0]["spans"] 69 | 70 | 71 | def test_create_span(splitter): 72 | 73 | tokens = [ 74 | {"start": 0, "end": 0, "id": 0}, 75 | {"start": 1, "end": 1, "id": 1}, 76 | {"start": 2, "end": 2, "id": 2}, 77 | ] 78 | 79 | after = {"start": 1, "end": 1, "token_start": 1, "token_end": 1, "label": "foo"} 80 | 81 | out = splitter.create_span(tokens=tokens, index=1, label="foo") 82 | 83 | assert out == after 84 | 85 | 86 | def test_split_long_span_three_token_span(splitter): 87 | 88 | tokens = [ 89 | {"start": 0, "end": 0, "id": 0}, 90 | {"start": 1, "end": 1, "id": 1}, 91 | {"start": 2, "end": 2, "id": 2}, 92 | {"start": 3, "end": 3, "id": 3}, 93 | {"start": 4, "end": 4, "id": 4}, 94 | {"start": 5, "end": 5, "id": 5}, 95 | {"start": 6, "end": 6, "id": 6}, 96 | ] 97 | 98 | span = {"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "BE"} 99 | 100 | expected = [ 101 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 102 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 103 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "e-r"}, 104 | ] 105 | 106 | actual = splitter.split_long_span( 107 | tokens, span, start_label="b-r", end_label="e-r", inside_label="i-r" 108 | ) 109 | 110 | assert expected == actual 111 | 112 | 113 | def test_split_long_span_two_token_span(splitter): 114 | 115 | tokens = [ 116 | {"start": 0, "end": 0, "id": 0}, 117 | {"start": 1, "end": 1, "id": 1}, 118 | {"start": 2, "end": 2, "id": 2}, 119 | {"start": 3, "end": 3, "id": 3}, 120 | {"start": 4, "end": 4, "id": 4}, 121 | {"start": 5, "end": 5, "id": 5}, 122 | {"start": 6, "end": 6, "id": 6}, 123 | ] 124 | 125 | span = {"start": 2, "end": 3, "token_start": 2, "token_end": 3, "label": "BE"} 126 | 127 | expected = [ 128 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 129 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "e-r"}, 130 | ] 131 | 132 | actual = splitter.split_long_span( 133 | tokens, span, start_label="b-r", end_label="e-r", inside_label="i-r" 134 | ) 135 | 136 | assert expected == actual 137 | 138 | 139 | def test_split_long_span_one_token_span(splitter): 140 | 141 | tokens = [ 142 | {"start": 0, "end": 0, "id": 0}, 143 | {"start": 1, "end": 1, "id": 1}, 144 | {"start": 2, "end": 2, "id": 2}, 145 | {"start": 3, "end": 3, "id": 3}, 146 | {"start": 4, "end": 4, "id": 4}, 147 | {"start": 5, "end": 5, "id": 5}, 148 | {"start": 6, "end": 6, "id": 6}, 149 | ] 150 | 151 | span = {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "BE"} 152 | 153 | expected = [ 154 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 155 | ] 156 | 157 | actual = splitter.split_long_span( 158 | tokens, span, start_label="b-r", end_label="e-r", inside_label="i-r" 159 | ) 160 | 161 | assert expected == actual 162 | 163 | 164 | def test_reference_spans_be(splitter): 165 | 166 | tokens = [ 167 | {"start": 0, "end": 0, "id": 0}, 168 | {"start": 1, "end": 1, "id": 1}, 169 | {"start": 2, "end": 2, "id": 2}, 170 | {"start": 3, "end": 3, "id": 3}, 171 | {"start": 4, "end": 4, "id": 4}, 172 | {"start": 5, "end": 5, "id": 5}, 173 | {"start": 6, "end": 6, "id": 6}, 174 | ] 175 | 176 | spans = [{"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "BE"}] 177 | 178 | after = [ 179 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 180 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 181 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "e-r"}, 182 | ] 183 | 184 | out = splitter.reference_spans(spans, tokens, task="splitting") 185 | 186 | assert out == after 187 | 188 | 189 | def test_reference_spans_bi(splitter): 190 | 191 | tokens = [ 192 | {"start": 0, "end": 0, "id": 0}, 193 | {"start": 1, "end": 1, "id": 1}, 194 | {"start": 2, "end": 2, "id": 2}, 195 | {"start": 3, "end": 3, "id": 3}, 196 | {"start": 4, "end": 4, "id": 4}, 197 | {"start": 5, "end": 5, "id": 5}, 198 | {"start": 6, "end": 6, "id": 6}, 199 | ] 200 | 201 | spans = [{"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "BI"}] 202 | 203 | after = [ 204 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "b-r"}, 205 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 206 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "i-r"}, 207 | ] 208 | 209 | out = splitter.reference_spans(spans, tokens, task="splitting") 210 | 211 | assert out == after 212 | 213 | 214 | def test_reference_spans_ie(splitter): 215 | 216 | tokens = [ 217 | {"start": 0, "end": 0, "id": 0}, 218 | {"start": 1, "end": 1, "id": 1}, 219 | {"start": 2, "end": 2, "id": 2}, 220 | {"start": 3, "end": 3, "id": 3}, 221 | {"start": 4, "end": 4, "id": 4}, 222 | {"start": 5, "end": 5, "id": 5}, 223 | {"start": 6, "end": 6, "id": 6}, 224 | ] 225 | 226 | spans = [{"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "IE"}] 227 | 228 | after = [ 229 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "i-r"}, 230 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 231 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "e-r"}, 232 | ] 233 | 234 | out = splitter.reference_spans(spans, tokens, task="splitting") 235 | 236 | assert out == after 237 | 238 | 239 | def test_reference_spans_ii(splitter): 240 | 241 | tokens = [ 242 | {"start": 0, "end": 0, "id": 0}, 243 | {"start": 1, "end": 1, "id": 1}, 244 | {"start": 2, "end": 2, "id": 2}, 245 | {"start": 3, "end": 3, "id": 3}, 246 | {"start": 4, "end": 4, "id": 4}, 247 | {"start": 5, "end": 5, "id": 5}, 248 | {"start": 6, "end": 6, "id": 6}, 249 | ] 250 | 251 | spans = [{"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "II"}] 252 | 253 | after = [ 254 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "i-r"}, 255 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 256 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "i-r"}, 257 | ] 258 | 259 | out = splitter.reference_spans(spans, tokens, task="splitting") 260 | 261 | assert out == after 262 | 263 | 264 | def test_reference_spans_parsing(splitter): 265 | 266 | tokens = [ 267 | {"start": 0, "end": 0, "id": 0}, 268 | {"start": 1, "end": 1, "id": 1}, 269 | {"start": 2, "end": 2, "id": 2}, 270 | {"start": 3, "end": 3, "id": 3}, 271 | {"start": 4, "end": 4, "id": 4}, 272 | {"start": 5, "end": 5, "id": 5}, 273 | {"start": 6, "end": 6, "id": 6}, 274 | ] 275 | 276 | spans = [ 277 | {"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "author"} 278 | ] 279 | 280 | after = [ 281 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "author"}, 282 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "author"}, 283 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "author"}, 284 | ] 285 | 286 | out = splitter.reference_spans(spans, tokens, task="parsing") 287 | 288 | assert out == after 289 | 290 | 291 | def test_reference_spans_parsing_single_token(splitter): 292 | 293 | tokens = [ 294 | {"start": 0, "end": 0, "id": 0}, 295 | {"start": 1, "end": 1, "id": 1}, 296 | {"start": 2, "end": 2, "id": 2}, 297 | {"start": 3, "end": 3, "id": 3}, 298 | {"start": 4, "end": 4, "id": 4}, 299 | {"start": 5, "end": 5, "id": 5}, 300 | {"start": 6, "end": 6, "id": 6}, 301 | ] 302 | 303 | spans = [ 304 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "author"}, 305 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "year"}, 306 | ] 307 | 308 | expected = [ 309 | {"start": 2, "end": 2, "token_start": 2, "token_end": 2, "label": "author"}, 310 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "year"}, 311 | ] 312 | 313 | actual = splitter.reference_spans(spans, tokens, task="parsing") 314 | 315 | print(actual) 316 | 317 | assert actual == expected 318 | 319 | 320 | def test_outside_spans(splitter): 321 | 322 | tokens = [ 323 | {"start": 0, "end": 0, "id": 0}, 324 | {"start": 1, "end": 1, "id": 1}, 325 | {"start": 2, "end": 2, "id": 2}, 326 | {"start": 3, "end": 3, "id": 3}, 327 | {"start": 4, "end": 4, "id": 4}, 328 | {"start": 5, "end": 5, "id": 5}, 329 | {"start": 6, "end": 6, "id": 6}, 330 | ] 331 | 332 | spans = [ 333 | {"start": 2, "end": 4, "token_start": 2, "token_end": 4, "label": "b-r"}, 334 | {"start": 3, "end": 3, "token_start": 3, "token_end": 3, "label": "i-r"}, 335 | {"start": 4, "end": 4, "token_start": 4, "token_end": 4, "label": "e-r"}, 336 | ] 337 | 338 | after = [ 339 | {"start": 0, "end": 0, "token_start": 0, "token_end": 0, "label": "o"}, 340 | {"start": 1, "end": 1, "token_start": 1, "token_end": 1, "label": "o"}, 341 | {"start": 5, "end": 5, "token_start": 5, "token_end": 5, "label": "o"}, 342 | {"start": 6, "end": 6, "token_start": 6, "token_end": 6, "label": "o"}, 343 | ] 344 | 345 | out = splitter.outside_spans(spans, tokens) 346 | 347 | assert out == after 348 | 349 | 350 | def test_reference_spans_real_example(doc, parser, expected): 351 | 352 | actual = parser.run([doc])[0]["spans"] 353 | assert actual == expected 354 | -------------------------------------------------------------------------------- /tests/prodigy/test_spacy_doc_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import en_core_web_sm 5 | import pytest 6 | import spacy 7 | from deep_reference_parser.prodigy.spacy_doc_to_prodigy import SpacyDocToProdigy 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def nlp(): 12 | return en_core_web_sm.load() 13 | 14 | 15 | def test_return_one_prodigy_doc_fails_if_passed_wrong_type(): 16 | 17 | with pytest.raises(TypeError): 18 | wrong_format = [ 19 | "this is the text", 20 | {"entities": [[0, 1, "PERSON"], [2, 4, "COMPANY"]]}, 21 | ] 22 | 23 | spacy_to_prodigy = SpacyDocToProdigy() 24 | spacy_to_prodigy.return_one_prodigy_doc(wrong_format) 25 | 26 | 27 | def test_SpacyDocToProdigy(nlp): 28 | 29 | # https://www.theguardian.com/world/2019/oct/30/pinochet-economic-model-current-crisis-chile 30 | before = nlp( 31 | "After 12 days of mass demonstrations, rioting and human rights violations, the government of President Sebastián Piñera must now find a way out of the crisis that has engulfed Chile." 32 | ) 33 | 34 | stp = SpacyDocToProdigy() 35 | actual = stp.run([before]) 36 | 37 | expected = [ 38 | { 39 | "text": "After 12 days of mass demonstrations, rioting and human rights violations, the government of President Sebastián Piñera must now find a way out of the crisis that has engulfed Chile.", 40 | "spans": [ 41 | { 42 | "token_start": 1, 43 | "token_end": 3, 44 | "start": 6, 45 | "end": 13, 46 | "label": "DATE", 47 | }, 48 | { 49 | "token_start": 17, 50 | "token_end": 19, 51 | "start": 103, 52 | "end": 119, 53 | "label": "PERSON", 54 | }, 55 | { 56 | "token_start": 31, 57 | "token_end": 32, 58 | "start": 176, 59 | "end": 181, 60 | "label": "GPE", 61 | }, 62 | ], 63 | "tokens": [ 64 | {"text": "After", "start": 0, "end": 5, "id": 0}, 65 | {"text": "12", "start": 6, "end": 8, "id": 1}, 66 | {"text": "days", "start": 9, "end": 13, "id": 2}, 67 | {"text": "of", "start": 14, "end": 16, "id": 3}, 68 | {"text": "mass", "start": 17, "end": 21, "id": 4}, 69 | {"text": "demonstrations", "start": 22, "end": 36, "id": 5}, 70 | {"text": ",", "start": 36, "end": 37, "id": 6}, 71 | {"text": "rioting", "start": 38, "end": 45, "id": 7}, 72 | {"text": "and", "start": 46, "end": 49, "id": 8}, 73 | {"text": "human", "start": 50, "end": 55, "id": 9}, 74 | {"text": "rights", "start": 56, "end": 62, "id": 10}, 75 | {"text": "violations", "start": 63, "end": 73, "id": 11}, 76 | {"text": ",", "start": 73, "end": 74, "id": 12}, 77 | {"text": "the", "start": 75, "end": 78, "id": 13}, 78 | {"text": "government", "start": 79, "end": 89, "id": 14}, 79 | {"text": "of", "start": 90, "end": 92, "id": 15}, 80 | {"text": "President", "start": 93, "end": 102, "id": 16}, 81 | {"text": "Sebastián", "start": 103, "end": 112, "id": 17}, 82 | {"text": "Piñera", "start": 113, "end": 119, "id": 18}, 83 | {"text": "must", "start": 120, "end": 124, "id": 19}, 84 | {"text": "now", "start": 125, "end": 128, "id": 20}, 85 | {"text": "find", "start": 129, "end": 133, "id": 21}, 86 | {"text": "a", "start": 134, "end": 135, "id": 22}, 87 | {"text": "way", "start": 136, "end": 139, "id": 23}, 88 | {"text": "out", "start": 140, "end": 143, "id": 24}, 89 | {"text": "of", "start": 144, "end": 146, "id": 25}, 90 | {"text": "the", "start": 147, "end": 150, "id": 26}, 91 | {"text": "crisis", "start": 151, "end": 157, "id": 27}, 92 | {"text": "that", "start": 158, "end": 162, "id": 28}, 93 | {"text": "has", "start": 163, "end": 166, "id": 29}, 94 | {"text": "engulfed", "start": 167, "end": 175, "id": 30}, 95 | {"text": "Chile", "start": 176, "end": 181, "id": 31}, 96 | {"text": ".", "start": 181, "end": 182, "id": 32}, 97 | ], 98 | } 99 | ] 100 | 101 | assert expected == actual 102 | -------------------------------------------------------------------------------- /tests/test_data/test_config.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = test 3 | 4 | [data] 5 | test_proportion = 0.25 6 | valid_proportion = 0.25 7 | data_path = data/processed/annotated/deep_reference_parser/ 8 | respect_line_endings = 0 9 | respect_doc_endings = 1 10 | line_limit = 250 11 | rodrigues_train = data/rodrigues/clean_test.txt 12 | rodrigues_test = 13 | rodrigues_valid = 14 | policy_train = data/2019.12.0_test.tsv 15 | policy_test = data/2019.12.0_test.tsv 16 | policy_valid = data/2019.12.0_test.tsv 17 | # This needs to have a trailing slash! 18 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 19 | 20 | [build] 21 | output_path = models/test/ 22 | output = crf 23 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-10-test.txt 24 | pretrained_embedding = 0 25 | dropout = 0.5 26 | lstm_hidden = 100 27 | word_embedding_size = 10 28 | char_embedding_size = 100 29 | char_embedding_type = BILSTM 30 | optimizer = rmsprop 31 | 32 | [train] 33 | epochs = 1 34 | batch_size = 100 35 | early_stopping_patience = 5 36 | metric = val_f1 37 | 38 | [evaluate] 39 | out_file = evaluation_data.tsv 40 | -------------------------------------------------------------------------------- /tests/test_data/test_config_multitask.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | version = test 3 | 4 | [data] 5 | test_proportion = 0.25 6 | valid_proportion = 0.25 7 | data_path = data/ 8 | respect_line_endings = 0 9 | respect_doc_endings = 1 10 | line_limit = 150 11 | rodrigues_train = data/rodrigues/clean_test.txt 12 | rodrigues_test = 13 | rodrigues_valid = 14 | policy_train = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv 15 | policy_test = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv 16 | policy_valid = data/processed/annotated/deep_reference_parser/multitask/2020.3.18_multitask_test.tsv 17 | # This needs to have a trailing slash! 18 | s3_slug = https://datalabs-public.s3.eu-west-2.amazonaws.com/deep_reference_parser/ 19 | 20 | [build] 21 | output_path = models/multitask/2020.4.5_multitask/ 22 | output = crf 23 | word_embeddings = embeddings/2020.1.1-wellcome-embeddings-300-test.txt 24 | pretrained_embedding = 0 25 | dropout = 0.5 26 | lstm_hidden = 400 27 | word_embedding_size = 300 28 | char_embedding_size = 100 29 | char_embedding_type = BILSTM 30 | optimizer = adam 31 | 32 | [train] 33 | epochs = 60 34 | batch_size = 100 35 | early_stopping_patience = 5 36 | metric = val_f1 37 | 38 | [evaluate] 39 | out_file = evaluation_data.tsv 40 | -------------------------------------------------------------------------------- /tests/test_data/test_jsonl.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 2 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 3 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 4 | -------------------------------------------------------------------------------- /tests/test_data/test_load_tsv.tsv: -------------------------------------------------------------------------------- 1 | the i-r a 2 | focus i-r a 3 | in i-r a 4 | Daloa i-r a 5 | , i-r a 6 | Côte i-r a 7 | d’Ivoire]. i-r a 8 | 9 | Bulletin i-r a 10 | de i-r a 11 | la i-r a 12 | Société i-r a 13 | de i-r a 14 | Pathologie i-r a 15 | 16 | Exotique i-r a 17 | et i-r a 18 | token 19 | 20 | 21 | -------------------------------------------------------------------------------- /tests/test_data/test_references.txt: -------------------------------------------------------------------------------- 1 | 1 Sibbald, A, Eason, W, McAdam, J, and Hislop, A (2001). The establishment phase of a silvopastoral national network experiment in the UK. Agroforestry systems, 39, 39–53. 2 | 2 Silva, J and Rego, F (2003). Root distribution of a Mediterranean shrubland in Portugal. Plant and Soil, 255 (2), 529–540. 3 | 3 Sims, R, Schock, R, Adegbululgbe, A, Fenhann, J, Konstantinaviciute, I, Moomaw, W, Nimir, H, Schlamadinger, B, Torres-Martínez, J, Turner, C, Uchiyama, Y, Vuori, S, Wamukonya, N, and X. Zhang (2007). Energy Supply. In Metz, B, Davidson, O, Bosch, P, Dave, R, and Meyer, L (eds.), Climate Change 2007: Mitigation. Contribution of Working Group III to the Fourth Assessment Report of the Intergovernmental Panel on Climate Change, Cambridge University Press, Cambridge, United Kingdom and New York, NY, USA. 4 | -------------------------------------------------------------------------------- /tests/test_data/test_tsv_predict.tsv: -------------------------------------------------------------------------------- 1 | the 2 | focus 3 | in 4 | Daloa 5 | , 6 | Côte 7 | d’Ivoire]. 8 | 9 | Bulletin 10 | de 11 | la 12 | Société 13 | de 14 | Pathologie 15 | 16 | Exotique 17 | et 18 | 19 | -------------------------------------------------------------------------------- /tests/test_data/test_tsv_train.tsv: -------------------------------------------------------------------------------- 1 | 2 | 3 | the i-r 4 | focus i-r 5 | in i-r 6 | Daloa i-r 7 | , i-r 8 | Côte i-r 9 | d’Ivoire]. i-r 10 | 11 | Bulletin i-r 12 | de i-r 13 | la i-r 14 | Société i-r 15 | de i-r 16 | Pathologie i-r 17 | 18 | Exotique i-r 19 | et i-r 20 | 21 | -------------------------------------------------------------------------------- /tests/test_deep_reference_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | import shutil 6 | import tempfile 7 | 8 | import pytest 9 | from deep_reference_parser import DeepReferenceParser, get_config, load_tsv 10 | from deep_reference_parser.common import download_model_artefact 11 | from wasabi import msg 12 | 13 | from .common import TEST_CFG, TEST_TSV_PREDICT, TEST_TSV_TRAIN 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def tmpdir(tmpdir_factory): 18 | return tmpdir_factory.mktemp("data") 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def cfg(): 23 | cfg = get_config(TEST_CFG) 24 | 25 | artefacts = [ 26 | "indices.pickle", 27 | "weights.h5", 28 | ] 29 | 30 | S3_SLUG = cfg["data"]["s3_slug"] 31 | OUTPUT_PATH = cfg["build"]["output_path"] 32 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 33 | 34 | for artefact in artefacts: 35 | with msg.loading(f"Could not find {artefact} locally, downloading..."): 36 | try: 37 | artefact = os.path.join(OUTPUT_PATH, artefact) 38 | download_model_artefact(artefact, S3_SLUG) 39 | msg.good(f"Found {artefact}") 40 | except: 41 | msg.fail(f"Could not download {S3_SLUG}{artefact}") 42 | 43 | # Check on word embedding and download if not exists 44 | 45 | WORD_EMBEDDINGS = cfg["build"]["word_embeddings"] 46 | 47 | with msg.loading(f"Could not find {WORD_EMBEDDINGS} locally, downloading..."): 48 | try: 49 | download_model_artefact(WORD_EMBEDDINGS, S3_SLUG) 50 | msg.good(f"Found {WORD_EMBEDDINGS}") 51 | except: 52 | msg.fail(f"Could not download {S3_SLUG}{WORD_EMBEDDINGS}") 53 | 54 | return cfg 55 | 56 | @pytest.mark.slow 57 | @pytest.mark.integration 58 | def test_DeepReferenceParser_train(tmpdir, cfg): 59 | """ 60 | This test creates the artefacts that will be used in the next test 61 | """ 62 | 63 | X_test, y_test = load_tsv(TEST_TSV_TRAIN) 64 | 65 | X_test = X_test[0:100] 66 | y_test = [y_test[0:100]] 67 | 68 | drp = DeepReferenceParser( 69 | X_train=X_test, 70 | X_test=X_test, 71 | X_valid=X_test, 72 | y_train=y_test, 73 | y_test=y_test, 74 | y_valid=y_test, 75 | max_len=250, 76 | output_path=tmpdir, 77 | 78 | ) 79 | 80 | # Prepare the data 81 | 82 | drp.prepare_data(save=True) 83 | 84 | # Build the model architecture 85 | 86 | drp.build_model( 87 | output=cfg["build"]["output"], 88 | word_embeddings=cfg["build"]["word_embeddings"], 89 | pretrained_embedding=cfg["build"]["pretrained_embedding"], 90 | dropout=float(cfg["build"]["dropout"]), 91 | lstm_hidden=int(cfg["build"]["lstm_hidden"]), 92 | word_embedding_size=int(cfg["build"]["word_embedding_size"]), 93 | char_embedding_size=int(cfg["build"]["char_embedding_size"]), 94 | ) 95 | 96 | # Train the model (quickly) 97 | 98 | drp.train_model( 99 | epochs=int(cfg["train"]["epochs"]), batch_size=int(cfg["train"]["batch_size"]) 100 | ) 101 | 102 | # Evaluate the model. This will write some evalutaion data to the 103 | # tempoary directory. 104 | 105 | drp.evaluate(load_weights=False, test_set=True, validation_set=True) 106 | 107 | examples = [ 108 | "This is an example".split(" "), 109 | "This is also an example".split(" "), 110 | "And so is this".split(" "), 111 | ] 112 | 113 | 114 | @pytest.mark.slow 115 | @pytest.mark.integration 116 | def test_DeepReferenceParser_predict(tmpdir, cfg): 117 | """ 118 | You must run this test after the previous one, or it will fail 119 | """ 120 | 121 | drp = DeepReferenceParser( 122 | # Nothign will be written here 123 | # output_path=cfg["build"]["output_path"] 124 | output_path=tmpdir 125 | ) 126 | 127 | # Load mapping dicts from the baseline model 128 | 129 | drp.load_data(tmpdir) 130 | 131 | # Build the model architecture 132 | 133 | drp.build_model( 134 | output=cfg["build"]["output"], 135 | word_embeddings=cfg["build"]["word_embeddings"], 136 | pretrained_embedding=False, 137 | dropout=float(cfg["build"]["dropout"]), 138 | lstm_hidden=int(cfg["build"]["lstm_hidden"]), 139 | word_embedding_size=int(cfg["build"]["word_embedding_size"]), 140 | char_embedding_size=int(cfg["build"]["char_embedding_size"]), 141 | ) 142 | 143 | examples = [ 144 | "This is an example".split(" "), 145 | "This is also an example".split(" "), 146 | "And so is this".split(" "), 147 | ] 148 | 149 | preds = drp.predict(examples, load_weights=True)[0] 150 | 151 | assert len(preds) == len(examples) 152 | 153 | assert len(preds[0]) == len(examples[0]) 154 | assert len(preds[1]) == len(examples[1]) 155 | assert len(preds[2]) == len(examples[2]) 156 | 157 | assert isinstance(preds[0][0], str) 158 | assert isinstance(preds[1][0], str) 159 | assert isinstance(preds[2][0], str) 160 | -------------------------------------------------------------------------------- /tests/test_deep_reference_parser_entrypoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | 6 | from deep_reference_parser.parse import Parser 7 | from deep_reference_parser.split import Splitter 8 | from deep_reference_parser.split_parse import SplitParser 9 | 10 | from .common import TEST_CFG, TEST_CFG_MULTITASK, TEST_REFERENCES 11 | 12 | 13 | @pytest.fixture 14 | def splitter(): 15 | return Splitter(TEST_CFG) 16 | 17 | 18 | @pytest.fixture 19 | def parser(): 20 | return Parser(TEST_CFG) 21 | 22 | 23 | @pytest.fixture 24 | def split_parser(): 25 | return SplitParser(TEST_CFG_MULTITASK) 26 | 27 | 28 | @pytest.fixture 29 | def text(): 30 | with open(TEST_REFERENCES, "r") as fb: 31 | text = fb.read() 32 | 33 | return text 34 | 35 | 36 | @pytest.mark.slow 37 | def test_splitter_list_output(text, splitter): 38 | """ 39 | Test that the splitter entrypoint works as expected. 40 | 41 | If the model artefacts and embeddings are not present this test will 42 | downloaded them, which can be slow. 43 | """ 44 | out = splitter.split(text, return_tokens=False, verbose=False) 45 | 46 | assert isinstance(out, list) 47 | 48 | 49 | @pytest.mark.slow 50 | def test_parser_list_output(text, parser): 51 | """ 52 | Test that the parser entrypoint works as expected. 53 | 54 | If the model artefacts and embeddings are not present this test will 55 | downloaded them, which can be slow. 56 | """ 57 | out = parser.parse(text, verbose=False) 58 | 59 | assert isinstance(out, list) 60 | 61 | 62 | @pytest.mark.slow 63 | def test_split_parser_list_output(text, split_parser): 64 | """ 65 | Test that the parser entrypoint works as expected. 66 | 67 | If the model artefacts and embeddings are not present this test will 68 | downloaded them, which can be slow. 69 | """ 70 | out = split_parser.split_parse(text, return_tokens=False, verbose=False) 71 | print(out) 72 | 73 | assert isinstance(out, list) 74 | 75 | 76 | def test_splitter_tokens_output(text, splitter): 77 | """ 78 | """ 79 | out = splitter.split(text, return_tokens=True, verbose=False) 80 | 81 | assert isinstance(out, list) 82 | assert isinstance(out[0], tuple) 83 | assert len(out[0]) == 2 84 | assert isinstance(out[0][0], str) 85 | assert isinstance(out[0][1], str) 86 | 87 | 88 | def test_parser_tokens_output(text, parser): 89 | """ 90 | """ 91 | out = parser.parse(text, verbose=False) 92 | 93 | assert isinstance(out, list) 94 | assert isinstance(out[0], tuple) 95 | assert len(out[0]) == 2 96 | assert isinstance(out[0][0], str) 97 | assert isinstance(out[0][1], str) 98 | 99 | 100 | def test_split_parser_tokens_output(text, split_parser): 101 | """ 102 | """ 103 | out = split_parser.split_parse(text, return_tokens=True, verbose=False) 104 | 105 | assert isinstance(out[0], tuple) 106 | assert len(out[0]) == 3 107 | assert isinstance(out[0][0], str) 108 | assert isinstance(out[0][1], str) 109 | assert isinstance(out[0][2], str) 110 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | import pytest 7 | 8 | from deep_reference_parser.io.io import ( 9 | read_jsonl, 10 | write_jsonl, 11 | load_tsv, 12 | write_tsv, 13 | _split_list_by_linebreaks, 14 | _unpack, 15 | ) 16 | from deep_reference_parser.reference_utils import yield_token_label_pairs 17 | 18 | from .common import TEST_JSONL, TEST_TSV_TRAIN, TEST_TSV_PREDICT, TEST_LOAD_TSV 19 | 20 | 21 | @pytest.fixture(scope="module") 22 | def tmpdir(tmpdir_factory): 23 | return tmpdir_factory.mktemp("data") 24 | 25 | def test_unpack(): 26 | 27 | before = [ 28 | ( 29 | ("token0", "token1", "token2", "token3"), 30 | ("label0", "label1", "label2", "label3") 31 | ), 32 | ( 33 | ("token0", "token1", "token2"), 34 | ("label0", "label1", "label2") 35 | ), 36 | ] 37 | 38 | expected = [ 39 | ( 40 | ("token0", "token1", "token2", "token3"), 41 | ("token0", "token1", "token2"), 42 | ), 43 | ( 44 | ("label0", "label1", "label2", "label3"), 45 | ("label0", "label1", "label2") 46 | ), 47 | ] 48 | 49 | actual = _unpack(before) 50 | 51 | assert expected == actual 52 | 53 | def test_write_tsv(tmpdir): 54 | 55 | expected = ( 56 | ( 57 | ("the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."), 58 | ("Bulletin", "de", "la", "Société", "de", "Pathologie"), 59 | ("Exotique", "et"), 60 | ), 61 | ( 62 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 63 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 64 | ("i-r", "i-r"), 65 | ), 66 | ) 67 | 68 | token_label_tuples = list(yield_token_label_pairs(expected[0], expected[1])) 69 | 70 | PATH = os.path.join(tmpdir, "test_tsv.tsv") 71 | write_tsv(token_label_tuples, PATH) 72 | actual = load_tsv(os.path.join(PATH)) 73 | 74 | assert expected == actual 75 | 76 | 77 | def test_load_tsv_train(): 78 | """ 79 | Text of TEST_TSV_TRAIN: 80 | 81 | ``` 82 | the i-r 83 | focus i-r 84 | in i-r 85 | Daloa i-r 86 | , i-r 87 | Côte i-r 88 | d’Ivoire]. i-r 89 | 90 | Bulletin i-r 91 | de i-r 92 | la i-r 93 | Société i-r 94 | de i-r 95 | Pathologie i-r 96 | 97 | Exotique i-r 98 | et i-r 99 | ``` 100 | """ 101 | 102 | expected = ( 103 | ( 104 | ("the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."), 105 | ("Bulletin", "de", "la", "Société", "de", "Pathologie"), 106 | ("Exotique", "et"), 107 | ), 108 | ( 109 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 110 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 111 | ("i-r", "i-r"), 112 | ), 113 | ) 114 | 115 | actual = load_tsv(TEST_TSV_TRAIN) 116 | 117 | assert len(actual[0][0]) == len(expected[0][0]) 118 | assert len(actual[0][1]) == len(expected[0][1]) 119 | assert len(actual[0][2]) == len(expected[0][2]) 120 | 121 | assert len(actual[1][0]) == len(expected[1][0]) 122 | assert len(actual[1][1]) == len(expected[1][1]) 123 | assert len(actual[1][2]) == len(expected[1][2]) 124 | 125 | assert actual == expected 126 | 127 | 128 | def test_load_tsv_predict(): 129 | """ 130 | Text of TEST_TSV_PREDICT: 131 | 132 | ``` 133 | the 134 | focus 135 | in 136 | Daloa 137 | , 138 | Côte 139 | d’Ivoire]. 140 | 141 | Bulletin 142 | de 143 | la 144 | Société 145 | de 146 | Pathologie 147 | 148 | Exotique 149 | et 150 | ``` 151 | """ 152 | 153 | expected = ( 154 | ( 155 | ("the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."), 156 | ("Bulletin", "de", "la", "Société", "de", "Pathologie"), 157 | ("Exotique", "et"), 158 | ), 159 | ) 160 | 161 | actual = load_tsv(TEST_TSV_PREDICT) 162 | 163 | assert actual == expected 164 | 165 | 166 | def test_load_tsv_train_multiple_labels(): 167 | """ 168 | Text of TEST_TSV_TRAIN: 169 | 170 | ``` 171 | the i-r a 172 | focus i-r a 173 | in i-r a 174 | Daloa i-r a 175 | , i-r a 176 | Côte i-r a 177 | d’Ivoire]. i-r a 178 | 179 | Bulletin i-r a 180 | de i-r a 181 | la i-r a 182 | Société i-r a 183 | de i-r a 184 | Pathologie i-r a 185 | 186 | Exotique i-r a 187 | et i-r a 188 | token 189 | 190 | ``` 191 | """ 192 | 193 | expected = ( 194 | ( 195 | ("the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."), 196 | ("Bulletin", "de", "la", "Société", "de", "Pathologie"), 197 | ("Exotique", "et"), 198 | ), 199 | ( 200 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 201 | ("i-r", "i-r", "i-r", "i-r", "i-r", "i-r"), 202 | ("i-r", "i-r"), 203 | ), 204 | ( 205 | ("a", "a", "a", "a", "a", "a", "a"), 206 | ("a", "a", "a", "a", "a", "a"), 207 | ("a", "a"), 208 | ), 209 | ) 210 | 211 | actual = load_tsv(TEST_LOAD_TSV) 212 | 213 | assert actual == expected 214 | 215 | 216 | def test_yield_toke_label_pairs(): 217 | 218 | tokens = [ 219 | [], 220 | ["the", "focus", "in", "Daloa", ",", "Côte", "d’Ivoire]."], 221 | ["Bulletin", "de", "la", "Société", "de", "Pathologie"], 222 | ["Exotique", "et"], 223 | ] 224 | 225 | labels = [ 226 | [], 227 | ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], 228 | ["i-r", "i-r", "i-r", "i-r", "i-r", "i-r"], 229 | ["i-r", "i-r"], 230 | ] 231 | 232 | expected = [ 233 | (None, None), 234 | ("the", "i-r"), 235 | ("focus", "i-r"), 236 | ("in", "i-r"), 237 | ("Daloa", "i-r"), 238 | (",", "i-r"), 239 | ("Côte", "i-r"), 240 | ("d’Ivoire].", "i-r"), 241 | (None, None), 242 | ("Bulletin", "i-r"), 243 | ("de", "i-r"), 244 | ("la", "i-r"), 245 | ("Société", "i-r"), 246 | ("de", "i-r"), 247 | ("Pathologie", "i-r"), 248 | (None, None), 249 | ("Exotique", "i-r"), 250 | ("et", "i-r"), 251 | (None, None), 252 | ] 253 | 254 | actual = list(yield_token_label_pairs(tokens, labels)) 255 | 256 | assert expected == actual 257 | 258 | 259 | def test_read_jsonl(): 260 | 261 | expected = [ 262 | { 263 | "text": "a b c\n a b c", 264 | "tokens": [ 265 | {"text": "a", "start": 0, "end": 1, "id": 0}, 266 | {"text": "b", "start": 2, "end": 3, "id": 1}, 267 | {"text": "c", "start": 4, "end": 5, "id": 2}, 268 | {"text": "\n ", "start": 5, "end": 7, "id": 3}, 269 | {"text": "a", "start": 7, "end": 8, "id": 4}, 270 | {"text": "b", "start": 9, "end": 10, "id": 5}, 271 | {"text": "c", "start": 11, "end": 12, "id": 6}, 272 | ], 273 | "spans": [ 274 | {"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, 275 | {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, 276 | {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, 277 | {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}, 278 | ], 279 | } 280 | ] 281 | 282 | expected = expected * 3 283 | 284 | actual = read_jsonl(TEST_JSONL) 285 | assert expected == actual 286 | 287 | 288 | def test_write_jsonl(tmpdir): 289 | 290 | expected = [ 291 | { 292 | "text": "a b c\n a b c", 293 | "tokens": [ 294 | {"text": "a", "start": 0, "end": 1, "id": 0}, 295 | {"text": "b", "start": 2, "end": 3, "id": 1}, 296 | {"text": "c", "start": 4, "end": 5, "id": 2}, 297 | {"text": "\n ", "start": 5, "end": 7, "id": 3}, 298 | {"text": "a", "start": 7, "end": 8, "id": 4}, 299 | {"text": "b", "start": 9, "end": 10, "id": 5}, 300 | {"text": "c", "start": 11, "end": 12, "id": 6}, 301 | ], 302 | "spans": [ 303 | {"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, 304 | {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, 305 | {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, 306 | {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}, 307 | ], 308 | } 309 | ] 310 | 311 | expected = expected * 3 312 | 313 | temp_file = os.path.join(tmpdir, "file.jsonl") 314 | 315 | write_jsonl(expected, temp_file) 316 | actual = read_jsonl(temp_file) 317 | 318 | assert expected == actual 319 | 320 | 321 | def test_split_list_by_linebreaks(): 322 | 323 | lst = ["a", "b", "c", None, "d"] 324 | expected = [["a", "b", "c"], ["d"]] 325 | 326 | actual = _split_list_by_linebreaks(lst) 327 | 328 | 329 | def test_list_by_linebreaks_ending_in_None(): 330 | 331 | lst = ["a", "b", "c", float("nan"), "d", None] 332 | expected = [["a", "b", "c"], ["d"]] 333 | 334 | actual = _split_list_by_linebreaks(lst) 335 | 336 | 337 | def test_list_by_linebreaks_starting_in_None(): 338 | 339 | lst = [None, "a", "b", "c", None, "d"] 340 | expected = [["a", "b", "c"], ["d"]] 341 | 342 | actual = _split_list_by_linebreaks(lst) 343 | -------------------------------------------------------------------------------- /tests/test_model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | 6 | from deep_reference_parser.model_utils import remove_padding_from_predictions 7 | 8 | 9 | def test_remove_pre_padding(): 10 | 11 | predictions = [ 12 | ["pad", "pad", "pad", "pad", "token", "token", "token"], 13 | ["pad", "pad", "pad", "pad", "pad", "token", "token"], 14 | ["pad", "pad", "pad", "pad", "pad", "pad", "token"], 15 | ] 16 | 17 | X = [ 18 | ["token", "token", "token"], 19 | ["token", "token"], 20 | ["token"], 21 | ] 22 | 23 | out = remove_padding_from_predictions(X, predictions, "pre") 24 | 25 | assert out == X 26 | 27 | 28 | def test_remove_post_padding(): 29 | 30 | predictions = [ 31 | ["token", "token", "token", "pad", "pad", "pad", "pad"], 32 | ["token", "token", "pad", "pad", "pad", "pad", "pad"], 33 | ["token", "pad", "pad", "pad", "pad", "pad", "pad"], 34 | ] 35 | X = [ 36 | ["token", "token", "token"], 37 | ["token", "token"], 38 | ["token"], 39 | ] 40 | 41 | out = remove_padding_from_predictions(X, predictions, "post") 42 | 43 | assert out == X 44 | -------------------------------------------------------------------------------- /tests/test_reference_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import pytest 5 | 6 | from deep_reference_parser.reference_utils import break_into_chunks 7 | 8 | 9 | def test_break_into_chunks(): 10 | 11 | before = ["a", "b", "c", "d", "e"] 12 | expected = [["a", "b"], ["c", "d"], ["e"]] 13 | 14 | actual = break_into_chunks(before, max_words=2) 15 | 16 | assert expected == actual 17 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37 3 | 4 | [testenv] 5 | deps=-rrequirements_test.txt 6 | commands=pytest --tb=line --cov=deep_reference_parser --cov-append --disable-warnings 7 | --------------------------------------------------------------------------------