├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── docs ├── guide_to_saber_api.md ├── img │ └── saber_logo.png ├── index.md ├── installation.md ├── quick_start.md └── resources.md ├── mkdocs.yml ├── notebooks └── lightning_tour.ipynb ├── saber ├── __init__.py ├── cli │ ├── __init__.py │ ├── app.py │ └── train.py ├── config.ini ├── config.py ├── constants.py ├── dataset.py ├── embeddings.py ├── metrics.py ├── models │ ├── __init__.py │ ├── base_model.py │ └── multi_task_lstm_crf.py ├── preprocessor.py ├── saber.py ├── tests │ ├── __init__.py │ ├── resources │ │ ├── __init__.py │ │ ├── dummy_config.ini │ │ ├── dummy_constants.py │ │ ├── dummy_dataset_1 │ │ │ ├── test.tsv │ │ │ ├── train.tsv │ │ │ └── valid.tsv │ │ ├── dummy_dataset_2 │ │ │ ├── test.tsv │ │ │ ├── train.tsv │ │ │ └── valid.tsv │ │ ├── dummy_word_embeddings │ │ │ └── dummy_word_embeddings.txt │ │ └── helpers.py │ ├── test_app_utils.py │ ├── test_base_model.py │ ├── test_config.py │ ├── test_data_utils.py │ ├── test_dataset.py │ ├── test_embeddings.py │ ├── test_generic_utils.py │ ├── test_grounding_utils.py │ ├── test_metrics.py │ ├── test_model_utils.py │ ├── test_multi_task_lstm_crf.py │ ├── test_preprocessor.py │ ├── test_saber.py │ ├── test_text_utils.py │ └── test_trainer.py ├── trainer.py └── utils │ ├── __init__.py │ ├── app_utils.py │ ├── data_utils.py │ ├── generic_utils.py │ ├── grounding_utils.py │ ├── model_utils.py │ └── text_utils.py ├── setup.cfg ├── setup.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | bin/ 19 | share/ 20 | include/ 21 | local/ 22 | man/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | pip-selfcheck.json 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # pytest 110 | .pytest_cache 111 | 112 | # IDE 113 | .idea 114 | *.iml 115 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | os: 4 | - linux 5 | - osx 6 | 7 | env: TOXENV=py 8 | 9 | matrix: 10 | include: 11 | - env: TOXENV=manifest 12 | - env: TOXENV=pyroma 13 | allow_failures: 14 | - os: osx 15 | - python: "3.5" 16 | 17 | python: 18 | - "3.6" 19 | - "3.5" 20 | 21 | install: pip install tox pytest-cov coveralls 22 | 23 | script: tox 24 | 25 | after_success: coveralls 26 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | **Working on your first Pull Request?** You can learn how from this *free* series [How to Contribute to an Open Source Project on GitHub](https://egghead.io/series/how-to-contribute-to-an-open-source-project-on-github) 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6 2 | WORKDIR /app 3 | COPY . /app 4 | RUN pip install . 5 | #RUN pip install git+https://github.com/BaderLab/saber.git 6 | RUN pip install git+https://www.github.com/keras-team/keras-contrib.git 7 | RUN pip install https://github.com/huggingface/neuralcoref-models/releases/download/en_coref_md-3.0.0/en_coref_md-3.0.0.tar.gz 8 | CMD ["python", "-m", "saber.cli.app"] 9 | EXPOSE 5000 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bader Lab, University of Toronto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include README(s) 2 | include README.md CONTRIBUTING.md 3 | 4 | # Include license file 5 | include LICENSE.txt 6 | 7 | # Include config file(s) 8 | global-include *.ini *.yml 9 | 10 | # Include Dockerfile 11 | include Dockerfile 12 | 13 | # Include test resources 14 | graft saber/tests 15 | 16 | # Include docs 17 | graft docs 18 | 19 | # Don't include notebooks 20 | prune notebooks 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 | Saber 7 |

8 | 9 |

10 | 11 | Travis CI 13 | 14 | 15 | Codacy Status 16 | 17 | 18 | Coverage Status 19 | 20 | 21 | PRs Welcome 22 | 23 | 24 | License 25 | 26 | 27 | Colab 28 | 29 | 30 | 31 | Slack 32 | 33 | 34 |

35 | 36 |

Saber (Sequence Annotator for Biomedical Entities and Relations) is a deep-learning based tool for information extraction in the biomedical domain. 37 |

38 | 39 |

40 | Installation • 41 | Quickstart • 42 | Documentation 43 |

44 | 45 | ## Installation 46 | 47 | _Note! This is a work in progress. Many things are broken, and the codebase is not stable._ 48 | 49 | To install Saber, you will need `python3.6`. 50 | 51 | ### Latest PyPI stable release 52 | 53 | [![PyPI-Status](https://img.shields.io/pypi/v/saber.svg?colorB=blue)](https://pypi.org/project/saber/) 54 | [![PyPI-Downloads](https://img.shields.io/pypi/dm/saber.svg?colorB=blue&logo=python&logoColor=white)](https://pypi.org/project/saber) 55 | [![Libraries-Dependents](https://img.shields.io/librariesio/dependent-repos/pypi/saber.svg?colorB=blue&logo=koding&logoColor=white)](https://github.com/baderlab/saber/network/dependents) 56 | 57 | ```sh 58 | (saber) $ pip install saber 59 | ``` 60 | 61 | > The install from PyPI is currently broken, please install using the instructions below. 62 | 63 | ### Latest development release on GitHub 64 | 65 | [![GitHub-Status](https://img.shields.io/github/tag-date/baderlab/saber.svg?logo=github)](https://github.com/baderlab/saber/releases) 66 | [![GitHub-Stars](https://img.shields.io/github/stars/baderlab/saber.svg?logo=github&label=stars)](https://github.com/baderlab/saber/stargazers) 67 | [![GitHub-Forks](https://img.shields.io/github/forks/baderlab/saber.svg?colorB=blue&logo=github&logoColor=white)](https://github.com/BaderLab/saber/network/members) 68 | [![GitHub-Commits](https://img.shields.io/github/commit-activity/y/baderlab/saber.svg?logo=git&logoColor=white)](https://github.com/baderlab/saber/graphs/commit-activity) 69 | [![GitHub-Updated](https://img.shields.io/github/last-commit/baderlab/saber.svg?colorB=blue&logo=github)](https://github.com/baderlab/saber/pulse) 70 | 71 | Pull and install straight from GitHub 72 | 73 | ```sh 74 | (saber) $ pip install git+https://github.com/BaderLab/saber.git 75 | ``` 76 | 77 | or install by cloning the repository 78 | 79 | ```sh 80 | (saber) $ git clone https://github.com/BaderLab/saber.git 81 | (saber) $ cd saber 82 | ``` 83 | 84 | and then using either `pip` 85 | 86 | ```sh 87 | (saber) $ pip install -e . 88 | ``` 89 | or `setuptools` 90 | 91 | ```sh 92 | (saber) $ python setup.py install 93 | ``` 94 | 95 | See the [documentation](https://baderlab.github.io/saber/installation/) for more detailed installation instructions. 96 | 97 | ## Quickstart 98 | 99 | If your goal is to use Saber to annotate biomedical text, then you can either use the [web-service](#web-service) or a [pre-trained model](#pre-trained-models). If you simply want to check Saber out, without installing anything locally, try the [Google Colaboratory](#google-colaboratory) notebook. 100 | 101 | ### Google Colaboratory 102 | 103 | The fastest way to check out Saber is by following along with the Google Colaboratory notebook ([![Colab](https://img.shields.io/badge/launch-Google%20Colab-orange.svg)](https://colab.research.google.com/drive/1WD7oruVuTo6p_908MQWXRBdLF3Vw2MPo)). In order to be able to run the cells, select "Open in Playground" or, alternatively, save a copy to your own Google Drive account (File > Save a copy in Drive). 104 | 105 | ### Web-service 106 | 107 | To use Saber as a **local** web-service, run 108 | 109 | ``` 110 | (saber) $ python -m saber.cli.app 111 | ``` 112 | 113 | or, if you prefer, you can pull & run the Saber image from **Docker Hub** 114 | 115 | ```sh 116 | # Pull Saber image from Docker Hub 117 | $ docker pull pathwaycommons/saber 118 | # Run docker (use `-dt` instead of `-it` to run container in background) 119 | $ docker run -it --rm -p 5000:5000 --name saber pathwaycommons/saber 120 | ``` 121 | 122 | There are currently two endpoints, `/annotate/text` and `/annotate/pmid`. Both expect a `POST` request with a JSON payload, e.g., 123 | 124 | ```json 125 | { 126 | "text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53." 127 | } 128 | ``` 129 | 130 | or 131 | 132 | ```json 133 | { 134 | "pmid": 11835401 135 | } 136 | ``` 137 | 138 | For example, running the web-service locally and using `cURL` 139 | 140 | ```sh 141 | $ curl -X POST 'http://localhost:5000/annotate/text' \ 142 | --data '{"text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."}' 143 | ``` 144 | 145 | Documentation for the Saber web-service API can be found [here](https://baderlab.github.io/saber-api-docs/). 146 | 147 | ### Pre-trained models 148 | 149 | First, import the `Saber` class. This is the interface to Saber 150 | 151 | ```python 152 | from saber.saber import Saber 153 | ``` 154 | 155 | then create a `Saber` object 156 | 157 | ```python 158 | saber = Saber() 159 | ``` 160 | 161 | and then load the model of our choice 162 | 163 | ```python 164 | saber.load('PRGE') 165 | ``` 166 | 167 | To annotate text with the model, just call the `Saber.annotate()` method 168 | 169 | ```python 170 | saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.") 171 | ``` 172 | See the [documentation](https://baderlab.github.io/saber/quick_start/#pre-trained-models) for more details on using pre-trained models. 173 | 174 | ## Documentation 175 | 176 | Documentation for the Saber package can be found [here](https://baderlab.github.io/saber/). The web-service API has its own documentation [here](https://baderlab.github.io/saber-api-docs/#introduction). 177 | 178 | You can also call `help()` on any Saber method for more information 179 | 180 | ```python 181 | from saber import Saber 182 | 183 | saber = Saber() 184 | 185 | help(saber.annotate) 186 | ``` 187 | 188 | or pass the `--help` flag to any of the command-line interfaces 189 | 190 | ``` 191 | python -m src.cli.train --help 192 | ``` 193 | 194 | Feel free to open an issue or reach out to us on our slack channel ([![Slack](https://img.shields.io/badge/slack-@saber--nlp-blueviolet.svg?logo=slack)](https://join.slack.com/t/saber-nlp/shared_invite/enQtNzE0MzY5ODM3MTc0LWZmY2VjMTY5MjllMmIzNDhkM2VhZjk5ODE1MDYyZjE5OGFjYWVhY2I2NDk5Yjk1N2Q3NTI4YTdhMTI5MjRiOGY)) for more help. 195 | 196 | 197 | -------------------------------------------------------------------------------- /docs/guide_to_saber_api.md: -------------------------------------------------------------------------------- 1 | # Guide to the Saber API 2 | 3 | You can interact with Saber as a web-service (explained in [Quick Start: Web-service](https://baderlab.github.io/saber/quick_start/#web-service)), [command line tool](#command-line-tool), or as a [python package](#python-package). If you created a virtual environment, _remember to activate it first_. 4 | 5 | ### Command line tool 6 | 7 | Currently, the command line tool simply trains the model. To use it, call 8 | 9 | ``` 10 | (saber) $ python -m saber.cli.train 11 | ``` 12 | 13 | along with any command line arguments. For example, to train the model on the [NCBI Disease](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) corpus 14 | 15 | ``` 16 | (saber) $ python -m saber.cli.train --dataset_folder NCBI_Disease_BIO 17 | ``` 18 | 19 | !!! tip 20 | See [Resources: Datasets](https://baderlab.github.io/saber/resources/#datasets) for help preparing datasets and word embeddings for training. 21 | 22 | Run `python -m saber.cli.train --help` to see all possible arguments. 23 | 24 | Of course, supplying arguments at the command line can quickly become cumbersome. Saber also allows you to provide a configuration file, which can be specified like so 25 | 26 | ``` 27 | (saber) $ python -m saber.cli.train --config_filepath path/to/config.ini 28 | ``` 29 | 30 | Copy the contents of the [default config file](https://github.com/BaderLab/saber/blob/master/saber/config.ini) to a new `*.ini` file in order to get started. 31 | 32 | !!! note 33 | Arguments supplied at the command line overwrite those found in the configuration file, e.g., 34 | 35 | ``` 36 | (saber) $ python -m saber.cli.train --dataset_folder path/to/dataset --k_folds 10 37 | ``` 38 | 39 | would overwrite the arguments for `dataset_folder` and `k_folds` found in the configuration file. 40 | 41 | ### Python package 42 | 43 | You can also import Saber and interact with it as a python package. Saber exposes its functionality through the `Saber` class. Here is just about everything Saber does in one script: 44 | 45 | ```python 46 | from saber.saber import Saber 47 | 48 | # First, create a Saber object, which exposes Sabers functionality 49 | saber = Saber() 50 | 51 | # Load a dataset and create a model (provide a list of datasets to use multi-task learning!) 52 | saber.load_dataset('path/to/datasets/GENIA') 53 | saber.build(model_name='MT-LSTM-CRF') 54 | 55 | # Train and save a model 56 | saber.train() 57 | saber.save('pretrained_models/GENIA') 58 | 59 | # Load a model 60 | del saber 61 | saber = Saber() 62 | saber.load('pretrained_models/GENIA') 63 | 64 | # Perform prediction on raw text, get resulting annotation 65 | raw_text = 'The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.' 66 | annotation = saber.annotate(raw_text) 67 | 68 | # Use transfer learning to continue training on a new dataset 69 | saber.load_dataset('path/to/datasets/CRAFT') 70 | saber.train() 71 | ``` 72 | 73 | #### Transfer learning 74 | 75 | Transfer learning is as easy as training, saving, loading, and then continuing training of a model. Here is an example 76 | 77 | ```python 78 | # Create and train a model on GENIA corpus 79 | saber = Saber() 80 | saber.load_dataset('path/to/datasets/GENIA') 81 | saber.build(model_name='MT-LSTM-CRF') 82 | saber.train() 83 | saber.save('pretrained_models/GENIA') 84 | 85 | # Load that model 86 | del saber 87 | saber = Saber() 88 | saber.load('pretrained_models/GENIA') 89 | 90 | # Use transfer learning to continue training on a new dataset 91 | saber.load_dataset('path/to/datasets/CRAFT') 92 | saber.train() 93 | ``` 94 | 95 | !!! note 96 | This is currently only supported by the `mt-lstm-crf` model. 97 | 98 | #### Multi-task learning 99 | 100 | Multi-task learning is as easy as specifying multiple dataset paths, either in the `config` file, at the command line via the flag `--dataset_folder`, or as an argument to `load_dataset()`. The number of datasets is arbitrary. 101 | 102 | Here is an example using the last method 103 | 104 | ```python 105 | saber = Saber() 106 | 107 | # Simply pass multiple dataset paths as a list to load_dataset to use multi-task learning. 108 | saber.load_dataset(['path/to/datasets/NCBI_Disease', 'path/to/datasets/Linnaeus']) 109 | 110 | saber.build(model_name='MT-LSTM-CRF') 111 | saber.train() 112 | ``` 113 | 114 | !!! note 115 | This is currently only supported by the `mt-lstm-crf` model. 116 | 117 | #### Training on GPUs 118 | 119 | Saber will automatically train on as many GPUs as are available. In order for this to work, you must have [CUDA](https://developer.nvidia.com/cuda-downloads) and, optionally, [CudDNN](https://developer.nvidia.com/cudnn) installed. If you are using conda to manage your environment, then these are installed for you when you call 120 | 121 | ``` 122 | (saber) $ conda install tensorflow-gpu 123 | ``` 124 | 125 | Otherwise, install them yourself and use `pip` to install `tensorflow-gpu` 126 | 127 | ``` 128 | (saber) $ pip install tensorflow-gpu 129 | ``` 130 | 131 | To control which GPUs Saber trains on, you can use the `CUDA_VISIBLE_DEVICES` environment variable, e.g., 132 | 133 | ``` 134 | # To train exclusively on CPU 135 | (saber) $ CUDA_VISIBLE_DEVICES="" python -m saber.cli.train 136 | 137 | # To train on 1 GPU with ID=0 138 | (saber) $ CUDA_VISIBLE_DEVICES="0" python -m saber.cli.train 139 | 140 | # To train on 2 GPUs with IDs=0,2 141 | (saber) $ CUDA_VISIBLE_DEVICES="0,2" python -m saber.cli.train 142 | ``` 143 | 144 | !!! tip 145 | You can get information about your NVIDIA GPUs by typing `nvidia-smi` at the command line (assuming the GPUs are setup properly and the nvidia driver is installed). 146 | 147 | #### Saving and loading models 148 | 149 | In the following sections we introduce the saving and loading of models. 150 | 151 | ##### Saving a model 152 | 153 | Assuming the model has already been created (see above), we can easily save our model like so 154 | 155 | ```python 156 | save_dir = 'path/to/pretrained_models/mymodel' 157 | saber.save(save_dir) 158 | ``` 159 | 160 | ##### Loading a model 161 | 162 | Lets illustrate loading a model with a new `Saber` object 163 | 164 | ```python 165 | # Delete our previous Saber object (if it exists) 166 | del saber 167 | # Create a new Saber object 168 | saber = Saber() 169 | # Load a previous model 170 | saber.load(path_to_saved_model) 171 | ``` 172 | -------------------------------------------------------------------------------- /docs/img/saber_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/docs/img/saber_logo.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | __Saber__ (**S**equence **A**nnotator for **B**iomedical **E**ntities and **R**elations) is a deep-learning based tool for __information extraction__ in the biomedical domain. 2 | 3 | The neural network model used is a BiLSTM-CRF [[1](https://arxiv.org/abs/1603.01360), [2](https://arxiv.org/abs/1603.01354)]; a state-of-the-art architecture for sequence labelling. The model is implemented using [Keras](https://keras.io/) / [Tensorflow](https://www.tensorflow.org). 4 | 5 | The goal is that Saber will eventually perform all the important steps in text-mining of biomedical literature: 6 | 7 | - Coreference resolution (:white_check_mark:) 8 | - Biomedical named entity recognition (BioNER) (:white_check_mark:) 9 | - Entity linking / grounding / normalization (:white_check_mark:) 10 | - Simple relation extraction (:soon:) 11 | - Event extraction (:soon:) 12 | 13 | Pull requests are welcome! If you encounter any bugs, please open an issue in the [GitHub repository](https://github.com/BaderLab/saber). 14 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | To install Saber, you will need `python3.6`. If not already installed, `python3` can be installed via 4 | 5 | - The [official installer](https://www.python.org/downloads/) 6 | - [Homebrew](https://brew.sh), on MacOS (`brew install python3`) 7 | - [Miniconda3](https://conda.io/miniconda.html) / [Anaconda3](https://www.anaconda.com/download/) 8 | 9 | !!! note 10 | Run `python --version` at the command line to make sure installation was successful. You may need to type `python3` (not just `python`) depending on your install method. 11 | 12 | (OPTIONAL) Activate your virtual environment (see [below](#optional-creating-and-activating-virtual-environments) for help) 13 | 14 | ```sh 15 | $ conda activate saber 16 | # Notice your command prompt has changed to indicate that the environment is active 17 | (saber) $ 18 | ``` 19 | 20 | ## Latest PyPI stable release 21 | 22 | [![PyPI-Status](https://img.shields.io/pypi/v/saber.svg?colorB=blue&style=flat-square)](https://pypi.org/project/saber/) 23 | [![PyPI-Downloads](https://img.shields.io/pypi/dm/saber.svg?colorB=blue&style=flat-square&logo=python&logoColor=white)](https://pypi.org/project/saber) 24 | [![Libraries-Dependents](https://img.shields.io/librariesio/dependent-repos/pypi/saber.svg?colorB=blue&style=flat-square&logo=koding&logoColor=white)](https://github.com/baderlab/saber/network/dependents) 25 | 26 | ```sh 27 | (saber) $ pip install saber 28 | ``` 29 | 30 | !!! error 31 | The install from PyPI is currently broken, please install using the instructions below. 32 | 33 | ## Latest development release on GitHub 34 | 35 | [![GitHub-Status](https://img.shields.io/github/tag-date/baderlab/saber.svg?logo=github&style=flat-square)](https://github.com/baderlab/saber/releases) 36 | [![GitHub-Stars](https://img.shields.io/github/stars/baderlab/saber.svg?logo=github&label=stars&style=flat-square)](https://github.com/baderlab/saber/stargazers) 37 | [![GitHub-Forks](https://img.shields.io/github/forks/baderlab/saber.svg?colorB=blue&logo=github&logoColor=white&style=flat-square)](https://github.com/BaderLab/saber/network/members) 38 | [![GitHub-Commits](https://img.shields.io/github/commit-activity/y/baderlab/saber.svg?logo=git&logoColor=white&style=flat-square)](https://github.com/baderlab/saber/graphs/commit-activity) 39 | [![GitHub-Updated](https://img.shields.io/github/last-commit/baderlab/saber.svg?colorB=blue&logo=github&style=flat-square)](https://github.com/baderlab/saber/pulse) 40 | 41 | Pull and install straight from GitHub 42 | 43 | ```sh 44 | (saber) $ pip install git+https://github.com/BaderLab/saber.git 45 | ``` 46 | 47 | or install by cloning the repository 48 | 49 | ```sh 50 | (saber) $ git clone https://github.com/BaderLab/saber.git 51 | (saber) $ cd saber 52 | ``` 53 | 54 | and then using either `pip` 55 | 56 | ```sh 57 | (saber) $ pip install -e . 58 | ``` 59 | or `setuptools` 60 | 61 | ```sh 62 | (saber) $ python setup.py install 63 | ``` 64 | 65 | !!! note 66 | See [Running tests](#running-tests) for a way to verify your installation. 67 | 68 | ## (OPTIONAL) Creating and activating virtual environments 69 | 70 | When using `pip` it is generally recommended to install packages in a virtual environment to avoid modifying system state. To create a virtual environment named `saber` 71 | 72 | ### Using virtualenv or venv 73 | 74 | Using [virtualenv](https://virtualenv.pypa.io/en/stable/) 75 | 76 | ``` 77 | $ virtualenv --python=python3 /path/to/new/venv/saber 78 | ``` 79 | 80 | Using [venv](https://docs.python.org/3/library/venv.html) 81 | 82 | ``` 83 | $ python3 -m venv /path/to/new/venv/saber 84 | ``` 85 | 86 | Next, you need to activate the environment. 87 | 88 | ``` 89 | $ source /path/to/new/venv/saber/bin/activate 90 | # Notice your command prompt has changed to indicate that the environment is active 91 | (saber) $ 92 | ``` 93 | 94 | ### Using Conda 95 | 96 | If you use [Conda](https://conda.io/docs/) / [Miniconda](https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh), you can create an environment named `saber` by running 97 | 98 | ``` 99 | $ conda create -n saber python=3.6 100 | ``` 101 | 102 | To activate the environment 103 | 104 | ``` 105 | $ conda activate saber 106 | # Notice your command prompt has changed to indicate that the environment is active 107 | (saber) $ 108 | ``` 109 | 110 | !!! note 111 | You do not _need_ to name the environment `saber`. 112 | 113 | ## Running tests 114 | 115 | Sabers test suite can be found in `saber/tests`. If Saber is already installed, you can run `pytest` on the installation directory 116 | 117 | ``` 118 | # Install pytest 119 | (saber) $ pip install pytest 120 | # Find out where Saber is installed 121 | (saber) $ INSTALL_DIR=$(python -c "import os; import saber; print(os.path.dirname(saber.__file__))") 122 | # Run tests on that installation directory 123 | (saber) $ python -m pytest $INSTALL_DIR 124 | ``` 125 | 126 | Alternatively, to clone Saber, install it, and run the test suite all in one go 127 | 128 | ``` 129 | (saber) $ git clone https://github.com/BaderLab/saber.git 130 | (saber) $ cd saber 131 | (saber) $ python setup.py test 132 | ``` 133 | -------------------------------------------------------------------------------- /docs/quick_start.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | If your goal is to use Saber to annotate biomedical text, then you can either use the [web-service](#web-service) or a [pre-trained model](#pre-trained-models). If you simply want to check Saber out, without installing anything locally, try the [Google Colaboratory](#google-colaboratory) notebook. 4 | 5 | ## Google Colaboratory 6 | 7 | The fastest way to check out Saber is by following along with the Google Colaboratory notebook ([![Colab](https://img.shields.io/badge/launch-Google%20Colab-orange.svg?style=flat-square)](https://colab.research.google.com/drive/1WD7oruVuTo6p_908MQWXRBdLF3Vw2MPo)). In order to be able to run the cells, select "Open in Playground" or, alternatively, save a copy to your own Google Drive account (File > Save a copy in Drive). 8 | 9 | ## Web-service 10 | 11 | To use Saber as a **local** web-service, run 12 | 13 | ``` 14 | (saber) $ python -m saber.cli.app 15 | ``` 16 | 17 | or, if you prefer, you can pull & run the Saber image from **Docker Hub** 18 | 19 | ``` 20 | # Pull Saber image from Docker Hub 21 | $ docker pull pathwaycommons/saber 22 | # Run docker (use `-dt` instead of `-it` to run container in background) 23 | $ docker run -it --rm -p 5000:5000 --name saber pathwaycommons/saber 24 | ``` 25 | 26 | !!! tip 27 | Alternatively, you can clone the GitHub repository and build the container from the `Dockerfile` with `docker build -t saber .` 28 | 29 | There are currently two endpoints, `/annotate/text` and `/annotate/pmid`. Both expect a `POST` request with a JSON payload, e.g. 30 | 31 | ```json 32 | { 33 | "text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53." 34 | } 35 | ``` 36 | 37 | or 38 | 39 | ```json 40 | { 41 | "pmid": 11835401 42 | } 43 | ``` 44 | 45 | For example, with the web-service running locally 46 | 47 | ``` bash tab="Bash" 48 | curl -X POST 'http://localhost:5000/annotate/text' \ 49 | --data '{"text": 'The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.'}' 50 | ``` 51 | 52 | ``` python tab="python" 53 | import requests # assuming you have requests package installed! 54 | 55 | url = "http://localhost:5000/annotate/pmid" 56 | payload = {"text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."} 57 | response = requests.post(url, json=payload) 58 | 59 | print(response.text) 60 | print(response.status_code, response.reason) 61 | ``` 62 | 63 | !!! warning 64 | The first request to the web-service will be slow (~60s). This is because a large language 65 | model needs to be loaded into memory. 66 | 67 | Documentation for the Saber web-service API can be found [here](https://baderlab.github.io/saber-api-docs/). We hope to provide a live version of the web-service soon! 68 | 69 | ## Pre-trained models 70 | 71 | First, import `Saber`. This class coordinates training, annotation, saving and loading of models and datasets. In short, this is the interface to Saber. 72 | 73 | ```python 74 | from saber.saber import Saber 75 | ``` 76 | 77 | To load a pre-trained model, first create a `Saber` object 78 | 79 | ```python 80 | saber = Saber() 81 | ``` 82 | 83 | and then load the model of our choice 84 | 85 | ```python 86 | saber.load('PRGE') 87 | ``` 88 | 89 | !!! tip 90 | See [Resources: Pre-trained models](../resources#pre-trained-models) for pre-trained model names and details. You will need an internet connection to download a pre-trained model. 91 | 92 | To annotate text with the model, just call the `Saber.annotate()` method 93 | 94 | ```python 95 | saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.") 96 | ``` 97 | 98 | !!! warning 99 | The `Saber.annotate()` method will be slow the first time you call it (~60s). This is because a large language model needs to be loaded into memory. 100 | 101 | ### Coreference Resolution 102 | 103 | [**Coreference**](https://en.wikipedia.org/wiki/Coreference) occurs when two or more expressions in a text refer to the same person or thing, that is, they have the same **referent**. Take the following example: 104 | 105 | _"__IL-6__ supports tumour growth and metastasising in terminal patients, and __it__ significantly engages in cancer cachexia (including anorexia) and depression associated with malignancy."_ 106 | 107 | Clearly, "__it__" referes to "__IL-6__". If we do not resolve this coreference, then "__it__" will not be labeled as an entity and any relation or event it is mentioned in will not be extracted. Saber uses [NeuralCoref](https://github.com/huggingface/neuralcoref), a state-of-the-art coreference resolution tool based on neural nets and built on top of [Spacy](https://spacy.io). To use it, just supply the argument `coref=True` (which is `False` by default) to the `Saber.annotate()` method 108 | 109 | ```python 110 | text = "IL-6 supports tumour growth and metastasising in terminal patients, and it significantly engages in cancer cachexia (including anorexia) and depression associated with malignancy." 111 | # WITHOUT coreference resolution 112 | saber.annotate(text, coref=False) 113 | # WITH coreference resolution 114 | saber.annotate(text, coref=True) 115 | ``` 116 | 117 | !!! note 118 | If you are using the web-service, simply supply `"coref": true` in your `JSON` payload to resolve coreferences. 119 | 120 | Saber currently takes the simplest possible approach: replace all coreference mentions with their referent, and then feed the resolved text to the model that identifies named entities. 121 | 122 | ### Grounding 123 | 124 | **Grounding** (sometimes called **entity linking** or **normalization**) involves mapping each annotated entity to a unique identifier in an external resource such as a database or ontology. To ground entities in a call to `Saber.annotate()`, simply pass the argument `ground=True` 125 | 126 | ```python 127 | saber.annotate('The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.', ground=True) 128 | ``` 129 | 130 | The grounding functionality is implemented by the [EXTRACT 2.0 API](https://extract.jensenlab.org/). Note that you will need an internet connection or grounding will fail. Also note that `Saber.annotate()` will take slightly longer to return a response when `ground=True` (up to a few seconds). 131 | 132 | See [Resources: Pre-trained models](../resources#pre-trained-models) for a list of the the external resources each entity type (annotated by the pre-trained models) is grounded to. 133 | 134 | !!! note 135 | If you are using the web-service, simply supply `"ground": true` in your `JSON` payload to ground entities. 136 | 137 | ### Working with annotations 138 | 139 | The `Saber.annotate()` method returns a simple `dict` object 140 | 141 | ```python 142 | ann = saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.") 143 | ``` 144 | 145 | which contains the keys `title`, `text` and `ents` 146 | 147 | - `title`: contains the title of the article, if provided 148 | - `text`: contains the text (which is minimally processed) the model was deployed on 149 | - `ents`: contains a list of entities present in the `text` that were annotated by the model 150 | 151 | For example, to see all entities annotated by the model, call 152 | 153 | ```python 154 | ann['ents'] 155 | ``` 156 | 157 | #### Converting annotations to JSON 158 | 159 | The `Saber.annotate()` method returns a `dict` object, but can be converted to a `JSON` formatted string for ease-of-use in downstream applications 160 | 161 | ```python 162 | import json 163 | 164 | # convert to json object 165 | json_ann = json.dumps(ann) 166 | 167 | # convert back to python dictionary 168 | ann = json.loads(json_ann) 169 | ``` 170 | -------------------------------------------------------------------------------- /docs/resources.md: -------------------------------------------------------------------------------- 1 | # Resources 2 | 3 | Saber is ready to go out-of-the box when using the __web-service__ or a __pre-trained model__. However, if you plan on training you own models, you will need to provide a dataset (or datasets!) and, ideally, pre-trained word embeddings. 4 | 5 | ## Pre-trained models 6 | 7 | Pre-trained model names can be passed to `Saber.load()` (see [Quick Start: Pre-trained Models](https://baderlab.github.io/saber/quick_start/#pre-trained-models)). Appending `"*-large"` to the model name (e.g. `"PRGE-large"` will download a much larger model, which should perform slightly better than the base model. 8 | 9 | Identifier | Semantic Group | Identified entity types | Namespace | 10 | ---------- | -------------- | ----------------------- | --------- | 11 | `CHED` | Chemicals | Abbreviations and Acronyms, Molecular Formulas, Chemical database identifiers, IUPAC names, Trivial (common names of chemicals and trademark names), Family (chemical families with a defined structure) and Multiple (non-continuous mentions of chemicals in text) | [PubChem Compounds](https://pubchem.ncbi.nlm.nih.gov/) 12 | `DISO` | Disorders | Acquired Abnormality, Anatomical Abnormality, Cell or Molecular Dysfunction, Congenital Abnormality, Disease or Syndrome, Mental or Behavioral Dysfunction, Neoplastic Process, Pathologic Function, Sign or Symptom | [Disease Ontology](http://disease-ontology.org/) 13 | `LIVB` | Organisms | Species, Taxa | [NCBI Taxonomy](https://www.ncbi.nlm.nih.gov/taxonomy) 14 | `PRGE` | Genes and Gene Products | Genes, Gene Products | [STRING](https://string-db.org/) 15 | 16 | ## Datasets 17 | 18 | Currently, Saber requires corpora to be in a **CoNLL** format with a BIO or IOBES tag scheme, e.g.: 19 | 20 | ``` 21 | Selegiline B-CHED 22 | - O 23 | induced O 24 | postural B-DISO 25 | hypotension I-DISO 26 | ... 27 | ``` 28 | 29 | Corpora in such a format are collected in [here](https://github.com/BaderLab/Biomedical-Corpora) for convenience. 30 | 31 | !!! info 32 | Many of the corpora in the BIO and IOBES tag format were originally collected by [Crichton _et al_., 2017](https://doi.org/10.1186/s12859-017-1776-8), [here](https://github.com/cambridgeltl/MTL-Bioinformatics-2016). 33 | 34 | In this format, the first column contains each token of an input sentence, the last column contains the tokens tag, all columns are separated by tabs, and all sentences by a newline. 35 | 36 | Of course, not all corpora are distributed in the CoNLL format: 37 | 38 | - Corpora in the **Standoff** format can be converted to **CoNLL** format using [this](https://github.com/spyysalo/standoff2conll) tool. 39 | - Corpora in **PubTator** format can be converted to **Standoff** first using [this](https://github.com/spyysalo/pubtator) tool. 40 | 41 | Saber infers the "training strategy" based on the structure of the dataset folder: 42 | 43 | - To use k-fold cross-validation, simply provide a `train.*` file in your dataset folder. 44 | 45 | E.g. 46 | ``` 47 | . 48 | ├── NCBI_Disease 49 | │   └── train.tsv 50 | ``` 51 | 52 | - To use a train/valid/test strategy, provide `train.*` and `test.*` files in your dataset folder. Optionally, you can provide a `valid.*` file. If not provided, a random 10% of examples from `train.*` are used as the validation set. 53 | 54 | E.g. 55 | ``` 56 | . 57 | ├── NCBI_Disease 58 | │   ├── test.tsv 59 | │   └── train.tsv 60 | ``` 61 | 62 | ## Word embeddings 63 | 64 | When training new models, you can (and should) provide your own pre-trained word embeddings with the `pretrained_embeddings` argument (either at the command line or in the configuration file). Saber expects all word embeddings to be in the `word2vec` file format. [Pyysalo _et al_. 2013](https://pdfs.semanticscholar.org/e2f2/8568031e1902d4f8ee818261f0f2c20de6dd.pdf) provide word embeddings that work quite well in the biomedical domain, which can be downloaded [here](http://bio.nlplab.org). Alternatively, from the command line call: 65 | 66 | ``` 67 | # Replace this with a location you want to save the embeddings to 68 | $ mkdir path/to/word_embeddings 69 | # Note: this file is over 4GB 70 | $ wget http://evexdb.org/pmresources/vec-space-models/wikipedia-pubmed-and-PMC-w2v.bin -O path/to/word_embeddings 71 | ``` 72 | 73 | To use these word embeddings with Saber, provide their path in the `pretrained_embeddings` argument (either in the `config` file or at the command line). Alternatively, pass their path to `Saber.load_embeddings()`. For example: 74 | 75 | ```python 76 | from saber.saber import Saber 77 | 78 | saber = Saber() 79 | 80 | saber.load_dataset('path/to/dataset') 81 | # load the embeddings here 82 | saber.load_embeddings('path/to/word_embeddings') 83 | 84 | saber.build() 85 | saber.train() 86 | ``` 87 | 88 | ### GloVe 89 | 90 | To use [GloVe](https://nlp.stanford.edu/projects/glove/) embeddings, just convert them to the [word2vec](https://code.google.com/archive/p/word2vec/) format first: 91 | 92 | ``` 93 | (saber) $ python 94 | >>> from gensim.scripts.glove2word2vec import glove2word2vec 95 | >>> glove_input_file = 'glove.txt' 96 | >>> word2vec_output_file = 'word2vec.txt' 97 | >>> glove2word2vec(glove_input_file, word2vec_output_file) 98 | ``` 99 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: Saber 3 | site_description: 'Documentation for Saber' 4 | site_author: 'John Giorgi' 5 | site_url: 'https://baderlab.github.io/saber/' 6 | 7 | # Repository 8 | repo_name: 'BaderLab/saber' 9 | repo_url: 'https://github.com/BaderLab/saber' 10 | 11 | # Site map 12 | nav: 13 | - About: index.md 14 | - Installation: installation.md 15 | - Quick Start: quick_start.md 16 | - Guide to the Saber API: guide_to_saber_api.md 17 | - Resources: resources.md 18 | 19 | # Configuration 20 | theme: 21 | language: 'en' 22 | name: 'material' 23 | logo: 24 | icon: 'memory' 25 | palette: 26 | primary: 'deep purple' 27 | accent: 'deep purple' 28 | 29 | # Customization 30 | extra: 31 | social: 32 | - type: 'github' 33 | link: 'https://github.com/BaderLab' 34 | 35 | # Extensions 36 | markdown_extensions: 37 | - admonition 38 | - codehilite 39 | - pymdownx.superfences 40 | - pymdownx.details 41 | - pymdownx.emoji: 42 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 43 | - toc: 44 | permalink: true 45 | -------------------------------------------------------------------------------- /saber/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from datetime import datetime 5 | 6 | # set Tensorflow logging level 7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 8 | 9 | # if applicable, delete the existing log file to generate a fresh log file during each execution 10 | try: 11 | os.remove("saber.log") 12 | except OSError: 13 | pass 14 | 15 | # create the logger 16 | logging.basicConfig(filename="saber.log", 17 | level=logging.DEBUG, 18 | format='%(name)s - %(levelname)s - %(message)s') 19 | 20 | # log the date to start 21 | logging.info('Saber invocation: %s\n%s', datetime.now(), '=' * 75) 22 | -------------------------------------------------------------------------------- /saber/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/cli/__init__.py -------------------------------------------------------------------------------- /saber/cli/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Simple web service which exposes Saber's functionality via a RESTful API. 3 | """ 4 | import argparse 5 | import logging 6 | 7 | from flask import Flask, jsonify, redirect, request 8 | from waitress import serve 9 | 10 | from .. import constants 11 | from ..utils import app_utils 12 | 13 | app = Flask(__name__) 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | @app.route('/') 18 | def serve_api_docs(): 19 | """Flask view that redirects to the Saber API docs from route '/'. 20 | """ 21 | return redirect('https://baderlab.github.io/saber-api-docs/') 22 | 23 | @app.route('/annotate/text', methods=['POST']) 24 | def annotate_text(): 25 | """Annotates raw text recieved in a POST request. 26 | 27 | Returns: 28 | JSON formatted string. 29 | """ 30 | parsed_request_json = app_utils.parse_request_json(request) 31 | # raw text to perform annotation on 32 | text = parsed_request_json['text'] 33 | 34 | annotation = predict(text=text, 35 | ents=parsed_request_json['ents'], 36 | coref=parsed_request_json['coref'], 37 | ground=parsed_request_json['ground']) 38 | 39 | return jsonify(annotation) 40 | 41 | @app.route('/annotate/pmid', methods=['POST']) 42 | def annotate_pmid(): 43 | """Annotates the abstract of a document with the PubMed ID recieved in a POST request. 44 | 45 | Returns: 46 | JSON formatted string. 47 | """ 48 | parsed_request_json = app_utils.parse_request_json(request) 49 | # use Entrez Utilities Web Service API to get the abtract text 50 | title, abstract = app_utils.get_pubmed_text(parsed_request_json['pmid']) 51 | text = '{}\n{}'.format(title, abstract) 52 | 53 | annotation = predict(text=text, 54 | ents=parsed_request_json['ents'], 55 | coref=parsed_request_json['coref'], 56 | ground=parsed_request_json['ground']) 57 | 58 | return jsonify(annotation) 59 | 60 | def predict(text, ents, coref=False, ground=False): 61 | """Annotates raw text (`text`) for entities according to their boolean value in `ents`. 62 | 63 | Args: 64 | text (str): Raw text to be annotated. 65 | ents: Dictionary of entity, boolean pairs representing whether or not to annotate the text 66 | for the given entities. 67 | 68 | Returns: 69 | Dictionary containing the annotated entities and processed text. 70 | """ 71 | annotations = [] 72 | for ent, value in ents.items(): 73 | if value: 74 | # TEMP: Weird solution to a weird bug 75 | # https://github.com/tensorflow/tensorflow/issues/14356#issuecomment-385962623 76 | with GRAPH.as_default(): 77 | annotations.append(MODELS[ent].annotate(text, coref=coref, ground=ground)) 78 | 79 | # if multiple models, combine annotations into one object 80 | final_annotation = annotations[0] 81 | if len(annotations) > 1: 82 | combined_ents = app_utils.combine_annotations(annotations) 83 | final_annotation['ents'] = combined_ents 84 | 85 | return final_annotation 86 | 87 | if __name__ == '__main__': 88 | # parse command line arguments 89 | parser = argparse.ArgumentParser(description='Saber web-service.') 90 | parser.add_argument('-p', '--port', help='Port number. Defaults to 5000', type=int, default=5000) 91 | args = vars(parser.parse_args()) 92 | # load the pre-trained models 93 | MODELS, GRAPH = app_utils.load_models(constants.ENTITIES) 94 | 95 | serve(app, host='0.0.0.0', port=args['port']) 96 | -------------------------------------------------------------------------------- /saber/cli/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """A simple script to train a model with Saber. 3 | 4 | Run the script with: 5 | ``` 6 | python -m saber.cli.train 7 | ``` 8 | e.g. 9 | ``` 10 | python -m saber.cli.train --dataset_folder ./datasets/NCBI_disease_BIO --epochs 25 11 | ``` 12 | """ 13 | import logging 14 | 15 | from ..config import Config 16 | from ..saber import Saber 17 | 18 | 19 | def main(): 20 | """Coordinates a complete training cycle, including reading in a config, loading dataset(s), 21 | training the model, and saving the models weights.""" 22 | config = Config(cli=True) 23 | saber = Saber(config) 24 | 25 | if config.pretrained_model: 26 | saber.load(config.pretrained_model) 27 | 28 | saber.load_dataset() 29 | 30 | # don't build a new model if pre-trained one was provided 31 | if not config.pretrained_model: 32 | # don't load embeddings if a pre-trained model was provided 33 | if config.pretrained_embeddings: 34 | saber.load_embeddings() 35 | saber.build() 36 | 37 | try: 38 | saber.train() 39 | except KeyboardInterrupt: 40 | print("\nQutting Saber...") 41 | logging.warning('Saber was terminated early due to KeyboardInterrupt') 42 | finally: 43 | if config.save_model: 44 | saber.save() 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /saber/config.ini: -------------------------------------------------------------------------------- 1 | [mode] 2 | # Possible models: [MT-LSTM-CRF, ] 3 | model_name = MT-LSTM-CRF 4 | # If True, model is compressed and saved in output_folder at the end of training. Model weights are 5 | # taken from the last epoch. 6 | save_model = False 7 | 8 | [data] 9 | # You can specify multiple datasets by listing their paths, separated by a comma. 10 | dataset_folder = ./datasets/NCBI_Disease_BIO 11 | output_folder = ./output 12 | # Path to pre-trained model. To train a model from scratch, leave this blank. 13 | pretrained_model = 14 | # Path to pre-trained word embeddings. In order to use random initialization, leave this blank. 15 | # Note that you can leave this blank when loading a pre-trained model 16 | # (via 'SequenceProcessor.load()') that was trained with pre-tained embeddings 17 | pretrained_embeddings = 18 | 19 | [model] 20 | # If pre-trained word embeddings are provided, word_embed_dim will be the same size as these 21 | # embeddings and this argument will be ignored. 22 | word_embed_dim = 200 23 | char_embed_dim = 30 24 | 25 | [training] 26 | # Values chosen for each hyperparameter represent sensible defaults that perform well across a wide 27 | # range of NLP tasks (POS tagging, Chunking, NER, etc.) and thus should only be changed in special 28 | # circumstances. 29 | optimizer = nadam 30 | activation = relu 31 | # Set to 0 to turn off gradient normalization. 32 | grad_norm = 1.0 33 | # For certain optimizers, these values are ignored. See compile_model() in 34 | # saber/utils/model_utils.py. 35 | learning_rate = 0.0 36 | decay = 0.0 37 | 38 | # Three dropout values must be specified (separated by a comma), corresponding to the dropout rate 39 | # to apply to the input, output and recurrent connections respectively. Must be a value between 0.0 40 | # and 1.0. 41 | dropout_rate = 0.3, 0.3, 0.1 42 | 43 | batch_size = 32 44 | # If a test partition is supplied at 'dataset_folder' (test.*) then this argument is ignored, and a 45 | # simple train/valid/test scheme is used. A valid partition (valid.*) may optionally be provided 46 | # along with the test partition. If none is found, 10% of examples are randomly selected. 47 | k_folds = 5 48 | epochs = 50 49 | 50 | # Matching criteria used when determining whether or not a prediction is a true-positive. Choices 51 | # are 'left' for left-boundary matching, 'right' for right-boundary matching and 'exact' for 52 | # exact-boundary matching. 53 | criteria = exact 54 | 55 | [advanced] 56 | verbose = False 57 | debug = False 58 | # If True, per-epoch logs which can be visualized with TensorBoard are written to output_folder 59 | # Note: These logs can be quite large. 60 | tensorboard = False 61 | # If True, then during training the models weights for each and every epoch will be saved. 62 | # Otherwise, weights are only saved for epochs that achieve a new best on validation loss. 63 | save_all_weights = False 64 | # If True, tokens that occur less than 1 time in the training dataset (hapax legomenon) are replaced 65 | # with a special unknown token. This should result in faster loading times of pre-trained word 66 | # embeddings and faster training times. 67 | replace_rare_tokens = True 68 | # If True, then all pre-trained word embeddings provided via pretrained_embeddings are loaded. 69 | # Otherwise, only pre-trained embeddings for tokens found in the dataset(s) at dataset_folder are 70 | # loaded. For evaluation, it's best to leave this as False. For models that will be deployed, it's 71 | # best to set it to True. 72 | load_all_embeddings = False 73 | # If True, then pre-trained word embeddings will be fine-tuned with the other parameters of the 74 | # neural network during training. Generally, you should not set this to True unless you have a very 75 | # large training dataset. 76 | # NOTE: if 'pretrained_embeddings' are not provided, they will be randomly initialized and 77 | # fine-tuned during training, ignoring this argument. 78 | fine_tune_word_embeddings = False 79 | # TEMP. Set to true if variational dropout should be used. 80 | variational_dropout = False 81 | -------------------------------------------------------------------------------- /saber/constants.py: -------------------------------------------------------------------------------- 1 | """Collection of constants used by Saber. 2 | """ 3 | from pkg_resources import resource_filename 4 | 5 | __version__ = '0.1.0-alpha' 6 | 7 | # DISPLACY OPTIONS 8 | # entity colours 9 | COLOURS = {'PRGE': 'linear-gradient(90deg, #aa9cfc, #fc9ce7)', 10 | 'DISO': 'linear-gradient(90deg, #ef9a9a, #f44336)', 11 | 'CHED': 'linear-gradient(90deg, #1DE9B6, #A7FFEB)', 12 | 'LIVB': 'linear-gradient(90deg, #FF4081, #F8BBD0)', 13 | 'CL': 'linear-gradient(90deg, #00E5FF, #84FFFF)', 14 | } 15 | # entity options 16 | OPTIONS = {'colors': COLOURS} 17 | 18 | # SPECIAL TOKENS 19 | UNK = '' # out-of-vocabulary token 20 | PAD = '' # sequence pad token 21 | START = '' # start-of-sentence token 22 | END = '' # end-of-sentence token 23 | OUTSIDE_TAG = 'O' # 'outside' tag of the IOB, BIO, and IOBES tag formats 24 | 25 | # MISC. 26 | PAD_VALUE = 0 # value of sequence pad 27 | NUM_RARE = 1 # tokens that occur less than NUM_RARE times are replaced UNK 28 | # mapping of special tokens to contants 29 | INITIAL_MAPPING = {'word': {PAD: 0, UNK: 1}, 'tag': {PAD: 0}} 30 | # keys into dictionaries containing information for different partitions of a dataset 31 | PARTITIONS = ['train', 'valid', 'test'] 32 | 33 | # FILEPATHS / FILENAMES 34 | # train, valid and test filename patterns 35 | TRAIN_FILE = 'train.*' 36 | VALID_FILE = 'valid.*' 37 | TEST_FILE = 'test.*' 38 | # pre-trained models 39 | ENTITIES = {'ANAT': False, 40 | 'CHED': True, 41 | 'DISO': True, 42 | 'LIVB': False, 43 | 'PRGE': True, 44 | 'TRIG': False} 45 | # Google Drive File IDs for the pre-trained models 46 | PRETRAINED_MODELS = { 47 | 'PRGE': '1xOmxpgNjQJK8OJSvih9wW5AITGQX6ODT', 48 | 'DISO': '1qmrBuqz75KM57Ug5MiDBfp0d5H3S_5ih', 49 | 'CHED': '13s9wvu3Mc8fG73w51KD8RArA31vsuL1c', 50 | } 51 | # relative path to pre-trained model directory 52 | PRETRAINED_MODEL_DIR = resource_filename(__name__, 'pretrained_models') 53 | MODEL_FILENAME = 'model_params.json' 54 | WEIGHTS_FILENAME = 'model_weights.hdf5' 55 | ATTRIBUTES_FILENAME = 'attributes.pickle' 56 | CONFIG_FILENAME = 'config.ini' 57 | 58 | # MODEL SETTINGS 59 | # batch size to use when performing model prediction 60 | PRED_BATCH_SIZE = 256 61 | # max length of a sentence 62 | MAX_SENT_LEN = 100 63 | # max length of a character sequence (word) 64 | MAX_CHAR_LEN = 25 65 | # number of units in the LSTM layers 66 | UNITS_WORD_LSTM = 200 67 | UNITS_CHAR_LSTM = 200 68 | UNITS_DENSE = UNITS_WORD_LSTM // 2 69 | # possible models 70 | MODEL_NAMES = ['mt-lstm-crf',] 71 | 72 | # EXTRACT 2.0 API 73 | # arguments passed in a get request to the EXTRACT 2.0 API to specify entity type 74 | ENTITY_TYPES = {'CHED': -1, 'DISO': -26, 'LIVB': -2} 75 | # the namespaces of the external resources that EXTRACT 2.0 grounds too 76 | NAMESPACES = {'CHED': 'PubChem Compound', 77 | 'DISO': 'Disease Ontology', 78 | 'LIVB': 'NCBI Taxonomy', 79 | 'PRGE': 'STRING', 80 | } 81 | 82 | # RESTful API 83 | # endpoint for Entrez Utilities Web Service API 84 | EUTILS_API_ENDPOINT = ('https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?retmode=xml&db=' 85 | 'pubmed&id=') 86 | # CONFIG 87 | CONFIG_ARGS = ['model_name', 'save_model', 'dataset_folder', 'output_folder', 88 | 'pretrained_model', 'pretrained_embeddings', 'word_embed_dim', 89 | 'char_embed_dim', 'optimizer', 'activation', 'learning_rate', 'decay', 'grad_norm', 90 | 'dropout_rate', 'batch_size', 'k_folds', 'epochs', 'criteria', 'verbose', 91 | 'debug', 'save_all_weights', 'tensorboard', 'replace_rare_tokens', 92 | 'load_all_embeddings', 'fine_tune_word_embeddings', 'variational_dropout'] 93 | -------------------------------------------------------------------------------- /saber/dataset.py: -------------------------------------------------------------------------------- 1 | """Contains the Dataset class, which handles the loading and storage of datasets. 2 | """ 3 | import logging 4 | import os 5 | from itertools import chain 6 | 7 | from keras.utils import to_categorical 8 | from nltk.corpus.reader.conll import ConllCorpusReader 9 | 10 | from . import constants 11 | from .preprocessor import Preprocessor 12 | from .utils import data_utils, generic_utils 13 | 14 | LOGGER = logging.getLogger(__name__) 15 | 16 | class Dataset(object): 17 | """A class for handling datasets. Expects datasets to be in tab-seperated CoNLL format, where 18 | each line contains a token and its tag (seperated by a tab) and each sentence is seperated 19 | by a blank line. 20 | 21 | Example corpus: 22 | ''' 23 | The O 24 | transcription O 25 | of O 26 | most O 27 | RP B-PRGE 28 | genes I-PRGE 29 | ... 30 | ''' 31 | 32 | Args: 33 | directory (str): Path to directory containing CoNLL formatted dataset(s). 34 | replace_rare_tokens (bool): True if rare tokens should be replaced with a special unknown 35 | token. Threshold for considering tokens rare can be found at `saber.constants.NUM_RARE`. 36 | """ 37 | def __init__(self, directory=None, replace_rare_tokens=True, **kwargs): 38 | self.directory = directory 39 | # don't load corpus unless directory was passed on object construction 40 | if self.directory is not None: 41 | self.directory = data_utils.get_filepaths(directory) 42 | self.conll_parser = ConllCorpusReader(directory, '.conll', ('words', 'pos')) 43 | 44 | self.replace_rare_tokens = replace_rare_tokens 45 | 46 | # word, character and tag sequences from dataset (per partition) 47 | self.type_seq = {'train': None, 'valid': None, 'test': None} 48 | # mappings of word, characters, and tag types to unique integer IDs 49 | self.type_to_idx = {'word': None, 'char': None, 'tag': None} 50 | # reverse mapping of unique integer IDs to tag types 51 | self.idx_to_tag = None 52 | # same as type_seq but all words, characters and tags have been mapped to unique integer IDs 53 | self.idx_seq = {'train': None, 'valid': None, 'test': None} 54 | 55 | for key, value in kwargs.items(): 56 | setattr(self, key, value) 57 | 58 | def load(self): 59 | """Coordinates the loading of a given data set at `self.directory`. 60 | 61 | For a given dataset in CoNLL format at `self.directory`, coordinates the loading of data and 62 | updates the appropriate instance attributes. Expects `self.directory` to be a directory 63 | containing a single file, `train.*` and optionally two additional files, `valid.*` and 64 | `test.*`. 65 | 66 | Raises: 67 | ValueError if `self.directory` is None. 68 | """ 69 | if self.directory is None: 70 | err_msg = "`Dataset.directory` is None; must be provided before call to `Dataset.load`" 71 | LOGGER.error('ValueError %s', err_msg) 72 | raise ValueError(err_msg) 73 | 74 | # unique words, chars and tags from CoNLL formatted dataset 75 | types = self._get_types() 76 | # map each word, char, and tag type to a unique integer 77 | self._get_idx_maps(types) 78 | 79 | # get word, char, and tag sequences from CoNLL formatted dataset 80 | self._get_type_seq() 81 | # get final representation used for training 82 | self.get_idx_seq() 83 | 84 | # useful during prediction / annotation 85 | self.idx_to_tag = generic_utils.reverse_dict(self.type_to_idx['tag']) 86 | 87 | def _get_types(self): 88 | """Collects the sets of all words, characters and tags in a CoNLL formatted dataset. 89 | 90 | For the CoNLL formatted dataset given at `self.directory`, updates `self.types` with the 91 | sets of all words (word types), characters (character types) and tags (tag types). All types 92 | are shared across all partitions, that is, word, char and tag types are collected from the 93 | train and, if provided, valid/test partitions found at `self.directory/train.*`, 94 | `self.directory/valid.*` and `self.directory/test.*`. 95 | """ 96 | types = {'word': [constants.PAD, constants.UNK], 97 | 'char': [constants.PAD, constants.UNK], 98 | 'tag': [constants.PAD], 99 | } 100 | 101 | for _, filepath in self.directory.items(): 102 | if filepath is not None: 103 | conll_file = os.path.basename(filepath) # get name of conll file 104 | types['word'].extend(set(self.conll_parser.words(conll_file))) 105 | types['char'].extend(set(chain(*[list(w) for w in self.conll_parser.words(conll_file)]))) 106 | types['tag'].extend(set([tag[-1] for tag in self.conll_parser.tagged_words(conll_file)])) 107 | 108 | # ensure that we have only unique types 109 | types['word'] = list(set(types['word'])) 110 | types['char'] = list(set(types['char'])) 111 | types['tag'] = list(set(types['tag'])) 112 | 113 | return types 114 | 115 | def _get_type_seq(self): 116 | """Loads sequence data from a CoNLL format data set given at `self.directory`. 117 | 118 | For the CoNLL formatted dataset given at `self.directory`, updates `self.type_seq` with 119 | lists containing the word, character and tag sequences for the train and, if provided, 120 | valid/test partitions found at `self.directory/train.*`, `self.directory/valid.*` and 121 | `self.directory/test.*`. 122 | """ 123 | for partition, filepath in self.directory.items(): 124 | if filepath is not None: 125 | conll_file = os.path.basename(filepath) # get name of conll file 126 | 127 | # collect sequence data 128 | sents = list(self.conll_parser.sents(conll_file)) 129 | tagged_sents = list(self.conll_parser.tagged_sents(conll_file)) 130 | 131 | word_seq = Preprocessor.replace_rare_tokens(sents) if self.replace_rare_tokens else sents 132 | char_seq = [[[c for c in w] for w in s] for s in sents] 133 | tag_seq = [[t[-1] for t in s] for s in tagged_sents] 134 | 135 | # update the class attributes 136 | self.type_seq[partition] = {'word': word_seq, 'char': char_seq, 'tag': tag_seq} 137 | 138 | def _get_idx_maps(self, types, initial_mapping=None): 139 | """Updates `self.type_to_idx` with mappings from word, char and tag types to unique int IDs. 140 | """ 141 | initial_mapping = constants.INITIAL_MAPPING if initial_mapping is None else initial_mapping 142 | # generate type to index mappings 143 | self.type_to_idx['word'] = Preprocessor.type_to_idx(types['word'], initial_mapping['word']) 144 | self.type_to_idx['char'] = Preprocessor.type_to_idx(types['char'], initial_mapping['word']) 145 | self.type_to_idx['tag'] = Preprocessor.type_to_idx(types['tag'], initial_mapping['tag']) 146 | 147 | def get_idx_seq(self): 148 | """Updates `self.idx_seq` with the final representation of the data used for training. 149 | 150 | Updates `self.idx_seq` with numpy arrays, by using `self.type_to_idx` to map all elements 151 | in `self.type_seq` to their corresponding integer IDs, for the train and, if provided, 152 | valid/test partitions found at `self.directory/train.*`, `self.directory/valid.*` and 153 | `self.directory/test.*`. 154 | """ 155 | for partition, filepath in self.directory.items(): 156 | if filepath is not None: 157 | self.idx_seq[partition] = { 158 | 'word': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['word'], 159 | self.type_to_idx['word'], 160 | type_='word'), 161 | 'char': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['word'], 162 | self.type_to_idx['char'], 163 | type_='char'), 164 | 'tag': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['tag'], 165 | self.type_to_idx['tag'], 166 | type_='tag'), 167 | } 168 | # one-hot encode our targets 169 | self.idx_seq[partition]['tag'] = to_categorical(self.idx_seq[partition]['tag']) 170 | -------------------------------------------------------------------------------- /saber/embeddings.py: -------------------------------------------------------------------------------- 1 | """Contains the Embedding class, which provides all code for working with pre-trained embeddings. 2 | """ 3 | import numpy as np 4 | from gensim.models import KeyedVectors 5 | 6 | from . import constants 7 | from .preprocessor import Preprocessor 8 | 9 | 10 | class Embeddings(object): 11 | """A class for loading and working with pre-trained word embeddings. 12 | 13 | Args: 14 | filepath (str): Path to file which contains pre-trained word embeddings. 15 | token_map (dict): A dictionary which maps tokens to unique integer IDs. 16 | """ 17 | def __init__(self, filepath, token_map, **kwargs): 18 | self.filepath = filepath 19 | self.token_map = token_map 20 | 21 | self.matrix = None # matrix containing row vectors for all embedded tokens 22 | self.num_found = None # number of loaded embeddings 23 | self.num_embed = None # final count of embedded words 24 | self.dimension = None # dimension of these embeddings 25 | 26 | for key, value in kwargs.items(): 27 | setattr(self, key, value) 28 | 29 | def load(self, binary=True, load_all=False): 30 | """Coordinates the loading of pre-trained word embeddings. 31 | 32 | Creates an embedding matrix from the pre-trained word embeddings given at `self.filepath`, 33 | whose ith row corresponds to the word embedding for the word with value i in 34 | `self.token_map`. Updates the instance attributes `self.matrix, `self.num_found`, 35 | `self.dimension`, and `self.num_embed`. 36 | 37 | Args: 38 | binary (bool): True if pre-trained embeddings are in C binary format, False if they are 39 | in C text format. Defaults to True. 40 | load_all (bool): True if all embeddings should be loaded. False if only words that 41 | appear in `self.token_map` should be loaded. Defaults to False. 42 | 43 | Returns: 44 | The token map used to map the embeddings to an embedding matrix. 45 | """ 46 | # prepare the embedding indices 47 | embedding_idx = self._prepare_embedding_index(binary) 48 | self.num_found, self.dimension = len(embedding_idx), len(list(embedding_idx.values())[0]) 49 | self.matrix, type_to_idx = self._prepare_embedding_matrix(embedding_idx, load_all) 50 | self.num_embed = self.matrix.shape[0] # num of embedded words 51 | 52 | return type_to_idx 53 | 54 | def _prepare_embedding_index(self, binary=True): 55 | """Returns an embedding index for pre-trained token embeddings. 56 | 57 | For pre-trained word embeddings given at `self.filepath`, returns a 58 | dictionary mapping words to their embedding (an 'embedding index'). If `self.debug` is 59 | True, only the first ten thousand vectors are loaded. 60 | 61 | Args: 62 | binary (bool): True if pre-trained embeddings are in C binary format, False if they are 63 | in C text format. Defaults to True. 64 | 65 | Returns: 66 | Dictionary mapping words to pre-trained word embeddings, known as an 'embedding index'. 67 | """ 68 | limit = 10000 if self.__dict__.get("debug", False) else None 69 | vectors = KeyedVectors.load_word2vec_format(self.filepath, binary=binary, limit=limit) 70 | embedding_idx = {word: vectors[word] for word in vectors.vocab} 71 | 72 | return embedding_idx 73 | 74 | def _prepare_embedding_matrix(self, embedding_idx, load_all=False): 75 | """Returns an embedding matrix containing all pre-trained embeddings in `embedding_idx`. 76 | 77 | Creates an embedding matrix from `embedding_idx`, where the ith row contains the 78 | embedding for the word with value i in `self.token_map`. If no embedding exists for a given 79 | word in `embedding_idx`, the zero vector is used instead. 80 | 81 | Args: 82 | embedding_idx (dict): A Dictionary mapping words to their embeddings. 83 | load_all (bool): True if all embeddings should be loaded. False if only words that 84 | appear in `self.token_map` should be loaded. Defaults to False. 85 | 86 | Returns: 87 | A matrix whos ith row corresponds to the word embedding for the word with value i in 88 | `self.token_map`. 89 | """ 90 | type_to_idx = None 91 | if load_all: 92 | # overwrite provided token map 93 | type_to_idx = self._generate_type_to_idx(embedding_idx) 94 | self.token_map = type_to_idx['word'] 95 | 96 | # initialize the embeddings matrix 97 | embedding_matrix = np.zeros((len(self.token_map), self.dimension)) 98 | 99 | # lookup embeddings for every word in the dataset 100 | for word, i in self.token_map.items(): 101 | token_embedding = embedding_idx.get(word) 102 | if token_embedding is not None: 103 | # words not found in embedding index will be all-zeros. 104 | embedding_matrix[i] = token_embedding 105 | 106 | return embedding_matrix, type_to_idx 107 | 108 | @classmethod 109 | def _generate_type_to_idx(self, embedding_idx): 110 | """Returns a dictionary mapping tokens in `embedding_idx` to unique integer IDs. 111 | """ 112 | word_types, char_types = list(embedding_idx.keys()), [] 113 | for word in word_types: 114 | char_types.extend(list(word)) 115 | char_types = list(set(char_types)) 116 | 117 | type_to_idx = { 118 | 'word': Preprocessor.type_to_idx(word_types, constants.INITIAL_MAPPING['word']), 119 | 'char': Preprocessor.type_to_idx(char_types, constants.INITIAL_MAPPING['word']) 120 | } 121 | 122 | return type_to_idx 123 | -------------------------------------------------------------------------------- /saber/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import multi_task_lstm_crf 4 | -------------------------------------------------------------------------------- /saber/models/base_model.py: -------------------------------------------------------------------------------- 1 | """Contains the BaseModel class, the parent class to all Keras models in Saber. 2 | """ 3 | import json 4 | import logging 5 | 6 | from keras import optimizers 7 | from keras.models import model_from_json 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | class BaseKerasModel(object): 12 | """Parent class of all Keras model classes implemented by Saber. 13 | """ 14 | def __init__(self, config, datasets, embeddings=None, **kwargs): 15 | self.config = config # hyperparameters and model details 16 | self.datasets = datasets # dataset(s) tied to this instance 17 | self.embeddings = embeddings # pre-trained word embeddings tied to this instance 18 | self.models = [] # Keras model(s) tied to this instance 19 | 20 | for key, value in kwargs.items(): 21 | setattr(self, key, value) 22 | 23 | def save(self, weights_filepath, model_filepath, model=0): 24 | """Save a model to disk. 25 | 26 | Saves a keras model to disk, by saving its architecture as a json file at `model_filepath` 27 | and its weights as a hdf5 file at `model_filepath`. 28 | 29 | Args: 30 | weights_filepath (str): filepath to the models wieghts (.hdf5 file). 31 | model_filepath (str): filepath to the models architecture (.json file). 32 | model (int): which model from `self.models` to save. 33 | """ 34 | with open(model_filepath, 'w') as f: 35 | model_json = self.models[model].to_json() 36 | json.dump(json.loads(model_json), f, sort_keys=True, indent=4) 37 | self.models[model].save_weights(weights_filepath) 38 | 39 | def load(self, weights_filepath, model_filepath): 40 | """Load a model from disk. 41 | 42 | Loads a keras model from disk by loading its architecture from a json file at `model_filepath` 43 | and its weights from a hdf5 file at `model_filepath`. 44 | 45 | Args: 46 | weights_filepath (str): filepath to the models weights (.hdf5 file). 47 | model_filepath (str): filepath to the models architecture (.json file). 48 | """ 49 | with open(model_filepath) as f: 50 | model = model_from_json(f.read()) 51 | model.load_weights(weights_filepath) 52 | self.models.append(model) 53 | 54 | def prepare_data_for_training(self): 55 | """Returns a list containing the training data for each dataset at `self.datasets`. 56 | 57 | For each dataset at `self.datasets`, collects the data to be used for training. 58 | Each dataset is represented by a dictionary, where the keys 'x_' and 59 | 'y_' contain the inputs and targets for each partition 'train', 'valid', and 60 | 'test'. 61 | """ 62 | training_data = [] 63 | for ds in self.datasets: 64 | # collect train data, must be provided 65 | x_train = [ds.idx_seq['train']['word'], ds.idx_seq['train']['char']] 66 | y_train = ds.idx_seq['train']['tag'] 67 | # collect valid and test data, may not be provided 68 | x_valid, y_valid, x_test, y_test = None, None, None, None 69 | if ds.idx_seq['valid'] is not None: 70 | x_valid = [ds.idx_seq['valid']['word'], ds.idx_seq['valid']['char']] 71 | y_valid = ds.idx_seq['valid']['tag'] 72 | if ds.idx_seq['test'] is not None: 73 | x_test = [ds.idx_seq['test']['word'], ds.idx_seq['test']['char']] 74 | y_test = ds.idx_seq['test']['tag'] 75 | 76 | training_data.append({'x_train': x_train, 'y_train': y_train, 'x_valid': x_valid, 77 | 'y_valid': y_valid, 'x_test': x_test, 'y_test': y_test}) 78 | 79 | return training_data 80 | 81 | def _compile(self, model, loss_function, optimizer, lr=0.01, decay=0.0, clipnorm=0.0): 82 | """Compiles a model specified with Keras. 83 | 84 | See https://keras.io/optimizers/ for more info on each optimizer. 85 | 86 | Args: 87 | model: Keras model object to compile 88 | loss_function: Keras loss_function object to compile model with 89 | optimizer (str): the optimizer to use during training 90 | lr (float): learning rate to use during training 91 | decay (float): per epoch decay rate 92 | clipnorm (float): gradient normalization threshold 93 | """ 94 | # The parameters of these optimizers can be freely tuned. 95 | if optimizer == 'sgd': 96 | optimizer_ = optimizers.SGD(lr=lr, decay=decay, clipnorm=clipnorm) 97 | elif optimizer == 'adam': 98 | optimizer_ = optimizers.Adam(lr=lr, decay=decay, clipnorm=clipnorm) 99 | elif optimizer == 'adamax': 100 | optimizer_ = optimizers.Adamax(lr=lr, decay=decay, clipnorm=clipnorm) 101 | # It is recommended to leave the parameters of this optimizer at their 102 | # default values (except the learning rate, which can be freely tuned). 103 | # This optimizer is usually a good choice for recurrent neural networks 104 | elif optimizer == 'rmsprop': 105 | optimizer_ = optimizers.RMSprop(lr=lr, clipnorm=clipnorm) 106 | # It is recommended to leave the parameters of these optimizers at their 107 | # default values. 108 | elif optimizer == 'adagrad': 109 | optimizer_ = optimizers.Adagrad(clipnorm=clipnorm) 110 | elif optimizer == 'adadelta': 111 | optimizer_ = optimizers.Adadelta(clipnorm=clipnorm) 112 | elif optimizer == 'nadam': 113 | optimizer_ = optimizers.Nadam(clipnorm=clipnorm) 114 | else: 115 | err_msg = "Argument for `optimizer` is invalid, got: {}".format(optimizer) 116 | LOGGER.error('ValueError %s', err_msg) 117 | raise ValueError(err_msg) 118 | 119 | model.compile(optimizer=optimizer_, loss=loss_function) 120 | -------------------------------------------------------------------------------- /saber/models/multi_task_lstm_crf.py: -------------------------------------------------------------------------------- 1 | """Contains the Multi-task BiLSTM-CRF (MT-BILSTM-CRF) Keras model for squence labelling. 2 | """ 3 | import logging 4 | 5 | import tensorflow as tf 6 | from keras.layers import (LSTM, Bidirectional, Concatenate, Dense, Dropout, 7 | Embedding, SpatialDropout1D, TimeDistributed) 8 | from keras.models import Input, Model, model_from_json 9 | from keras.utils import multi_gpu_model 10 | from keras_contrib.layers.crf import CRF 11 | from keras_contrib.losses.crf_losses import crf_loss 12 | 13 | from .. import constants 14 | from .base_model import BaseKerasModel 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | class MultiTaskLSTMCRF(BaseKerasModel): 19 | """A Keras implementation of a BiLSTM-CRF for sequence labeling. 20 | 21 | A BiLSTM-CRF for NER implementation in Keras. Supports multi-task learning by default, just pass 22 | multiple Dataset objects via `datasets` to the constructor and the model will share the 23 | parameters of all layers, except for the final output layer, across all datasets, where each 24 | dataset represents a sequence labelling task. 25 | 26 | Args: 27 | config (Config): Contains a set of harmonzied arguments provided in a *.ini file and, 28 | optionally, from the command line. 29 | datasets (list): A list of Dataset objects. 30 | embeddings (numpy.ndarray): A numpy array where ith row contains the vector embedding for 31 | the ith word type. 32 | 33 | References: 34 | - Guillaume Lample, Miguel Ballesteros, Sandeep Subramanian, Kazuya Kawakami, Chris Dyer. 35 | "Neural Architectures for Named Entity Recognition". Proceedings of NAACL 2016. 36 | https://arxiv.org/abs/1603.01360 37 | """ 38 | def __init__(self, config, datasets, embeddings=None, **kwargs): 39 | super().__init__(config, datasets, embeddings, **kwargs) 40 | 41 | def load(self, weights_filepath, model_filepath): 42 | """Load a model from disk. 43 | 44 | Loads a model from disk by loading its architecture from a json file at `model_filepath` 45 | and its weights from a hdf5 file at `model_filepath`. 46 | 47 | Args: 48 | weights_filepath (str): Filepath to the models wieghts (.hdf5 file). 49 | model_filepath (str): Filepath to the models architecture (.json file). 50 | """ 51 | with open(model_filepath) as f: 52 | model = model_from_json(f.read(), custom_objects={'CRF': CRF}) 53 | model.load_weights(weights_filepath) 54 | self.models.append(model) 55 | 56 | def specify(self): 57 | """Specifies a multi-task BiLSTM-CRF for sequence tagging using Keras. 58 | 59 | Implements a hybrid bidirectional long short-term memory network-condition random 60 | field (BiLSTM-CRF) multi-task model for sequence tagging. 61 | """ 62 | # specify any shared layers outside the for loop 63 | # word-level embedding layer 64 | if self.embeddings is None: 65 | word_embeddings = Embedding(input_dim=len(self.datasets[0].type_to_idx['word']) + 1, 66 | output_dim=self.config.word_embed_dim, 67 | mask_zero=True, 68 | name="word_embedding_layer") 69 | else: 70 | word_embeddings = Embedding(input_dim=self.embeddings.num_embed, 71 | output_dim=self.embeddings.dimension, 72 | mask_zero=True, 73 | weights=[self.embeddings.matrix], 74 | trainable=self.config.fine_tune_word_embeddings, 75 | name="word_embedding_layer") 76 | # character-level embedding layer 77 | char_embeddings = Embedding(input_dim=len(self.datasets[0].type_to_idx['char']) + 1, 78 | output_dim=self.config.char_embed_dim, 79 | mask_zero=True, 80 | name="char_embedding_layer") 81 | # char-level BiLSTM 82 | char_BiLSTM = TimeDistributed(Bidirectional(LSTM(constants.UNITS_CHAR_LSTM // 2)), 83 | name='character_BiLSTM') 84 | # word-level BiLSTM 85 | word_BiLSTM_1 = Bidirectional(LSTM(units=constants.UNITS_WORD_LSTM // 2, 86 | return_sequences=True, 87 | dropout=self.config.dropout_rate['input'], 88 | recurrent_dropout=self.config.dropout_rate['recurrent']), 89 | name="word_BiLSTM_1") 90 | word_BiLSTM_2 = Bidirectional(LSTM(units=constants.UNITS_WORD_LSTM // 2, 91 | return_sequences=True, 92 | dropout=self.config.dropout_rate['input'], 93 | recurrent_dropout=self.config.dropout_rate['recurrent']), 94 | name="word_BiLSTM_2") 95 | 96 | # get all unique tag types across all datasets 97 | all_tags = [ds.type_to_idx['tag'] for ds in self.datasets] 98 | all_tags = set(x for l in all_tags for x in l) 99 | 100 | # feedforward before CRF, maps each time step to a vector 101 | dense_layer = TimeDistributed(Dense(len(all_tags), activation=self.config.activation), 102 | name='dense_layer') 103 | 104 | # specify model, taking into account the shared layers 105 | for dataset in self.datasets: 106 | # word-level embedding 107 | word_ids = Input(shape=(None, ), dtype='int32', name='word_id_inputs') 108 | word_embed = word_embeddings(word_ids) 109 | 110 | # character-level embedding 111 | char_ids = Input(shape=(None, None), dtype='int32', name='char_id_inputs') 112 | char_embed = char_embeddings(char_ids) 113 | 114 | # character-level BiLSTM + dropout. Spatial dropout applies the same dropout mask to all 115 | # timesteps which is necessary to implement variational dropout 116 | # (https://arxiv.org/pdf/1512.05287.pdf) 117 | char_embed = char_BiLSTM(char_embed) 118 | if self.config.variational_dropout: 119 | LOGGER.info('Used variational dropout') 120 | char_embed = SpatialDropout1D(self.config.dropout_rate['output'])(char_embed) 121 | 122 | # concatenate word- and char-level embeddings + dropout 123 | model = Concatenate()([word_embed, char_embed]) 124 | model = Dropout(self.config.dropout_rate['output'])(model) 125 | 126 | # word-level BiLSTM + dropout 127 | model = word_BiLSTM_1(model) 128 | if self.config.variational_dropout: 129 | model = SpatialDropout1D(self.config.dropout_rate['output'])(model) 130 | model = word_BiLSTM_2(model) 131 | if self.config.variational_dropout: 132 | model = SpatialDropout1D(self.config.dropout_rate['output'])(model) 133 | 134 | # feedforward before CRF 135 | model = dense_layer(model) 136 | 137 | # CRF output layer 138 | crf = CRF(len(dataset.type_to_idx['tag']), name='crf_classifier') 139 | output_layer = crf(model) 140 | 141 | # fully specified model 142 | # https://github.com/keras-team/keras/blob/bf1378f39d02b7d0b53ece5458f9275ac8208046/keras/utils/multi_gpu_utils.py 143 | with tf.device('/cpu:0'): 144 | model = Model(inputs=[word_ids, char_ids], outputs=[output_layer]) 145 | self.models.append(model) 146 | 147 | def compile(self): 148 | """Compiles the BiLSTM-CRF. 149 | 150 | Compiles the Keras model(s) at `self.models`. If multiple GPUs are detected, a model 151 | capable of training on all of them is compiled. 152 | """ 153 | for i in range(len(self.models)): 154 | # parallize the model if multiple GPUs are available 155 | # https://github.com/keras-team/keras/pull/9226 156 | # awfully bad practice but this was the example given by Keras documentation 157 | try: 158 | self.models[i] = multi_gpu_model(self.models[i]) 159 | LOGGER.info('Compiling the model on multiple GPUs') 160 | except: 161 | LOGGER.info('Compiling the model on a single CPU or GPU') 162 | 163 | self._compile(model=self.models[i], 164 | loss_function=crf_loss, 165 | optimizer=self.config.optimizer, 166 | lr=self.config.learning_rate, 167 | decay=self.config.decay, 168 | clipnorm=self.config.grad_norm) 169 | 170 | def prepare_for_transfer(self, datasets): 171 | """Prepares the BiLSTM-CRF for transfer learning by recreating its last layer. 172 | 173 | Prepares the BiLSTM-CRF model(s) at `self.models` for transfer learning by removing their 174 | CRF classifiers and replacing them with un-trained CRF classifiers of the appropriate size 175 | (i.e. number of units equal to number of output tags) for the target datasets. 176 | 177 | References: 178 | - https://stackoverflow.com/questions/41378461/how-to-use-models-from-keras-applications-for-transfer-learnig/41386444#41386444 179 | """ 180 | self.datasets = datasets # replace with target datasets 181 | models, self.models = self.models, [] # wipe models 182 | 183 | for dataset, model in zip(self.datasets, models): 184 | # remove the old CRF classifier and define a new one 185 | model.layers.pop() 186 | new_crf = CRF(len(dataset.type_to_idx['tag']), name='target_crf_classifier') 187 | # create the new model 188 | new_input = model.input 189 | new_output = new_crf(model.layers[-1].output) 190 | self.models.append(Model(new_input, new_output)) 191 | 192 | self.compile() 193 | -------------------------------------------------------------------------------- /saber/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/tests/__init__.py -------------------------------------------------------------------------------- /saber/tests/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/tests/resources/__init__.py -------------------------------------------------------------------------------- /saber/tests/resources/dummy_config.ini: -------------------------------------------------------------------------------- 1 | [mode] 2 | # Possible models: [MT-LSTM-CRF, ] 3 | model_name = MT-LSTM-CRF 4 | # If True, model is compressed and saved in output_folder at the end of training. Model weights are 5 | # taken from the last epoch. 6 | save_model = False 7 | 8 | [data] 9 | # You can specify multiple datasets by listing their paths, separated by a comma. 10 | dataset_folder = saber/tests/resources/dummy_dataset_1 11 | output_folder = ../output 12 | # Path to pre-trained model. To train a model from scratch, leave this blank. 13 | pretrained_model = 14 | # Path to pre-trained word embeddings. In order to use random initialization, leave this blank. 15 | # Note that you can leave this blank when loading a pre-trained model 16 | # (via 'SequenceProcessor.load()') that was trained with pre-tained embeddings 17 | pretrained_embeddings = saber/tests/resources/dummy_word_embeddings/dummy_word_embeddings.txt 18 | 19 | [model] 20 | # If pre-trained word embeddings are provided, word_embed_dim will be the same size as these 21 | # embeddings and this argument will be ignored. 22 | word_embed_dim = 200 23 | char_embed_dim = 30 24 | 25 | [training] 26 | # Values chosen for each hyperparameter represent sensible defaults that perform well across a wide 27 | # range of NLP tasks (POS tagging, Chunking, NER, etc.) and thus should only be changed in special 28 | # circumstances. 29 | optimizer = nadam 30 | activation = relu 31 | # Set to 0 to turn off gradient normalization. 32 | grad_norm = 1.0 33 | # For certain optimizers, these values are ignored. See compile_model() in 34 | # saber/utils/model_utils.py. 35 | learning_rate = 0.0 36 | decay = 0.0 37 | 38 | # Three dropout values must be specified (separated by a comma), corresponding to the dropout rate 39 | # to apply to the input, output and recurrent connections respectively. Must be a value between 0.0 40 | # and 1.0. 41 | dropout_rate = 0.3, 0.3, 0.1 42 | 43 | batch_size = 32 44 | # If a test partition is supplied at 'dataset_folder' (test.*) then this argument is ignored, and a 45 | # simple train/valid/test scheme is used. A valid partition (valid.*) may optionally be provided 46 | # along with the test parition. If none is found, 10% of examples are randomly selected. 47 | k_folds = 2 48 | epochs = 50 49 | 50 | # Matching criteria used when determining whether or not a prediction is a true-positive. Choices 51 | # are 'left' for left-boundary matching, 'right' for right-boundary matching and 'exact' for 52 | # exact-boundary matching. 53 | criteria = exact 54 | 55 | [advanced] 56 | verbose = False 57 | debug = False 58 | # If True, per-epoch logs which can be visualized with TensorBoard are written to output_folder 59 | # Note: These logs can be quite large. 60 | tensorboard = False 61 | # If True, then during training the models weights for each and every epoch will be saved. 62 | # Otherwise, weights are only saved for epochs that achieve a new best on validation loss. 63 | save_all_weights = False 64 | # If True, tokens that occur less than 1 time in the training dataset (hapax legomenon) are replaced 65 | # with a special unknown token. This should result in faster loading times of pre-trained word 66 | # embeddings and faster training times. 67 | replace_rare_tokens = False 68 | # If True, then all pre-trained word embeddings provided via pretrained_embeddings are loaded. 69 | # Otherwise, only pre-trained embeddings for tokens found in the dataset(s) at dataset_folder are 70 | # loaded. For evaluation, it's best to leave this as False. For models that will be deployed, it's 71 | # best to set it to True. 72 | load_all_embeddings = False 73 | # If True, then pre-trained word embeddings will be fine-tuned with the other parameters of the 74 | # neural network during training. Generally, you should not set this to True unless you have a very 75 | # large training dataset. 76 | # NOTE: if 'pretrained_embeddings' are not provided, they will be randomly initialized and 77 | # fine-tuned during training, ignoring this argument. 78 | fine_tune_word_embeddings = False 79 | # TEMP. Set to true if variational dropout should be used. 80 | variational_dropout = False 81 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_constants.py: -------------------------------------------------------------------------------- 1 | """Constants used by the unit tests. 2 | """ 3 | import logging 4 | import os 5 | 6 | import numpy as np 7 | from pkg_resources import resource_filename 8 | 9 | from ... import constants 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | # relative paths for test resources 14 | PATH_TO_DUMMY_DATASET_1 = resource_filename(__name__, 'dummy_dataset_1') 15 | PATH_TO_DUMMY_DATASET_2 = resource_filename(__name__, 'dummy_dataset_2') 16 | PATH_TO_DUMMY_CONFIG = resource_filename(__name__, 'dummy_config.ini') 17 | PATH_TO_DUMMY_EMBEDDINGS = resource_filename(__name__, 'dummy_word_embeddings/dummy_word_embeddings.txt') 18 | 19 | ######################################### DUMMY EMBEDDINGS ######################################### 20 | 21 | # for testing embeddings 22 | DUMMY_TOKEN_MAP = {'': 0, '': 1, 'the': 2, 'quick': 3, 'brown': 4, 'fox': 5} 23 | DUMMY_CHAR_MAP = {'': 0, '': 1, 'r': 2, 'u': 3, 'c': 4, 'f': 5, 'e': 6, 'o': 7, 'x': 8, 24 | 'h': 9, 'b': 10, 'n': 11, 'w': 12, 'i': 13, 't': 14, 'q': 15, 'k': 16} 25 | DUMMY_EMBEDDINGS_INDEX = { 26 | 'the': [0.15580128, -0.07108746, 0.055198, -0.14199848, 0.0005317868], 27 | 'quick': [-0.011208724, 0.21213274, -0.17233513, -0.4401193, 0.13930725], 28 | 'brown': [0.12754257, -0.07938199, 0.083904505, -0.24103324, 0.0084449835], 29 | 'fox': [0.2947119, 0.14794342, 0.10318808, 0.09019197, -0.24244581] 30 | } 31 | DUMMY_EMBEDDINGS_MATRIX = np.array([ 32 | [0.0, 0.0, 0.0, 0.0, 0.0], 33 | [0.0, 0.0, 0.0, 0.0, 0.0], 34 | [0.15580128, -0.07108746, 0.055198, -0.14199848, 0.0005317868], 35 | [-0.011208724, 0.21213274, -0.17233513, -0.4401193, 0.13930725], 36 | [0.12754257, -0.07938199, 0.083904505, -0.24103324, 0.0084449835], 37 | [0.2947119, 0.14794342, 0.10318808, 0.09019197, -0.24244581] 38 | ]) 39 | 40 | ########################################## DUMMY DATASET ########################################## 41 | 42 | DUMMY_WORD_SEQ = np.array([ 43 | ['Human', 'APC2', 'maps', 'to', 'chromosome', '19p13', '.'], 44 | ['The', 'absence', 'of', 'functional', 'C7', 'activity', 'could', 'not', 'be', 'accounted', 45 | 'for', 'on', 'the', 'basis', 'of', 'an', 'inhibitor', '.'], 46 | ]) 47 | DUMMY_CHAR_SEQ = np.array([ 48 | [['H', 'u', 'm', 'a', 'n'], ['A', 'P', 'C', '2'], ['m', 'a', 'p', 's'], ['t', 'o'], 49 | ['c', 'h', 'r', 'o', 'm', 'o', 's', 'o', 'm', 'e'], ['1', '9', 'p', '1', '3'], ['.']], 50 | [['T', 'h', 'e'], ['a', 'b', 's', 'e', 'n', 'c', 'e'], ['o', 'f'], 51 | ['f', 'u', 'n', 'c', 't', 'i', 'o', 'n', 'a', 'l'], ['C', '7'], 52 | ['a', 'c', 't', 'i', 'v', 'i', 't', 'y'], ['c', 'o', 'u', 'l', 'd'], ['n', 'o', 't'], 53 | ['b', 'e'], ['a', 'c', 'c', 'o', 'u', 'n', 't', 'e', 'd'], ['f', 'o', 'r'], ['o', 'n'], 54 | ['t', 'h', 'e'], ['b', 'a', 's', 'i', 's'], ['o', 'f'], ['a', 'n'], 55 | ['i', 'n', 'h', 'i', 'b', 'i', 't', 'o', 'r'], ['.']]]) 56 | DUMMY_TAG_SEQ = np.array([ 57 | ['O', 'O', 'O', 'O', 'O', 'O', 'O'], 58 | ['O', 'B-DISO', 'I-DISO', 'I-DISO', 'E-DISO', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 59 | 'O', 'O', 'O'], 60 | ]) 61 | DUMMY_WORD_TYPES = ['Human', 'APC2', 'maps', 'to', 'chromosome', '19p13', '.', 'The', 'absence', 62 | 'of', 'functional', 'C7', 'activity', 'could', 'not', 'be', 'accounted', 63 | 'for', 'on', 'the', 'basis', 'an', 'inhibitor', constants.PAD, constants.UNK] 64 | DUMMY_CHAR_TYPES = ['2', 's', 'c', 'T', 'd', 'e', 'H', 'h', 'a', 'b', 'v', 'C', 'm', 't', '9', 'p', 65 | 'r', '3', 'u', '.', 'o', '7', 'n', 'f', 'y', 'l', '1', 'i', 'A', 'P', 66 | constants.PAD, constants.UNK] 67 | DUMMY_TAG_TYPES = ['O', 'B-DISO', 'I-DISO', 'E-DISO', constants.PAD] 68 | 69 | ########################################### DUMMY CONFIG ########################################### 70 | 71 | # Sections of the .ini file 72 | CONFIG_SECTIONS = ['mode', 'data', 'model', 'training', 'advanced'] 73 | 74 | # Arg values before any processing 75 | DUMMY_ARGS_NO_PROCESSING = {'model_name': 'MT-LSTM-CRF', 76 | 'save_model': 'False', 77 | 'dataset_folder': 'saber/tests/resources/dummy_dataset_1', 78 | 'output_folder': '../output', 79 | 'pretrained_model': '', 80 | 'pretrained_embeddings': ('saber/tests/resources/' 81 | 'dummy_word_embeddings/' 82 | 'dummy_word_embeddings.txt'), 83 | 'word_embed_dim': '200', 84 | 'char_embed_dim': '30', 85 | 'optimizer': 'nadam', 86 | 'activation': 'relu', 87 | 'learning_rate': '0.0', 88 | 'grad_norm': '1.0', 89 | 'decay': '0.0', 90 | 'dropout_rate': '0.3, 0.3, 0.1', 91 | 'batch_size': '32', 92 | 'k_folds': '2', 93 | 'epochs': '50', 94 | 'criteria': 'exact', 95 | 'verbose': 'False', 96 | 'debug': 'False', 97 | 'save_all_weights': 'False', 98 | 'tensorboard': 'False', 99 | 'replace_rare_tokens': 'False', 100 | 'load_all_embeddings': 'False', 101 | 'fine_tune_word_embeddings': 'False', 102 | # TEMP 103 | 'variational_dropout': 'False', 104 | } 105 | # Final arg values when args provided in only config file 106 | DUMMY_ARGS_NO_CLI_ARGS = {'model_name': 'mt-lstm-crf', 107 | 'save_model': False, 108 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1], 109 | 'output_folder': os.path.abspath('../output'), 110 | 'pretrained_model': '', 111 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS, 112 | 'word_embed_dim': 200, 113 | 'char_embed_dim': 30, 114 | 'optimizer': 'nadam', 115 | 'activation': 'relu', 116 | 'learning_rate': 0.0, 117 | 'decay': 0.0, 118 | 'grad_norm': 1.0, 119 | 'dropout_rate': {'input': 0.3, 'output':0.3, 'recurrent': 0.1}, 120 | 'batch_size': 32, 121 | 'k_folds': 2, 122 | 'epochs': 50, 123 | 'criteria': 'exact', 124 | 'verbose': False, 125 | 'debug': False, 126 | 'save_all_weights': False, 127 | 'tensorboard': False, 128 | 'replace_rare_tokens': False, 129 | 'load_all_embeddings': False, 130 | 'fine_tune_word_embeddings': False, 131 | # TEMP 132 | 'variational_dropout': False, 133 | } 134 | # Final arg values when args provided in config file and from CLI 135 | DUMMY_COMMAND_LINE_ARGS = {'optimizer': 'sgd', 136 | 'grad_norm': 1.0, 137 | 'learning_rate': 0.05, 138 | 'decay': 0.5, 139 | 'dropout_rate': [0.6, 0.6, 0.2], 140 | # the dataset and embeddings are used for test purposes so they must 141 | # point to the correct resources, this can be ensured by passing their 142 | # paths here 143 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1], 144 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS, 145 | } 146 | DUMMY_ARGS_WITH_CLI_ARGS = {'model_name': 'mt-lstm-crf', 147 | 'save_model': False, 148 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1], 149 | 'output_folder': os.path.abspath('../output'), 150 | 'pretrained_model': '', 151 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS, 152 | 'word_embed_dim': 200, 153 | 'char_embed_dim': 30, 154 | 'optimizer': 'sgd', 155 | 'activation': 'relu', 156 | 'learning_rate': 0.05, 157 | 'decay': 0.5, 158 | 'grad_norm': 1.0, 159 | 'dropout_rate': {'input': 0.6, 'output': 0.6, 'recurrent': 0.2}, 160 | 'batch_size': 32, 161 | 'k_folds': 2, 162 | 'epochs': 50, 163 | 'criteria': 'exact', 164 | 'verbose': False, 165 | 'debug': False, 166 | 'save_all_weights': False, 167 | 'tensorboard': False, 168 | 'replace_rare_tokens': False, 169 | 'load_all_embeddings': False, 170 | 'fine_tune_word_embeddings': False, 171 | # TEMP 172 | 'variational_dropout': False, 173 | } 174 | ########################################### WEB SERVICE ########################################### 175 | 176 | DUMMY_ENTITIES = {'ANAT': False, 177 | 'CHED': True, 178 | 'DISO': False, 179 | 'LIVB': True, 180 | 'PRGE': True, 181 | 'TRIG': False, 182 | } 183 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_1/test.tsv: -------------------------------------------------------------------------------- 1 | Human O 2 | APC2 O 3 | maps O 4 | to O 5 | chromosome O 6 | 19p13 O 7 | . O 8 | 9 | The O 10 | absence B-DISO 11 | of I-DISO 12 | functional I-DISO 13 | C7 E-DISO 14 | activity O 15 | could O 16 | not O 17 | be O 18 | accounted O 19 | for O 20 | on O 21 | the O 22 | basis O 23 | of O 24 | an O 25 | inhibitor O 26 | . O 27 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_1/train.tsv: -------------------------------------------------------------------------------- 1 | Human O 2 | APC2 O 3 | maps O 4 | to O 5 | chromosome O 6 | 19p13 O 7 | . O 8 | 9 | The O 10 | absence B-DISO 11 | of I-DISO 12 | functional I-DISO 13 | C7 E-DISO 14 | activity O 15 | could O 16 | not O 17 | be O 18 | accounted O 19 | for O 20 | on O 21 | the O 22 | basis O 23 | of O 24 | an O 25 | inhibitor O 26 | . O 27 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_1/valid.tsv: -------------------------------------------------------------------------------- 1 | Human O 2 | APC2 O 3 | maps O 4 | to O 5 | chromosome O 6 | 19p13 O 7 | . O 8 | 9 | The O 10 | absence B-DISO 11 | of I-DISO 12 | functional I-DISO 13 | C7 E-DISO 14 | activity O 15 | could O 16 | not O 17 | be O 18 | accounted O 19 | for O 20 | on O 21 | the O 22 | basis O 23 | of O 24 | an O 25 | inhibitor O 26 | . O 27 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_2/test.tsv: -------------------------------------------------------------------------------- 1 | Myoc B-PRGE 2 | alleles O 3 | do O 4 | not O 5 | associate O 6 | with O 7 | the O 8 | magnitude O 9 | of O 10 | IOP O 11 | 12 | Mutations O 13 | in O 14 | the O 15 | myocilin B-PRGE 16 | gene O 17 | ( O 18 | MYOC B-PRGE 19 | ) O 20 | cause O 21 | human O 22 | glaucoma O 23 | . O 24 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_2/train.tsv: -------------------------------------------------------------------------------- 1 | Myoc B-PRGE 2 | alleles O 3 | do O 4 | not O 5 | associate O 6 | with O 7 | the O 8 | magnitude O 9 | of O 10 | IOP O 11 | 12 | Mutations O 13 | in O 14 | the O 15 | myocilin B-PRGE 16 | gene O 17 | ( O 18 | MYOC B-PRGE 19 | ) O 20 | cause O 21 | human O 22 | glaucoma O 23 | . O 24 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_dataset_2/valid.tsv: -------------------------------------------------------------------------------- 1 | Myoc B-PRGE 2 | alleles O 3 | do O 4 | not O 5 | associate O 6 | with O 7 | the O 8 | magnitude O 9 | of O 10 | IOP O 11 | 12 | Mutations O 13 | in O 14 | the O 15 | myocilin B-PRGE 16 | gene O 17 | ( O 18 | MYOC B-PRGE 19 | ) O 20 | cause O 21 | human O 22 | glaucoma O 23 | . O 24 | -------------------------------------------------------------------------------- /saber/tests/resources/dummy_word_embeddings/dummy_word_embeddings.txt: -------------------------------------------------------------------------------- 1 | 4 5 2 | the 0.15580128 -0.07108746 0.055198 -0.14199848 0.0005317868 3 | quick -0.011208724 0.21213274 -0.17233513 -0.4401193 0.13930725 4 | brown 0.12754257 -0.07938199 0.083904505 -0.24103324 0.0084449835 5 | fox 0.2947119 0.14794342 0.10318808 0.09019197 -0.24244581 6 | -------------------------------------------------------------------------------- /saber/tests/resources/helpers.py: -------------------------------------------------------------------------------- 1 | """Any and all helper functions for Sabers unit tests. 2 | """ 3 | import configparser 4 | import os 5 | 6 | from ... import constants 7 | 8 | 9 | def assert_type_to_idx_as_expected(actual, expected): 10 | """Asserts that a `type_to_idx` mapping is as expected. This involves checking that it contains 11 | the expected, keys, the expected values, and that the values are a consecutive mapping of 12 | integers beginning at 0. 13 | """ 14 | # check keys 15 | assert all(word in actual['word'] for word in expected['word']) 16 | assert all(char in actual['char'] for char in expected['char']) 17 | # check that values are consectutive mapping of integers between 0 and length of the dictionary 18 | assert all(id in range(0, len(actual['word'])) for id in actual['word'].values()) 19 | assert all(id in range(0, len(actual['char'])) for id in actual['char'].values()) 20 | # check initial mapping items 21 | assert all(word in actual['word'] for word in constants.INITIAL_MAPPING['word']) 22 | assert all(word in actual['char'] for word in constants.INITIAL_MAPPING['word']) 23 | 24 | def load_saved_config(filepath): 25 | """Load a saved config.ConfigParser object at 'filepath/config.ini'. 26 | 27 | Args: 28 | filepath (str): filepath to the saved config file 'config.ini' 29 | 30 | Returns: 31 | parsed config.ConfigParser object at 'filepath/config.ini'. 32 | """ 33 | saved_config_filepath = os.path.join(filepath, 'config.ini') 34 | saved_config = configparser.ConfigParser() 35 | saved_config.read(saved_config_filepath) 36 | 37 | return saved_config 38 | 39 | def unprocess_args(args): 40 | """Unprocesses processed config args. 41 | 42 | Given a dictionary of arguments ('arg'), returns a dictionary where all values have been 43 | converted to string representation. 44 | 45 | Returns: 46 | args, where all values have been replaced by a str representation. 47 | """ 48 | unprocessed_args = {} 49 | for arg, value in args.items(): 50 | if isinstance(value, list): 51 | unprocessed_arg = ', '.join(value) 52 | elif isinstance(value, dict): 53 | dict_values = [str(v) for v in value.values()] 54 | unprocessed_arg = ', '.join(dict_values) 55 | else: 56 | unprocessed_arg = str(value) 57 | 58 | unprocessed_args[arg] = unprocessed_arg 59 | 60 | return unprocessed_args 61 | -------------------------------------------------------------------------------- /saber/tests/test_app_utils.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the app_utils (saber/utils/app_utils.py). 2 | """ 3 | import pytest 4 | 5 | from ..utils import app_utils 6 | from .resources.dummy_constants import * 7 | 8 | ############################################ UNIT TESTS ############################################ 9 | 10 | def test_get_pubmed_xml_errors(): 11 | """Asserts that call to `app_utils.get_pubmed_xml()` raises a ValueError error when an invalid 12 | value for argument `pmid` is passed.""" 13 | invalid_pmids = [["test"], "test", 0.0, 0, -1, (42,)] 14 | 15 | for pmid in invalid_pmids: 16 | with pytest.raises(ValueError): 17 | app_utils.get_pubmed_xml(pmid) 18 | 19 | def test_harmonize_entities(): 20 | """Asserts that app_utils.harmonize_entities() returns the expected results.""" 21 | # single bool test 22 | one_on_test = {'PRGE': True} 23 | one_on_expected = {'ANAT': False, 'CHED': False, 'DISO': False, 24 | 'LIVB': False, 'PRGE': True, 'TRIG': False} 25 | # multi bool test 26 | multi_on_test = {'PRGE': True, 'CHED': True, 'TRIG': False} 27 | multi_on_expected = {'ANAT': False, 'CHED': True, 'DISO': False, 28 | 'LIVB': False, 'PRGE': True, 'TRIG': False} 29 | # null test 30 | none_on_test = {} 31 | none_on_expected = {'ANAT': False, 'CHED': False, 'DISO': False, 32 | 'LIVB': False, 'PRGE': False, 'TRIG': False} 33 | 34 | assert one_on_expected == \ 35 | app_utils.harmonize_entities(DUMMY_ENTITIES, one_on_test) 36 | assert multi_on_expected == \ 37 | app_utils.harmonize_entities(DUMMY_ENTITIES, multi_on_test) 38 | assert none_on_expected == \ 39 | app_utils.harmonize_entities(DUMMY_ENTITIES, none_on_test) 40 | -------------------------------------------------------------------------------- /saber/tests/test_base_model.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the BaseKerasModel (saber/models/base_model.py). 2 | """ 3 | import pytest 4 | 5 | from ..config import Config 6 | from ..dataset import Dataset 7 | from ..embeddings import Embeddings 8 | from ..models.base_model import BaseKerasModel 9 | from .resources.dummy_constants import * 10 | 11 | ######################################### PYTEST FIXTURES ######################################### 12 | 13 | @pytest.fixture 14 | def dummy_config(): 15 | """Returns an instance of a Config object.""" 16 | return Config(PATH_TO_DUMMY_CONFIG) 17 | 18 | @pytest.fixture 19 | def dummy_dataset_1(): 20 | """Returns a single dummy Dataset instance after calling `Dataset.load()`. 21 | """ 22 | # Don't replace rare tokens for the sake of testing 23 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False) 24 | dataset.load() 25 | 26 | return dataset 27 | 28 | @pytest.fixture 29 | def dummy_dataset_2(): 30 | """Returns a single dummy Dataset instance after calling `Dataset.load()`. 31 | """ 32 | # Don't replace rare tokens for the sake of testing 33 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False) 34 | dataset.load() 35 | 36 | return dataset 37 | 38 | @pytest.fixture 39 | def dummy_embeddings(dummy_dataset_1): 40 | """Returns an instance of an `Embeddings()` object AFTER the `.load()` method is called. 41 | """ 42 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, 43 | token_map=dummy_dataset_1.idx_to_tag) 44 | embeddings.load(binary=False) # txt file format is easier to test 45 | return embeddings 46 | 47 | @pytest.fixture 48 | def single_model(dummy_config, dummy_dataset_1, dummy_embeddings): 49 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration.""" 50 | model = BaseKerasModel(config=dummy_config, 51 | datasets=[dummy_dataset_1], 52 | # to test passing of arbitrary keyword args to constructor 53 | totally_arbitrary='arbitrary') 54 | return model 55 | 56 | @pytest.fixture 57 | def single_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings): 58 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and 59 | loaded embeddings""" 60 | model = BaseKerasModel(config=dummy_config, 61 | datasets=[dummy_dataset_1], 62 | embeddings=dummy_embeddings, 63 | # to test passing of arbitrary keyword args to constructor 64 | totally_arbitrary='arbitrary') 65 | return model 66 | 67 | ############################################ UNIT TESTS ############################################ 68 | 69 | def test_compile_value_error(single_model): 70 | """Asserts that `BaseKerasModel._compile()` returns a ValueError when an invalid argument for 71 | `optimizer` is passed. 72 | """ 73 | with pytest.raises(ValueError): 74 | single_model._compile('arbitrary', 'arbitrary', 'invalid') 75 | 76 | def test_attributes_init_of_single_model(dummy_config, dummy_dataset_1, single_model): 77 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is 78 | initialized without embeddings (`embeddings` attribute is None.) 79 | """ 80 | assert isinstance(single_model, BaseKerasModel) 81 | # attributes that are passed to __init__ 82 | assert single_model.config is dummy_config 83 | assert single_model.datasets[0] is dummy_dataset_1 84 | assert single_model.embeddings is None 85 | # other instance attributes 86 | assert single_model.models == [] 87 | # test that we can pass arbitrary keyword arguments 88 | assert single_model.totally_arbitrary == 'arbitrary' 89 | 90 | def test_attributes_init_of_single_model_embeddings(dummy_config, dummy_dataset_1, 91 | dummy_embeddings, single_model_embeddings): 92 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is 93 | initialized with embeddings (`embeddings` attribute is not None.) 94 | """ 95 | assert isinstance(single_model_embeddings, BaseKerasModel) 96 | # attributes that are passed to __init__ 97 | assert single_model_embeddings.config is dummy_config 98 | assert single_model_embeddings.datasets[0] is dummy_dataset_1 99 | assert single_model_embeddings.embeddings is dummy_embeddings 100 | # other instance attributes 101 | assert single_model_embeddings.models == [] 102 | # test that we can pass arbitrary keyword arguments 103 | assert single_model_embeddings.totally_arbitrary == 'arbitrary' 104 | 105 | def test_prepare_data_for_training(dummy_dataset_1, single_model): 106 | """Assert that the values returned from call to `BaseKerasModel.prepare_data_for_training()` are 107 | as expected. 108 | """ 109 | training_data = single_model.prepare_data_for_training() 110 | partitions = ['x_train', 'y_train', 'x_valid', 'y_valid', 'x_test', 'y_test'] 111 | 112 | # assert each item in training_data contains the expected keys 113 | assert all(partition in data for data in training_data for partition in partitions) 114 | 115 | # assert that the items in training_data contain the expected values 116 | assert all(data['x_train'] == [dummy_dataset_1.idx_seq['train']['word'], dummy_dataset_1.idx_seq['train']['char']] 117 | for data in training_data) 118 | assert all(data['x_valid'] == [dummy_dataset_1.idx_seq['valid']['word'], dummy_dataset_1.idx_seq['valid']['char']] 119 | for data in training_data) 120 | assert all(data['x_test'] == [dummy_dataset_1.idx_seq['test']['word'], dummy_dataset_1.idx_seq['test']['char']] 121 | for data in training_data) 122 | assert all(np.array_equal(data['y_train'], dummy_dataset_1.idx_seq['train']['tag']) for data in training_data) 123 | assert all(np.array_equal(data['y_valid'], dummy_dataset_1.idx_seq['valid']['tag']) for data in training_data) 124 | assert all(np.array_equal(data['y_test'], dummy_dataset_1.idx_seq['test']['tag']) for data in training_data) 125 | -------------------------------------------------------------------------------- /saber/tests/test_config.py: -------------------------------------------------------------------------------- 1 | """Contains any and all unit tests for the config.Config class (saber/config.py). 2 | """ 3 | import pytest 4 | 5 | from .. import config 6 | from .resources.dummy_constants import * 7 | from .resources.helpers import * 8 | 9 | ######################################### PYTEST FIXTURES ######################################### 10 | 11 | @pytest.fixture 12 | def config_no_cli_args(): 13 | """Returns an instance of a config.Config object after parsing the dummy config file with no command 14 | line interface (CLI) args.""" 15 | # parse the dummy config 16 | dummy_config = config.Config(PATH_TO_DUMMY_CONFIG) 17 | 18 | return dummy_config 19 | 20 | @pytest.fixture 21 | def config_with_cli_args(): 22 | """Returns an instance of a config.config.Config object after parsing the dummy config file with command line 23 | interface (CLI) args.""" 24 | # parse the dummy config, leave cli false and instead pass command line args manually 25 | dummy_config = config.Config(PATH_TO_DUMMY_CONFIG) 26 | # this is a bit of a hack, but need to simulate providing commands at the command line 27 | dummy_config.cli_args = DUMMY_COMMAND_LINE_ARGS 28 | dummy_config.harmonize_args(DUMMY_COMMAND_LINE_ARGS) 29 | 30 | return dummy_config 31 | 32 | ############################################ UNIT TESTS ############################################ 33 | 34 | def test_process_args_no_cli_args(config_no_cli_args): 35 | """Asserts the config.Config.config object contains the expected attributes after initializing a config.Config 36 | object without CLI args.""" 37 | # check filepath attribute 38 | assert config_no_cli_args.filepath == PATH_TO_DUMMY_CONFIG 39 | # check that the config file contains the same values as DUMMY_ARGS_NO_PROCESSING 40 | config = config_no_cli_args.config 41 | for section in CONFIG_SECTIONS: 42 | for arg, value in config[section].items(): 43 | assert value == DUMMY_ARGS_NO_PROCESSING[arg] 44 | # check cli_args attribute 45 | assert config_no_cli_args.cli_args == {} 46 | 47 | def test_process_args_with_cli_args(config_with_cli_args): 48 | """Asserts the config.Config.config object contains the expected attributes after initializing a config.Config 49 | object with CLI args.""" 50 | # check filepath attribute 51 | assert config_with_cli_args.filepath == os.path.join(os.path.dirname( \ 52 | os.path.os.path.abspath(__file__)), PATH_TO_DUMMY_CONFIG) 53 | config = config_with_cli_args.config 54 | # check that the config file contains the same values as DUMMY_ARGS_NO_PROCESSING 55 | for section in CONFIG_SECTIONS: 56 | for arg, value in config[section].items(): 57 | assert value == DUMMY_ARGS_NO_PROCESSING[arg] 58 | # check cli_args attribute 59 | assert config_with_cli_args.cli_args == DUMMY_COMMAND_LINE_ARGS 60 | 61 | def test_config_attributes_no_cli_args(config_no_cli_args): 62 | """Asserts that the class attributes of a config.Config object are of the expected value/type after 63 | objects initialization, with NO command line arguments. 64 | """ 65 | # check that we get the values we expected 66 | for arg, value in DUMMY_ARGS_NO_CLI_ARGS.items(): 67 | assert value == getattr(config_no_cli_args, arg) 68 | 69 | def test_config_attributes_with_cli_args(config_with_cli_args): 70 | """Asserts that the class attributes of a config.Config object are of the expected value/type after 71 | object initialization, taking into account command line arguments, which take precedence over 72 | config arguments. 73 | """ 74 | # check that we get the values we expected, specifically, check that our command line arguments 75 | # have overwritten our config arguments 76 | for arg, value in DUMMY_ARGS_WITH_CLI_ARGS.items(): 77 | assert value == getattr(config_with_cli_args, arg) 78 | 79 | def test_resolve_filepath(config_no_cli_args): 80 | """Asserts that `Config._resolve_filepath()` returns the expected values. 81 | """ 82 | # tests for when neither filepath nor cli_args arguments are provided 83 | filepath_none_cli_args_none_expected = resource_filename(config.__name__, constants.CONFIG_FILENAME) 84 | filepath_none_cli_args_none_actual = config_no_cli_args._resolve_filepath(filepath=None, cli_args={}) 85 | # tests for when cli_args argument is provided 86 | filepath_none_cli_args_expected = 'arbitrary/filepath/to/config.ini' 87 | dummy_cli_args = {'config_filepath': filepath_none_cli_args_expected} 88 | filepath_none_cli_args_actual = config_no_cli_args._resolve_filepath(filepath=None, 89 | cli_args=dummy_cli_args) 90 | # tests for when filepath argument is provided 91 | filepath_cli_args_none_expected = filepath_none_cli_args_expected 92 | filepath_cli_args_none_actual = config_no_cli_args._resolve_filepath(filepath=filepath_cli_args_none_expected, 93 | cli_args={}) 94 | # tests for when both filepath and cli_args arguments are provided 95 | filepath_cli_args_expected = filepath_none_cli_args_expected 96 | filepath_cli_args_actual = config_no_cli_args._resolve_filepath(filepath=filepath_cli_args_expected, 97 | cli_args=dummy_cli_args) 98 | 99 | assert filepath_none_cli_args_none_expected == filepath_none_cli_args_none_actual 100 | assert filepath_none_cli_args_expected == filepath_none_cli_args_actual 101 | assert filepath_cli_args_none_expected == filepath_cli_args_none_actual 102 | assert filepath_cli_args_expected == filepath_cli_args_actual 103 | 104 | def test_key_error(tmpdir): 105 | """Assert that a KeyError is raised when Config object is initialized with a value for 106 | `filepath` that does does contain a valid *.ini file. 107 | """ 108 | with pytest.raises(KeyError): 109 | dummy_config = config.Config(tmpdir.strpath) 110 | 111 | def test_save_no_cli_args(config_no_cli_args, tmpdir): 112 | """Asserts that a saved config file contains the correct arguments and values.""" 113 | # save the config to temporary directory created by py.test 114 | config_no_cli_args.save(tmpdir.strpath) 115 | # load the saved config 116 | saved_config = load_saved_config(tmpdir.strpath) 117 | # need to 'unprocess' the args to check them against the saved config file 118 | unprocessed_args = unprocess_args(DUMMY_ARGS_NO_CLI_ARGS) 119 | # ensure the saved config file matches the original arguments used to create it 120 | for section in CONFIG_SECTIONS: 121 | for arg, value in saved_config[section].items(): 122 | assert value == unprocessed_args[arg] 123 | 124 | def test_save_with_cli_args(config_with_cli_args, tmpdir): 125 | """Asserts that a saved config file contains the correct arguments and values, taking into 126 | account command line arguments, which take precedence over config arguments. 127 | """ 128 | # save the config to temporary directory created by py.test 129 | config_with_cli_args.save(tmpdir.strpath) 130 | # load the saved config 131 | saved_config = load_saved_config(tmpdir.strpath) 132 | # need to 'unprocess' the args to check them against the saved config file 133 | unprocessed_args = unprocess_args(DUMMY_ARGS_WITH_CLI_ARGS) 134 | # ensure the saved config file matches the original arguments used to create it 135 | for section in CONFIG_SECTIONS: 136 | for arg, value in saved_config[section].items(): 137 | assert value == unprocessed_args[arg] 138 | -------------------------------------------------------------------------------- /saber/tests/test_data_utils.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the data_utils (saber/utils/data_utils.py). 2 | """ 3 | import numpy as np 4 | import pytest 5 | 6 | from ..config import Config 7 | from ..dataset import Dataset 8 | from ..utils import data_utils 9 | from .resources.dummy_constants import * 10 | 11 | ######################################### PYTEST FIXTURES ######################################### 12 | 13 | @pytest.fixture 14 | def dummy_config(): 15 | """Returns an instance of a Config object.""" 16 | dummy_config = Config(PATH_TO_DUMMY_CONFIG) 17 | return dummy_config 18 | 19 | @pytest.fixture 20 | def dummy_dataset_1(): 21 | """Returns a single dummy Dataset instance after calling Dataset.load(). 22 | """ 23 | # Don't replace rare tokens for the sake of testing 24 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False) 25 | dataset.load() 26 | 27 | return dataset 28 | 29 | @pytest.fixture 30 | def dummy_dataset_2(): 31 | """Returns a single dummy Dataset instance after calling `Dataset.load()`. 32 | """ 33 | # Don't replace rare tokens for the sake of testing 34 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False) 35 | dataset.load() 36 | 37 | return dataset 38 | 39 | @pytest.fixture 40 | def dummy_compound_dataset(dummy_config): 41 | """ 42 | """ 43 | dummy_config.dataset_folder = [PATH_TO_DUMMY_DATASET_1, PATH_TO_DUMMY_DATASET_2] 44 | dummy_config.replace_rare_tokens = False 45 | dataset = data_utils.load_compound_dataset(dummy_config) 46 | 47 | return dataset 48 | 49 | @pytest.fixture(scope='session') 50 | def dummy_dataset_paths_all(tmpdir_factory): 51 | """Creates and returns the path to a temporary dataset folder, and train, valid, test files. 52 | """ 53 | # create a dummy dataset folder 54 | dummy_dir = tmpdir_factory.mktemp('dummy_dataset') 55 | # create train, valid and train partitions in this folder 56 | train_file = dummy_dir.join('train.tsv') 57 | train_file.write('arbitrary') # need to write content or else the file wont exist 58 | valid_file = dummy_dir.join('valid.tsv') 59 | valid_file.write('arbitrary') 60 | test_file = dummy_dir.join('test.tsv') 61 | test_file.write('arbitrary') 62 | 63 | return dummy_dir.strpath, train_file.strpath, valid_file.strpath, test_file.strpath 64 | 65 | @pytest.fixture(scope='session') 66 | def dummy_dataset_paths_no_valid(tmpdir_factory): 67 | """Creates and returns the path to a temporary dataset folder, and train, and test files. 68 | """ 69 | # create a dummy dataset folder 70 | dummy_dir = tmpdir_factory.mktemp('dummy_dataset') 71 | # create train, valid and train partitions in this folder 72 | train_file = dummy_dir.join('train.tsv') 73 | train_file.write('arbitrary') # need to write content or else the file wont exist 74 | test_file = dummy_dir.join('test.tsv') 75 | test_file.write('arbitrary') 76 | 77 | return dummy_dir.strpath, train_file.strpath, test_file.strpath 78 | 79 | ############################################ UNIT TESTS ############################################ 80 | 81 | def test_get_filepaths_value_error(tmpdir): 82 | """Asserts that a ValueError is raised when `data_utils.get_filepaths(tmpdir)` is called and 83 | no file '/train.*' exists. 84 | """ 85 | with pytest.raises(ValueError): 86 | data_utils.get_filepaths(tmpdir.strpath) 87 | 88 | def test_get_filepaths_all(dummy_dataset_paths_all): 89 | """Asserts that `data_utils.get_filepaths()` returns the expected filepaths when all partitions 90 | (train/test/valid) are provided. 91 | """ 92 | dummy_dataset_directory, train_filepath, valid_filepath, test_filepath = dummy_dataset_paths_all 93 | expected = {'train': train_filepath, 94 | 'valid': valid_filepath, 95 | 'test': test_filepath 96 | } 97 | actual = data_utils.get_filepaths(dummy_dataset_directory) 98 | 99 | assert actual == expected 100 | 101 | def test_get_filepaths_no_valid(dummy_dataset_paths_no_valid): 102 | """Asserts that `data_utils.get_filepaths()` returns the expected filepaths when train and 103 | test partitions are provided. 104 | """ 105 | dummy_dataset_directory, train_filepath, test_filepath = dummy_dataset_paths_no_valid 106 | expected = {'train': train_filepath, 107 | 'valid': None, 108 | 'test': test_filepath 109 | } 110 | actual = data_utils.get_filepaths(dummy_dataset_directory) 111 | 112 | assert actual == expected 113 | 114 | def test_load_single_dataset(dummy_config, dummy_dataset_1): 115 | """Asserts that `data_utils.load_single_dataset()` returns the expected value. 116 | """ 117 | actual = data_utils.load_single_dataset(dummy_config) 118 | expected = [dummy_dataset_1] 119 | 120 | # essentially redundant, but if we dont return a [Dataset] object then the error message from 121 | # the final test could be cryptic 122 | assert isinstance(actual, list) 123 | assert len(actual) == 1 124 | assert isinstance(actual[0], Dataset) 125 | # the test we actually care about, least roundabout way of asking if the two Dataset objects 126 | # are identical 127 | assert dir(actual[0].__dict__) == dir(expected[0].__dict__) 128 | 129 | def test_load_compound_dataset_unchanged_attributes(dummy_dataset_1, 130 | dummy_dataset_2, 131 | dummy_compound_dataset): 132 | """Asserts that attributes of `Dataset` objects which are expected to remain unchanged 133 | are unchanged after call to `data_utils.load_compound_dataset()`. 134 | """ 135 | actual = dummy_compound_dataset 136 | expected = [dummy_dataset_1, dummy_dataset_2] 137 | 138 | # essentially redundant, but if we dont return a [Dataset, Dataset] object then the error 139 | # messages from the downstream tests could be cryptic 140 | assert isinstance(actual, list) 141 | assert len(actual) == 2 142 | assert all([isinstance(ds, Dataset) for ds in actual]) 143 | 144 | # attributes that are unchanged in case of compound dataset 145 | assert actual[0].directory == expected[0].directory 146 | assert actual[0].replace_rare_tokens == expected[0].replace_rare_tokens 147 | assert actual[0].type_seq == expected[0].type_seq 148 | assert actual[0].type_to_idx['tag'] == expected[0].type_to_idx['tag'] 149 | assert actual[0].idx_to_tag == expected[0].idx_to_tag 150 | 151 | assert actual[-1].directory == expected[-1].directory 152 | assert actual[-1].replace_rare_tokens == expected[-1].replace_rare_tokens 153 | assert actual[-1].type_seq == expected[-1].type_seq 154 | assert actual[-1].type_to_idx['tag'] == expected[-1].type_to_idx['tag'] 155 | assert actual[-1].idx_to_tag == expected[-1].idx_to_tag 156 | 157 | def test_load_compound_dataset_changed_attributes(dummy_dataset_1, 158 | dummy_dataset_2, 159 | dummy_compound_dataset): 160 | """Asserts that attributes of `Dataset` objects which are expected to be changed are changed 161 | after call to `data_utils.load_compound_dataset()`. 162 | """ 163 | actual = dummy_compound_dataset 164 | expected = [dummy_dataset_1, dummy_dataset_2] 165 | 166 | # essentially redundant, but if we don't return a [Dataset, Dataset] object then the error 167 | # messages from the downstream tests could be cryptic 168 | assert isinstance(actual, list) 169 | assert len(actual) == 2 170 | assert all([isinstance(ds, Dataset) for ds in actual]) 171 | 172 | # attributes that are changed in case of compound dataset 173 | assert actual[0].type_to_idx['word'] == actual[-1].type_to_idx['word'] 174 | assert actual[0].type_to_idx['char'] == actual[-1].type_to_idx['char'] 175 | 176 | # TODO: Need to assert that all types in idx_seq map to the same integers 177 | # across the compound datasets 178 | 179 | def test_setup_dataset_for_transfer(dummy_dataset_1, dummy_dataset_2): 180 | """Asserts that the `type_to_idx` attribute of a "source" dataset and a "target" dataset are 181 | as expected after call to `data_utils.setup_dataset_for_transfer()`. 182 | """ 183 | source_type_to_idx = dummy_dataset_1.type_to_idx 184 | data_utils.setup_dataset_for_transfer(dummy_dataset_2, source_type_to_idx) 185 | 186 | assert all(dummy_dataset_2.type_to_idx[type_] == source_type_to_idx[type_] for type_ in ['word', 'char']) 187 | -------------------------------------------------------------------------------- /saber/tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | """Contains any and all unit tests for the `Dataset` class (saber/dataset.py). 2 | """ 3 | import os 4 | 5 | import numpy as np 6 | import pytest 7 | from nltk.corpus.reader.conll import ConllCorpusReader 8 | 9 | from .. import constants 10 | from ..dataset import Dataset 11 | from ..utils import generic_utils 12 | from .resources.dummy_constants import * 13 | 14 | # TODO (johngiorgi): Need to include tests for valid/test partitions 15 | # TODO (johngiorgi): Need to include tests for compound datasets 16 | 17 | ######################################### PYTEST FIXTURES ######################################### 18 | @pytest.fixture 19 | def empty_dummy_dataset(): 20 | """Returns an empty single dummy Dataset instance. 21 | """ 22 | # Don't replace rare tokens for the sake of testing 23 | return Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False, 24 | # to test passing of arbitrary keyword args to constructor 25 | totally_arbitrary='arbitrary') 26 | 27 | @pytest.fixture 28 | def loaded_dummy_dataset(): 29 | """Returns a single dummy Dataset instance after calling Dataset.load(). 30 | """ 31 | # Don't replace rare tokens for the sake of testing 32 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False) 33 | dataset.load() 34 | 35 | return dataset 36 | 37 | ############################################ UNIT TESTS ############################################ 38 | 39 | # Generic object tests 40 | 41 | def test_attributes_after_initilization_of_dataset(empty_dummy_dataset): 42 | """Asserts instance attributes are initialized correctly when dataset is empty (i.e., 43 | `Dataset.load()` has not been called). 44 | """ 45 | # attributes that are passed to __init__ 46 | for partition in empty_dummy_dataset.directory: 47 | expected = os.path.join(PATH_TO_DUMMY_DATASET_1, '{}.tsv'.format(partition)) 48 | assert empty_dummy_dataset.directory[partition] == expected 49 | assert not empty_dummy_dataset.replace_rare_tokens 50 | # other instance attributes 51 | assert empty_dummy_dataset.conll_parser.root == PATH_TO_DUMMY_DATASET_1 52 | assert empty_dummy_dataset.type_seq == {'train': None, 'valid': None, 'test': None} 53 | assert empty_dummy_dataset.type_to_idx == {'word': None, 'char': None, 'tag': None} 54 | assert empty_dummy_dataset.idx_to_tag is None 55 | assert empty_dummy_dataset.idx_seq == {'train': None, 'valid': None, 'test': None} 56 | # test that we can pass arbitrary keyword arguments 57 | assert empty_dummy_dataset.totally_arbitrary == 'arbitrary' 58 | 59 | def test_value_error_load(empty_dummy_dataset): 60 | """Asserts that `Dataset.load()` raises a ValueError when `Dataset.directory` is None. 61 | """ 62 | # Set directory to None to force error to arise 63 | empty_dummy_dataset.directory = None 64 | with pytest.raises(ValueError): 65 | empty_dummy_dataset.load() 66 | 67 | # SINGLE DATASET 68 | 69 | def test_get_types_single_dataset(empty_dummy_dataset): 70 | """Asserts that `Dataset._get_types()` returns the expected values. 71 | """ 72 | actual = empty_dummy_dataset._get_types() 73 | expected = {'word': DUMMY_WORD_TYPES, 'char': DUMMY_CHAR_TYPES, 'tag': DUMMY_TAG_TYPES} 74 | 75 | # sort allows us to assert that the two lists are identical 76 | assert all(actual['word'].sort() == v.sort() for k, v in expected.items()) 77 | 78 | # Tests on unloaded Dataset object (`Dataset.load()` was not called) 79 | 80 | def test_get_type_seq_single_dataset_before_load(empty_dummy_dataset): 81 | """Asserts that `Dataset.type_seq` is updated as expected after call to 82 | `Dataset._get_type_seq()`. 83 | """ 84 | empty_dummy_dataset._get_type_seq() 85 | 86 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['word'], DUMMY_WORD_SEQ) 87 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['char'], DUMMY_CHAR_SEQ) 88 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['tag'], DUMMY_TAG_SEQ) 89 | 90 | def test_get_idx_maps_single_dataset_before_load(empty_dummy_dataset): 91 | """Asserts that `Dataset.type_to_idx` is updated as expected after successive calls to 92 | `Dataset._get_types()` and `Dataset._get_idx_maps()`. 93 | """ 94 | types = empty_dummy_dataset._get_types() 95 | empty_dummy_dataset._get_idx_maps(types) 96 | 97 | # ensure that index mapping is a contigous sequence of numbers starting at 0 98 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['word'].values()) 99 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['char'].values()) 100 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['tag'].values()) 101 | # ensure that type to index mapping contains the expected keys 102 | assert all(key in DUMMY_WORD_TYPES for key in empty_dummy_dataset.type_to_idx['word']) 103 | assert all(key in DUMMY_CHAR_TYPES for key in empty_dummy_dataset.type_to_idx['char']) 104 | assert all(key in DUMMY_TAG_TYPES for key in empty_dummy_dataset.type_to_idx['tag']) 105 | 106 | def test_get_idx_maps_single_dataset_before_load_special_tokens(empty_dummy_dataset): 107 | """Asserts that `Dataset.type_to_idx` contains the special tokens as keys with expected values 108 | after successive calls to `Dataset._get_types()` and `Dataset._get_idx_maps()`. 109 | """ 110 | types = empty_dummy_dataset._get_types() 111 | empty_dummy_dataset._get_idx_maps(types) 112 | # assert special tokens are mapped to the correct indices 113 | assert all(empty_dummy_dataset.type_to_idx['word'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items()) 114 | assert all(empty_dummy_dataset.type_to_idx['char'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items()) 115 | assert all(empty_dummy_dataset.type_to_idx['tag'][k] == v for k, v in constants.INITIAL_MAPPING['tag'].items()) 116 | 117 | def test_get_idx_seq_single_dataset_before_load(empty_dummy_dataset): 118 | """Asserts that `Dataset.idx_seq` is updated as expected after successive calls to 119 | `Dataset._get_type_seq()`, `Dataset._get_idx_maps()` and `Dataset.get_idx_seq()`. 120 | """ 121 | types = empty_dummy_dataset._get_types() 122 | empty_dummy_dataset._get_type_seq() 123 | empty_dummy_dataset._get_idx_maps(types) 124 | empty_dummy_dataset.get_idx_seq() 125 | 126 | # as a workaround to testing this directly, just check that shapes are as expected 127 | expected_word_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN) 128 | expected_char_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, constants.MAX_CHAR_LEN) 129 | expected_tag_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, len(DUMMY_TAG_TYPES)) 130 | 131 | assert all(empty_dummy_dataset.idx_seq[partition]['word'].shape == expected_word_idx_shape 132 | for partition in ['train', 'test', 'valid']) 133 | assert all(empty_dummy_dataset.idx_seq[partition]['char'].shape == expected_char_idx_shape 134 | for partition in ['train', 'test', 'valid']) 135 | assert all(empty_dummy_dataset.idx_seq[partition]['tag'].shape == expected_tag_idx_shape 136 | for partition in ['train', 'test', 'valid']) 137 | 138 | # tests on loaded Dataset object (`Dataset.load()` was called) 139 | 140 | def test_get_type_seq_single_dataset_after_load(loaded_dummy_dataset): 141 | """Asserts that `Dataset.type_seq` is updated as expected after call to `Dataset.load()`. 142 | """ 143 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['word'], DUMMY_WORD_SEQ) 144 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['char'], DUMMY_CHAR_SEQ) 145 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['tag'], DUMMY_TAG_SEQ) 146 | 147 | def test_get_idx_maps_single_dataset_after_load(loaded_dummy_dataset): 148 | """Asserts that `Dataset.type_to_idx` is updated as expected after call to `Dataset.load()`. 149 | """ 150 | # ensure that index mapping is a contigous sequence of numbers starting at 0 151 | # ensure that index mapping is a contigous sequence of numbers starting at 0 152 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['word'].values()) 153 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['char'].values()) 154 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['tag'].values()) 155 | # ensure that type to index mapping contains the expected keys 156 | assert all(key in DUMMY_WORD_TYPES for key in loaded_dummy_dataset.type_to_idx['word']) 157 | assert all(key in DUMMY_CHAR_TYPES for key in loaded_dummy_dataset.type_to_idx['char']) 158 | assert all(key in DUMMY_TAG_TYPES for key in loaded_dummy_dataset.type_to_idx['tag']) 159 | 160 | def test_get_idx_maps_single_dataset_after_load_special_tokens(loaded_dummy_dataset): 161 | """Asserts that `Dataset.type_to_idx` contains the special tokens as keys with expected values 162 | after successive calls to `Dataset._get_types()` and `Dataset.get_idx_maps()`. 163 | """ 164 | # assert special tokens are mapped to the correct indices 165 | assert all(loaded_dummy_dataset.type_to_idx['word'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items()) 166 | assert all(loaded_dummy_dataset.type_to_idx['char'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items()) 167 | assert all(loaded_dummy_dataset.type_to_idx['tag'][k] == v for k, v in constants.INITIAL_MAPPING['tag'].items()) 168 | 169 | def test_get_idx_seq_after_load(loaded_dummy_dataset): 170 | """Asserts that `Dataset.idx_seq` is updated as expected after calls to `Dataset.load()`. 171 | """ 172 | # as a workaround to testing this directly, just check that shapes are as expected 173 | expected_word_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN) 174 | expected_char_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, constants.MAX_CHAR_LEN) 175 | expected_tag_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, len(DUMMY_TAG_TYPES)) 176 | 177 | assert all(loaded_dummy_dataset.idx_seq[partition]['word'].shape == expected_word_idx_shape 178 | for partition in ['train', 'test', 'valid']) 179 | assert all(loaded_dummy_dataset.idx_seq[partition]['char'].shape == expected_char_idx_shape 180 | for partition in ['train', 'test', 'valid']) 181 | assert all(loaded_dummy_dataset.idx_seq[partition]['tag'].shape == expected_tag_idx_shape 182 | for partition in ['train', 'test', 'valid']) 183 | 184 | # COMPOUND DATASET 185 | -------------------------------------------------------------------------------- /saber/tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the Embeddings class (saber/embeddings.py). 2 | """ 3 | import numpy as np 4 | import pytest 5 | 6 | from ..embeddings import Embeddings 7 | from .resources import helpers 8 | from .resources.dummy_constants import * 9 | 10 | # TODO (johngiorgi): write tests using a binary format file 11 | # TODO (johngiorgi): write tests to test for debug functionality 12 | 13 | ######################################### PYTEST FIXTURES ######################################### 14 | 15 | @pytest.fixture 16 | def dummy_embedding_idx(): 17 | """Returns embedding index from call to `Embeddings._prepare_embedding_index()`. 18 | """ 19 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP) 20 | embedding_idx = embeddings._prepare_embedding_index(binary=False) 21 | return embedding_idx 22 | 23 | @pytest.fixture 24 | def dummy_embedding_matrix_and_type_to_idx(): 25 | """Returns the `embedding_matrix` and `type_to_index` objects from call to 26 | `Embeddings._prepare_embedding_matrix(load_all=False)`. 27 | """ 28 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP) 29 | embedding_idx = embeddings._prepare_embedding_index(binary=False) 30 | embeddings.num_found = len(embedding_idx) 31 | embeddings.dimension = len(list(embedding_idx.values())[0]) 32 | embedding_matrix, type_to_idx = embeddings._prepare_embedding_matrix(embedding_idx, load_all=False) 33 | embeddings.num_embed = embedding_matrix.shape[0] # num of embedded words 34 | 35 | return embedding_matrix, type_to_idx 36 | 37 | @pytest.fixture 38 | def dummy_embedding_matrix_and_type_to_idx_load_all(): 39 | """Returns the embedding matrix and type to index objects from call to 40 | `Embeddings._prepare_embedding_matrix(load_all=True)`. 41 | """ 42 | # this should be different than DUMMY_TOKEN_MAP for a reliable test 43 | test = {"This": 0, "is": 1, "a": 2, "test": 3} 44 | 45 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=test) 46 | embedding_idx = embeddings._prepare_embedding_index(binary=False) 47 | embeddings.num_found = len(embedding_idx) 48 | embeddings.dimension = len(list(embedding_idx.values())[0]) 49 | embedding_matrix, type_to_idx = embeddings._prepare_embedding_matrix(embedding_idx, load_all=True) 50 | embeddings.num_embed = embedding_matrix.shape[0] # num of embedded words 51 | 52 | return embedding_matrix, type_to_idx 53 | 54 | @pytest.fixture 55 | def dummy_embeddings_before_load(): 56 | """Returns an instance of an Embeddings() object BEFORE the `Embeddings.load()` method is 57 | called. 58 | """ 59 | return Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, 60 | token_map=DUMMY_TOKEN_MAP, 61 | # to test passing of arbitrary keyword args to constructor 62 | totally_arbitrary='arbitrary') 63 | 64 | @pytest.fixture 65 | def dummy_embeddings_after_load(): 66 | """Returns an instance of an Embeddings() object AFTER `Embeddings.load(load_all=False)` is 67 | called. 68 | """ 69 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP) 70 | embeddings.load(binary=False, load_all=False) # txt file format is easier to test 71 | return embeddings 72 | 73 | @pytest.fixture 74 | def dummy_embeddings_after_load_with_load_all(): 75 | """Returns an instance of an Embeddings() object AFTER `Embeddings.load(load_all=True)` is 76 | called. 77 | """ 78 | # this should be different than DUMMY_TOKEN_MAP for a reliable test 79 | test = {"This": 0, "is": 1, "a": 2, "test": 3} 80 | 81 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=test) 82 | embeddings.load(binary=False, load_all=True) # txt file format is easier to test 83 | return embeddings 84 | 85 | ############################################ UNIT TESTS ############################################ 86 | 87 | def test_initialization(dummy_embeddings_before_load): 88 | """Asserts that Embeddings object contains the expected attribute values after initialization. 89 | """ 90 | # test attributes whos values are passed to the constructor 91 | assert dummy_embeddings_before_load.filepath == PATH_TO_DUMMY_EMBEDDINGS 92 | assert dummy_embeddings_before_load.token_map == DUMMY_TOKEN_MAP 93 | # test attributes initialized with default values 94 | assert dummy_embeddings_before_load.matrix is None 95 | assert dummy_embeddings_before_load.num_found is None 96 | assert dummy_embeddings_before_load.num_embed is None 97 | assert dummy_embeddings_before_load.dimension is None 98 | # test that we can pass arbitrary keyword arguments 99 | assert dummy_embeddings_before_load.totally_arbitrary == 'arbitrary' 100 | 101 | def test_prepare_embedding_index(dummy_embedding_idx): 102 | """Asserts that we get the expected value back after call to 103 | `Embeddings._prepare_embedding_index()`. 104 | """ 105 | # need to check keys and values differently 106 | assert dummy_embedding_idx.keys() == DUMMY_EMBEDDINGS_INDEX.keys() 107 | assert all(np.allclose(actual, expected) for actual, expected in 108 | zip(dummy_embedding_idx.values(), DUMMY_EMBEDDINGS_INDEX.values())) 109 | 110 | def test_prepare_embedding_matrix(dummy_embedding_matrix_and_type_to_idx): 111 | """Asserts that we get the expected value back after successive calls to 112 | `Embeddings._prepare_embedding_index()` and 113 | `Embeddings._prepare_embedding_matrix(load_all=False)`. 114 | """ 115 | # expected_values 116 | embedding_matrix_expected, type_to_idx_expected = DUMMY_EMBEDDINGS_MATRIX, None 117 | # actual values 118 | embedding_matrix_actual, type_to_idx_actual = dummy_embedding_matrix_and_type_to_idx 119 | 120 | assert np.allclose(embedding_matrix_actual, embedding_matrix_expected) 121 | assert type_to_idx_actual is type_to_idx_expected 122 | 123 | def test_prepare_embedding_matrix_load_all(dummy_embedding_matrix_and_type_to_idx_load_all): 124 | """Asserts that we get the expected value back after successive calls to 125 | `Embeddings._prepare_embedding_index()` and 126 | `Embeddings._prepare_embedding_matrix(load_all=True)`. 127 | """ 128 | # expected values 129 | embedding_matrix_expected = DUMMY_EMBEDDINGS_MATRIX 130 | type_to_idx_expected = {"word": DUMMY_TOKEN_MAP, "char": DUMMY_CHAR_MAP} 131 | # actual values 132 | embedding_matrix_actual, type_to_idx_actual = dummy_embedding_matrix_and_type_to_idx_load_all 133 | 134 | assert np.allclose(embedding_matrix_actual, embedding_matrix_expected) 135 | helpers.assert_type_to_idx_as_expected(actual=type_to_idx_actual, expected=type_to_idx_expected) 136 | 137 | def test_matrix_after_load(dummy_embeddings_after_load): 138 | """Asserts that pre-trained token embeddings are loaded correctly when 139 | `Embeddings.load(load_all=False)` is called.""" 140 | assert np.allclose(dummy_embeddings_after_load.matrix, DUMMY_EMBEDDINGS_MATRIX) 141 | 142 | def test_matrix_after_load_with_load_all(dummy_embeddings_after_load): 143 | """Asserts that pre-trained token embeddings are loaded correctly when 144 | `Embeddings.load(load_all=True)` is called.""" 145 | assert np.allclose(dummy_embeddings_after_load.matrix, DUMMY_EMBEDDINGS_MATRIX) 146 | 147 | def test_attributes_after_load(dummy_embedding_idx, dummy_embeddings_after_load): 148 | """Asserts that attributes of Embeddings object are updated as expected after 149 | `Embeddings.load(load_all=False)` is called. 150 | """ 151 | # expected values 152 | num_found_expected = len(dummy_embedding_idx) 153 | dimension_expected = len(list(dummy_embedding_idx.values())[0]) 154 | num_embed_expected = dummy_embeddings_after_load.matrix.shape[0] 155 | # actual values 156 | num_found_actual = dummy_embeddings_after_load.num_found 157 | dimension_actual = dummy_embeddings_after_load.dimension 158 | num_embed_actual = dummy_embeddings_after_load.num_embed 159 | 160 | assert num_found_expected == num_found_actual 161 | assert dimension_expected == dimension_actual 162 | assert num_embed_expected == num_embed_actual 163 | 164 | def test_attributes_after_load_with_load_all(dummy_embedding_idx, 165 | dummy_embeddings_after_load_with_load_all): 166 | """Asserts that attributes of Embeddings object are updated as expected after 167 | `Embeddings.load(load_all=True)` is called. 168 | """ 169 | # expected values 170 | num_found_expected = len(dummy_embedding_idx) 171 | dimension_expected = len(list(dummy_embedding_idx.values())[0]) 172 | num_embed_expected = dummy_embeddings_after_load_with_load_all.matrix.shape[0] 173 | # actual values 174 | num_found_actual = dummy_embeddings_after_load_with_load_all.num_found 175 | dimension_actual = dummy_embeddings_after_load_with_load_all.dimension 176 | num_embed_actual = dummy_embeddings_after_load_with_load_all.num_embed 177 | 178 | assert num_found_expected == num_found_actual 179 | assert dimension_expected == dimension_actual 180 | assert num_embed_expected == num_embed_actual 181 | 182 | def test_generate_type_to_idx(dummy_embeddings_before_load): 183 | """Asserts that the dictionary returned from 'Embeddings._generate_type_to_idx()' is as 184 | expected. 185 | """ 186 | test = {'This': 0, 'is': 1, 'a': 2, 'test': 3} 187 | 188 | # expected values 189 | expected = { 190 | "word": list(test.keys()), 191 | "char": [] 192 | } 193 | for word in expected['word']: 194 | expected['char'].extend(list(word)) 195 | expected['char'] = list(set(expected['char'])) 196 | # actual values 197 | actual = dummy_embeddings_before_load._generate_type_to_idx(test) 198 | 199 | helpers.assert_type_to_idx_as_expected(actual=actual, expected=expected) 200 | -------------------------------------------------------------------------------- /saber/tests/test_generic_utils.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the generic_utils (saber/utils/generic_utils.py). 2 | """ 3 | import os 4 | 5 | import pytest 6 | 7 | from ..config import Config 8 | from ..utils import generic_utils 9 | from .resources.dummy_constants import * 10 | 11 | ######################################### PYTEST FIXTURES ######################################### 12 | 13 | @pytest.fixture(scope='session') 14 | def dummy_dir(tmpdir_factory): 15 | """Returns the path to a temporary directory. 16 | """ 17 | dummy_dir = tmpdir_factory.mktemp('dummy_dir') 18 | return dummy_dir.strpath 19 | 20 | @pytest.fixture 21 | def dummy_config(): 22 | """Returns an instance of a Config object.""" 23 | dummy_config = Config(PATH_TO_DUMMY_CONFIG) 24 | return dummy_config 25 | 26 | ############################################ UNIT TESTS ############################################ 27 | 28 | def test_is_consecutive_empty(): 29 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed an 30 | empty list. 31 | """ 32 | test = [] 33 | 34 | expected = True 35 | actual = generic_utils.is_consecutive(test) 36 | 37 | assert actual == expected 38 | 39 | def test_is_consecutive_simple_sorted_list_no_duplicates(): 40 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 41 | simple sorted list with no duplicates. 42 | """ 43 | test_true = [0, 1, 2, 3, 4, 5] 44 | test_false = [1, 2, 3, 4, 5, 6] 45 | 46 | expected_true = True 47 | expected_false = False 48 | 49 | actual_true = generic_utils.is_consecutive(test_true) 50 | actual_false = generic_utils.is_consecutive(test_false) 51 | 52 | assert actual_true == expected_true 53 | assert actual_false == expected_false 54 | 55 | def test_is_consecutive_simple_unsorted_list_no_duplicates(): 56 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 57 | simple unsorted list with no duplicates. 58 | """ 59 | test_true = [0, 1, 3, 2, 4, 5] 60 | test_false = [1, 2, 4, 3, 5, 6] 61 | 62 | expected_true = True 63 | expected_false = False 64 | 65 | actual_true = generic_utils.is_consecutive(test_true) 66 | actual_false = generic_utils.is_consecutive(test_false) 67 | 68 | assert actual_true == expected_true 69 | assert actual_false == expected_false 70 | 71 | def test_is_consecutive_simple_sorted_list_duplicates(): 72 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 73 | simple sorted list with duplicates. 74 | """ 75 | test = [0, 1, 2, 3, 3, 4, 5] 76 | 77 | expected = False 78 | actual = generic_utils.is_consecutive(test) 79 | 80 | assert actual == expected 81 | 82 | def test_is_consecutive_simple_unsorted_list_duplicates(): 83 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 84 | simple unsorted list with duplicates. 85 | """ 86 | test = [0, 1, 4, 3, 3, 2, 5] 87 | 88 | expected = False 89 | actual = generic_utils.is_consecutive(test) 90 | 91 | assert actual == expected 92 | 93 | def test_is_consecutive_simple_dict_no_duplicates(): 94 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 95 | dictionaries values no duplicates. 96 | """ 97 | test_true = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5} 98 | test_false = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6} 99 | 100 | expected_true = True 101 | expected_false = False 102 | 103 | actual_true = generic_utils.is_consecutive(test_true.values()) 104 | actual_false = generic_utils.is_consecutive(test_false.values()) 105 | 106 | assert actual_true == expected_true 107 | assert actual_false == expected_false 108 | 109 | def test_is_consecutive_simple_dict_duplicates(): 110 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a 111 | dictionaries values with duplicates. 112 | """ 113 | test = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 3, 'f': 4, 'g': 5} 114 | 115 | expected = False 116 | actual = generic_utils.is_consecutive(test) 117 | 118 | assert actual == expected 119 | 120 | def test_reverse_dict_empty(): 121 | """Asserts that `generic_utils.reverse_dictionary()` returns the expected value when given an 122 | empty dictionary. 123 | """ 124 | test = {} 125 | expected = {} 126 | actual = generic_utils.reverse_dict(test) 127 | 128 | assert actual == expected 129 | 130 | def test_reverse_mapping_simple(): 131 | """Asserts that `generic_utils.reverse_dictionary()` returns the expected value when given a 132 | simply dictionary. 133 | """ 134 | test = {'a': 1, 'b': 2, 'c': 3} 135 | 136 | expected = {1: 'a', 2: 'b', 3: 'c'} 137 | actual = generic_utils.reverse_dict(test) 138 | 139 | assert actual == expected 140 | 141 | def test_make_dir_new(tmpdir): 142 | """Assert that `generic_utils.make_dir()` creates a directory as expected when it does not 143 | already exist. 144 | """ 145 | dummy_dirpath = os.path.join(tmpdir.strpath, 'dummy_dir') 146 | generic_utils.make_dir(dummy_dirpath) 147 | assert os.path.isdir(dummy_dirpath) 148 | 149 | def test_make_dir_exists(dummy_dir): 150 | """Assert that `generic_utils.make_dir()` fails silently when trying to create a directory that 151 | already exists. 152 | """ 153 | generic_utils.make_dir(dummy_dir) 154 | assert os.path.isdir(dummy_dir) 155 | 156 | def test_clean_path(): 157 | """Asserts that filepath returned by `generic_utils.clean_path()` is as expected. 158 | """ 159 | test = ' this/is//a/test/ ' 160 | expected = os.path.abspath('this/is/a/test') 161 | 162 | assert generic_utils.clean_path(test) == expected 163 | 164 | def test_decompress_model(): 165 | """Asserts that `generic_utils.decompress_model()` decompresses a given directory. 166 | """ 167 | pass 168 | 169 | def test_compress_model(): 170 | """Asserts that `generic_utils.compress_model()` compresses a given directory. 171 | """ 172 | pass 173 | -------------------------------------------------------------------------------- /saber/tests/test_grounding_utils.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the grounding_utils (saber/utils/grounding_utils.py). 2 | """ 3 | import copy 4 | 5 | import pytest 6 | 7 | from .. import constants 8 | from ..utils import grounding_utils 9 | 10 | 11 | @pytest.fixture 12 | def blank_annotation(): 13 | """Returns an annotation with no identified entities. 14 | """ 15 | annotation = {"ents": [], 16 | "text": "This is a test with no entities.", 17 | "title": ""} 18 | return annotation 19 | 20 | @pytest.fixture 21 | def ched_annotation(): 22 | """Returns an annotation with chemical entities (CHED) identified. 23 | """ 24 | annotation = {"ents": [{"text": "glucose", "label": "CHED", "start": 0, "end": 0}, 25 | {"text": "fructose", "label": "CHED", "start": 0, "end": 0}], 26 | "text": "glucose and fructose", 27 | "title": ""} 28 | 29 | return annotation 30 | 31 | @pytest.fixture 32 | def diso_annotation(): 33 | """Returns an annotation with disease entities (DISO) identified. 34 | """ 35 | annotation = {"ents": [{"text": "cancer", "label": "DISO", "start": 0, "end": 0}, 36 | {"text": "cystic fibrosis", "label": "DISO", "start": 0, "end": 0}], 37 | "text": "cancer and cystic fibrosis", 38 | "title": ""} 39 | 40 | return annotation 41 | 42 | @pytest.fixture 43 | def livb_annotation(): 44 | """Returns an annotation with species entities (LIVB) identified. 45 | """ 46 | annotation = {"ents": [{"text": "mouse", "label": "LIVB", "start": 0, "end": 0}, 47 | {"text": "human", "label": "LIVB", "start": 0, "end": 0}], 48 | "text": "mouse and human", 49 | "title": ""} 50 | 51 | return annotation 52 | 53 | @pytest.fixture 54 | def prge_annotation(): 55 | """Returns an annotation with protein/gene entities (PRGE) identified. 56 | """ 57 | annotation = {"ents": [{"text": "p53", "label": "PRGE", "start": 0, "end": 0}, 58 | {"text": "MK2", "label": "PRGE", "start": 0, "end": 0}], 59 | "text": "p53 and MK2", 60 | "title": ""} 61 | 62 | return annotation 63 | 64 | def test_ground_no_entites(blank_annotation): 65 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with 66 | no identified entities. 67 | """ 68 | 69 | actual = grounding_utils.ground(blank_annotation) 70 | expected = blank_annotation 71 | 72 | assert actual == expected 73 | 74 | def test_ground_chemicals(ched_annotation): 75 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with 76 | chemical entities. 77 | """ 78 | actual = grounding_utils.ground(copy.deepcopy(ched_annotation)) 79 | 80 | # create expected value 81 | glucose_xrefs = [ 82 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs00005793'}, 83 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs10954115'}, 84 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs53782692'}, 85 | ] 86 | fructose_xrefs = [{'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs00439709'}] 87 | 88 | ched_annotation['ents'][0].update(xrefs=glucose_xrefs) 89 | ched_annotation['ents'][1].update(xrefs=fructose_xrefs) 90 | 91 | expected = ched_annotation 92 | 93 | assert actual == expected 94 | 95 | def test_ground_diso(diso_annotation): 96 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with 97 | disease entities. 98 | """ 99 | actual = grounding_utils.ground(copy.deepcopy(diso_annotation)) 100 | 101 | # create expected value 102 | cancer_xrefs = [{'namespace': constants.NAMESPACES['DISO'], 'id': 'DOID:162'}] 103 | cystic_fibrosis_xrefs = [{'namespace': constants.NAMESPACES['DISO'], 'id': 'DOID:1485'}] 104 | 105 | diso_annotation['ents'][0].update(xrefs=cancer_xrefs) 106 | diso_annotation['ents'][1].update(xrefs=cystic_fibrosis_xrefs) 107 | 108 | expected = diso_annotation 109 | 110 | assert actual == expected 111 | 112 | def test_ground_livb(livb_annotation): 113 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with 114 | species entities. 115 | """ 116 | actual = grounding_utils.ground(copy.deepcopy(livb_annotation)) 117 | 118 | # create expected value 119 | mouse_xrefs = [ 120 | {'namespace': constants.NAMESPACES['LIVB'], 'id': '10090'}, 121 | {'namespace': constants.NAMESPACES['LIVB'], 'id': '10088'}, 122 | ] 123 | human_xrefs = [{'namespace': constants.NAMESPACES['LIVB'], 'id': '9606'}] 124 | 125 | livb_annotation['ents'][0].update(xrefs=mouse_xrefs) 126 | livb_annotation['ents'][1].update(xrefs=human_xrefs) 127 | 128 | expected = livb_annotation 129 | 130 | assert actual == expected 131 | 132 | def test_ground_prge(prge_annotation): 133 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with 134 | species entities. 135 | """ 136 | actual = grounding_utils.ground(copy.deepcopy(prge_annotation)) 137 | 138 | # create expected value 139 | p53_xrefs = [ 140 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000269305', 'organism-id': '9606'} 141 | ] 142 | mk2_xrefs = [ 143 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000356070', 'organism-id': '9606'}, 144 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000433109', 'organism-id': '9606'}, 145 | ] 146 | 147 | prge_annotation['ents'][0].update(xrefs=p53_xrefs) 148 | prge_annotation['ents'][1].update(xrefs=mk2_xrefs) 149 | 150 | expected = prge_annotation 151 | 152 | assert actual == expected 153 | -------------------------------------------------------------------------------- /saber/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the Metrics class (saber/metrics.py). 2 | """ 3 | import pytest 4 | 5 | from .. import constants 6 | from ..config import Config 7 | from ..dataset import Dataset 8 | from ..metrics import Metrics 9 | from ..utils import model_utils 10 | from .resources.dummy_constants import * 11 | 12 | PATH_TO_METRICS_OUTPUT = 'totally/arbitrary' 13 | 14 | ######################################### PYTEST FIXTURES ######################################### 15 | 16 | @pytest.fixture 17 | def dummy_config(): 18 | """Returns an instance of a Config object.""" 19 | dummy_config = Config(PATH_TO_DUMMY_CONFIG) 20 | return dummy_config 21 | 22 | @pytest.fixture 23 | def dummy_dataset(): 24 | """Returns a single dummy Dataset instance after calling Dataset.load(). 25 | """ 26 | # Don't replace rare tokens for the sake of testing 27 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False) 28 | dataset.load() 29 | 30 | return dataset 31 | 32 | @pytest.fixture 33 | def dummy_output_dir(tmpdir, dummy_config): 34 | """Returns list of output directories.""" 35 | # make sure top-level directory is the pytest tmpdir 36 | dummy_config.output_folder = tmpdir.strpath 37 | output_dirs = model_utils.prepare_output_directory(dummy_config) 38 | 39 | return output_dirs 40 | 41 | @pytest.fixture 42 | def dummy_training_data(dummy_dataset): 43 | """Returns training data from `dummy_dataset`. 44 | """ 45 | training_data = {'x_train': [dummy_dataset.idx_seq['train']['word'], 46 | dummy_dataset.idx_seq['train']['char']], 47 | 'x_valid': None, 48 | 'x_test': None, 49 | 'y_train': dummy_dataset.idx_seq['train']['tag'], 50 | 'y_valid': None, 51 | 'y_test': None, 52 | } 53 | 54 | return training_data 55 | 56 | @pytest.fixture 57 | def dummy_metrics(dummy_config, dummy_dataset, dummy_training_data, dummy_output_dir): 58 | """Returns an instance of Metrics. 59 | """ 60 | metrics = Metrics(config=dummy_config, 61 | training_data=dummy_training_data, 62 | index_map=dummy_dataset.idx_to_tag, 63 | output_dir=dummy_output_dir, 64 | # to test passing of arbitrary keyword args to constructor 65 | totally_arbitrary='arbitrary') 66 | return metrics 67 | 68 | ############################################ UNIT TESTS ############################################ 69 | 70 | def test_attributes_after_initilization(dummy_config, 71 | dummy_dataset, 72 | dummy_output_dir, 73 | dummy_training_data, 74 | dummy_metrics): 75 | """Asserts instance attributes are initialized correctly when Metrics object is initialized.""" 76 | # attributes that are passed to __init__ 77 | assert dummy_metrics.config is dummy_config 78 | assert dummy_metrics.training_data is dummy_training_data 79 | assert dummy_metrics.index_map is dummy_dataset.idx_to_tag 80 | assert dummy_metrics.output_dir == dummy_output_dir 81 | # other instance attributes 82 | assert dummy_metrics.current_epoch == 0 83 | assert dummy_metrics.performance_metrics == {p: [] for p in constants.PARTITIONS} 84 | # test that we can pass arbitrary keyword arguments 85 | assert dummy_metrics.totally_arbitrary == 'arbitrary' 86 | 87 | def test_precision_recall_f1_support_value_error(): 88 | """Asserts that call to `Metrics.get_precision_recall_f1_support` raises a `ValueError` error 89 | when an invalid value for parameter `criteria` is passed.""" 90 | # these are totally arbitrary 91 | y_true = [('test', 0, 3), ('test', 4, 7), ('test', 8, 11)] 92 | y_pred = [('test', 0, 3), ('test', 4, 7), ('test', 8, 11)] 93 | 94 | # anything but 'exact', 'left', or 'right' should throw an error 95 | invalid_args = ['right ', 'LEFT', 'eXact', 0, []] 96 | 97 | for arg in invalid_args: 98 | with pytest.raises(ValueError): 99 | Metrics.get_precision_recall_f1_support(y_true, y_pred, criteria=arg) 100 | -------------------------------------------------------------------------------- /saber/tests/test_model_utils.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the model_utils (saber/utils/model_utils.py). 2 | """ 3 | import os 4 | 5 | import pytest 6 | from keras.callbacks import ModelCheckpoint, TensorBoard 7 | 8 | from ..config import Config 9 | from ..utils import model_utils 10 | from .resources.dummy_constants import * 11 | 12 | ######################################### PYTEST FIXTURES ######################################### 13 | 14 | @pytest.fixture 15 | def dummy_config(): 16 | """Returns an instance of a Config object.""" 17 | return Config(PATH_TO_DUMMY_CONFIG) 18 | 19 | @pytest.fixture 20 | def dummy_output_dir(tmpdir, dummy_config): 21 | """Returns list of output directories.""" 22 | # make sure top-level directory is the pytest tmpdir 23 | dummy_config.output_folder = tmpdir.strpath 24 | output_dirs = model_utils.prepare_output_directory(dummy_config) 25 | 26 | return output_dirs 27 | 28 | ############################################ UNIT TESTS ############################################ 29 | 30 | def test_prepare_output_directory(dummy_config, dummy_output_dir): 31 | """Assert that `model_utils.prepare_output_directory()` creates the expected directories 32 | with the expected content. 33 | """ 34 | # TODO (johngiorgi): need to test the actual output of the function! 35 | # check that the expected directories are created 36 | assert all([os.path.isdir(dir_) for dir_ in dummy_output_dir]) 37 | # check that they contain config files 38 | assert all([os.path.isfile(os.path.join(dir_, 'config.ini')) for dir_ in dummy_output_dir]) 39 | 40 | def test_prepare_pretrained_model_dir(dummy_config): 41 | """Asserts that filepath returned by `generic_utils.get_pretrained_model_dir()` is as expected. 42 | """ 43 | dataset = os.path.basename(dummy_config.dataset_folder[0]) 44 | expected = os.path.join(dummy_config.output_folder, constants.PRETRAINED_MODEL_DIR, dataset) 45 | assert model_utils.prepare_pretrained_model_dir(dummy_config) == expected 46 | 47 | def test_setup_checkpoint_callback(dummy_config, dummy_output_dir): 48 | """Check that we get the expected results from call to 49 | `model_utils.setup_checkpoint_callback()`. 50 | """ 51 | simple_actual = model_utils.setup_checkpoint_callback(dummy_config, dummy_output_dir) 52 | blank_actual = model_utils.setup_checkpoint_callback(dummy_config, []) 53 | 54 | # should get as many Callback objects as datasets 55 | assert len(dummy_output_dir) == len(simple_actual) 56 | # all objects in returned list should be of type ModelCheckpoint 57 | assert all([isinstance(x, ModelCheckpoint) for x in simple_actual]) 58 | 59 | # blank input should return blank list 60 | assert blank_actual == [] 61 | 62 | def test_setup_tensorboard_callback(dummy_output_dir): 63 | """Check that we get the expected results from call to 64 | `model_utils.setup_tensorboard_callback()`. 65 | """ 66 | simple_actual = model_utils.setup_tensorboard_callback(dummy_output_dir) 67 | blank_actual = model_utils.setup_tensorboard_callback([]) 68 | 69 | # should get as many Callback objects as datasets 70 | assert len(dummy_output_dir) == len(simple_actual) 71 | # all objects in returned list should be of type TensorBoard 72 | assert all([isinstance(x, TensorBoard) for x in simple_actual]) 73 | 74 | # blank input should return blank list 75 | assert blank_actual == [] 76 | 77 | def test_setup_metrics_callback(): 78 | """ 79 | """ 80 | pass 81 | 82 | def test_setup_callbacks(dummy_config, dummy_output_dir): 83 | """Check that we get the expected results from call to 84 | `model_utils.setup_callbacks()`. 85 | """ 86 | # setup callbacks with config.tensorboard == True 87 | dummy_config.tensorboard = True 88 | with_tensorboard_actual = model_utils.setup_callbacks(dummy_config, dummy_output_dir) 89 | # setup callbacks with config.tensorboard == False 90 | dummy_config.tensorboard = False 91 | without_tensorboard_actual = model_utils.setup_callbacks(dummy_config, dummy_output_dir) 92 | 93 | blank_actual = [] 94 | 95 | # should get as many Callback objects as datasets 96 | assert all([len(x) == len(dummy_output_dir) for x in with_tensorboard_actual]) 97 | assert all([len(x) == len(dummy_output_dir) for x in without_tensorboard_actual]) 98 | 99 | # all objects in returned list should be of expected type 100 | assert all([isinstance(x, ModelCheckpoint) for x in with_tensorboard_actual[0]]) 101 | assert all([isinstance(x, TensorBoard) for x in with_tensorboard_actual[1]]) 102 | assert all([isinstance(x, ModelCheckpoint) for x in without_tensorboard_actual[0]]) 103 | 104 | # blank input should return blank list 105 | assert blank_actual == [] 106 | 107 | def test_precision_recall_f1_support(): 108 | """Asserts that model_utils.precision_recall_f1_support returns the expected values.""" 109 | TP_dummy = 100 110 | FP_dummy = 10 111 | FN_dummy = 20 112 | 113 | prec_dummy = TP_dummy / (TP_dummy + FP_dummy) 114 | rec_dummy = TP_dummy / (TP_dummy + FN_dummy) 115 | f1_dummy = 2 * prec_dummy * rec_dummy / (prec_dummy + rec_dummy) 116 | support_dummy = TP_dummy + FN_dummy 117 | 118 | test_scores_no_null = model_utils.precision_recall_f1_support(TP_dummy, FP_dummy, FN_dummy) 119 | test_scores_TP_null = model_utils.precision_recall_f1_support(0, FP_dummy, FN_dummy) 120 | test_scores_FP_null = model_utils.precision_recall_f1_support(TP_dummy, 0, FN_dummy) 121 | f1_FP_null = 2 * 1. * rec_dummy / (1. + rec_dummy) 122 | test_scores_FN_null = model_utils.precision_recall_f1_support(TP_dummy, FP_dummy, 0) 123 | f1_FN_null = 2 * prec_dummy * 1. / (prec_dummy + 1.) 124 | test_scores_all_null = model_utils.precision_recall_f1_support(0, 0, 0) 125 | 126 | assert test_scores_no_null == (prec_dummy, rec_dummy, f1_dummy, support_dummy) 127 | assert test_scores_TP_null == (0., 0., 0., FN_dummy) 128 | assert test_scores_FP_null == (1., rec_dummy, f1_FP_null, support_dummy) 129 | assert test_scores_FN_null == (prec_dummy, 1., f1_FN_null, TP_dummy) 130 | assert test_scores_all_null == (0., 0., 0., 0) 131 | -------------------------------------------------------------------------------- /saber/tests/test_multi_task_lstm_crf.py: -------------------------------------------------------------------------------- 1 | """Any and all unit tests for the MultiTaskLSTMCRF (saber/models/multi_task_lstm_crf.py). 2 | """ 3 | import pytest 4 | from keras.engine.training import Model 5 | 6 | from ..config import Config 7 | from ..dataset import Dataset 8 | from ..embeddings import Embeddings 9 | from ..models.base_model import BaseKerasModel 10 | from ..models.multi_task_lstm_crf import MultiTaskLSTMCRF 11 | from .resources.dummy_constants import * 12 | 13 | ######################################### PYTEST FIXTURES ######################################### 14 | 15 | @pytest.fixture 16 | def dummy_config(): 17 | """Returns an instance of a Config object.""" 18 | return Config(PATH_TO_DUMMY_CONFIG) 19 | 20 | @pytest.fixture 21 | def dummy_dataset_1(): 22 | """Returns a single dummy Dataset instance after calling `Dataset.load()`. 23 | """ 24 | # Don't replace rare tokens for the sake of testing 25 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False) 26 | dataset.load() 27 | 28 | return dataset 29 | 30 | @pytest.fixture 31 | def dummy_dataset_2(): 32 | """Returns a single dummy Dataset instance after calling `Dataset.load()`. 33 | """ 34 | # Don't replace rare tokens for the sake of testing 35 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False) 36 | dataset.load() 37 | 38 | return dataset 39 | 40 | @pytest.fixture 41 | def dummy_embeddings(dummy_dataset_1): 42 | """Returns an instance of an `Embeddings()` object AFTER the `.load()` method is called. 43 | """ 44 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, 45 | token_map=dummy_dataset_1.idx_to_tag) 46 | embeddings.load(binary=False) # txt file format is easier to test 47 | return embeddings 48 | 49 | @pytest.fixture 50 | def single_model(dummy_config, dummy_dataset_1, dummy_embeddings): 51 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration.""" 52 | model = MultiTaskLSTMCRF(config=dummy_config, 53 | datasets=[dummy_dataset_1], 54 | # to test passing of arbitrary keyword args to constructor 55 | totally_arbitrary='arbitrary') 56 | return model 57 | 58 | @pytest.fixture 59 | def single_model_specify(single_model): 60 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and 61 | a single specified model.""" 62 | single_model.specify() 63 | 64 | return single_model 65 | 66 | @pytest.fixture 67 | def single_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings): 68 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and 69 | loaded embeddings""" 70 | model = MultiTaskLSTMCRF(config=dummy_config, 71 | datasets=[dummy_dataset_1], 72 | embeddings=dummy_embeddings, 73 | # to test passing of arbitrary keyword args to constructor 74 | totally_arbitrary='arbitrary') 75 | return model 76 | 77 | @pytest.fixture 78 | def single_model_embeddings_specify(single_model_embeddings): 79 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file, 80 | loaded embeddings and single specified model.""" 81 | single_model_embeddings.specify() 82 | 83 | return single_model_embeddings 84 | 85 | ############################################ UNIT TESTS ############################################ 86 | 87 | def test_attributes_init_of_single_model(dummy_config, dummy_dataset_1, single_model): 88 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is 89 | initialized without embeddings (`embeddings` attribute is None.) 90 | """ 91 | assert isinstance(single_model, (MultiTaskLSTMCRF, BaseKerasModel)) 92 | # attributes that are passed to __init__ 93 | assert single_model.config is dummy_config 94 | assert single_model.datasets[0] is dummy_dataset_1 95 | assert single_model.embeddings is None 96 | # other instance attributes 97 | assert single_model.models == [] 98 | # test that we can pass arbitrary keyword arguments 99 | assert single_model.totally_arbitrary == 'arbitrary' 100 | 101 | def test_attributes_init_of_single_model_specify(dummy_config, dummy_dataset_1, single_model_specify): 102 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` 103 | model is initialized without embeddings (`embeddings` attribute is None) and 104 | `MultiTaskLSTMCRF.specify()` has been called. 105 | """ 106 | assert isinstance(single_model_specify, (MultiTaskLSTMCRF, BaseKerasModel)) 107 | # attributes that are passed to __init__ 108 | assert single_model_specify.config is dummy_config 109 | assert single_model_specify.datasets[0] is dummy_dataset_1 110 | assert single_model_specify.embeddings is None 111 | # other instance attributes 112 | assert all([isinstance(model, Model) for model in single_model_specify.models]) 113 | # test that we can pass arbitrary keyword arguments 114 | assert single_model_specify.totally_arbitrary == 'arbitrary' 115 | 116 | def test_attributes_init_of_single_model_embeddings(dummy_config, dummy_dataset_1, 117 | dummy_embeddings, single_model_embeddings): 118 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is 119 | initialized with embeddings (`embeddings` attribute is not None.) 120 | """ 121 | assert isinstance(single_model_embeddings, (MultiTaskLSTMCRF, BaseKerasModel)) 122 | # attributes that are passed to __init__ 123 | assert single_model_embeddings.config is dummy_config 124 | assert single_model_embeddings.datasets[0] is dummy_dataset_1 125 | assert single_model_embeddings.embeddings is dummy_embeddings 126 | # other instance attributes 127 | assert single_model_embeddings.models == [] 128 | # test that we can pass arbitrary keyword arguments 129 | assert single_model_embeddings.totally_arbitrary == 'arbitrary' 130 | 131 | def test_attributes_init_of_single_model_embeddings_specify(dummy_config, 132 | dummy_dataset_1, 133 | dummy_embeddings, 134 | single_model_embeddings_specify): 135 | """Asserts instance attributes are initialized correctly when single MultiTaskLSTMCRF 136 | model is initialized with embeddings (`embeddings` attribute is not None) and 137 | `MultiTaskLSTMCRF.specify()` has been called. 138 | """ 139 | assert isinstance(single_model_embeddings_specify, (MultiTaskLSTMCRF, BaseKerasModel)) 140 | # attributes that are passed to __init__ 141 | assert single_model_embeddings_specify.config is dummy_config 142 | assert single_model_embeddings_specify.datasets[0] is dummy_dataset_1 143 | assert single_model_embeddings_specify.embeddings is dummy_embeddings 144 | # other instance attributes 145 | assert all([isinstance(model, Model) for model in single_model_embeddings_specify.models]) 146 | # test that we can pass arbitrary keyword arguments 147 | assert single_model_embeddings_specify.totally_arbitrary == 'arbitrary' 148 | 149 | def test_crf_after_transfer(single_model_specify, dummy_dataset_2): 150 | """Asserts that the CRF output layer of a model is replaced with a new layer when 151 | `MultiTaskLSTMCRF.prepare_for_transfer()` is called by testing that the `name` attribute 152 | of the final layer. 153 | """ 154 | # shorten test statements 155 | test_model = single_model_specify 156 | 157 | # get output layer names before transfer 158 | expected_before_transfer = ['crf_classifier'] 159 | actual_before_transfer = [model.layers[-1].name for model in test_model.models] 160 | # get output layer names after transfer 161 | test_model.prepare_for_transfer([dummy_dataset_2]) 162 | expected_after_transfer = ['target_crf_classifier'] 163 | actual_after_transfer = [model.layers[-1].name for model in test_model.models] 164 | 165 | assert actual_before_transfer == expected_before_transfer 166 | assert actual_after_transfer == expected_after_transfer 167 | -------------------------------------------------------------------------------- /saber/tests/test_preprocessor.py: -------------------------------------------------------------------------------- 1 | """Contains any and all unit tests for the `Preprocessor` class (saber/preprocessor.py). 2 | """ 3 | import en_coref_md 4 | import pytest 5 | 6 | from .. import constants 7 | from ..preprocessor import Preprocessor 8 | 9 | ######################################### PYTEST FIXTURES ######################################### 10 | 11 | @pytest.fixture 12 | def preprocessor(): 13 | """Returns an instance of a Preprocessor object.""" 14 | return Preprocessor() 15 | 16 | @pytest.fixture 17 | def nlp(): 18 | """Returns Sacy NLP model.""" 19 | return en_coref_md.load() 20 | 21 | ############################################ UNIT TESTS ############################################ 22 | 23 | def test_process_text(preprocessor, nlp): 24 | """Asserts that call to Preprocessor._process_text() returns the expected 25 | results.""" 26 | # simple test and its expected value 27 | simple_text = nlp("Simple example. With two sentences!") 28 | simple_expected = ([['Simple', 'example', '.'], ['With', 'two', \ 29 | 'sentences', '!']], [[(0, 6), (7, 14), (14, 15)], [(16, 20),\ 30 | (21, 24), (25, 34), (34, 35)]]) 31 | # blank value test and its expected value 32 | blank_test = nlp("") 33 | blank_expected = ([], []) 34 | 35 | assert preprocessor._process_text(simple_text) == simple_expected 36 | assert preprocessor._process_text(blank_test) == blank_expected 37 | 38 | def test_type_to_idx_value_error(): 39 | """ 40 | """ 41 | with pytest.raises(ValueError): 42 | invalid_input = {'a': 0, 'b': 2, 'c': 3} 43 | Preprocessor.type_to_idx([], initial_mapping=invalid_input) 44 | 45 | def test_type_to_idx_empty_input(): 46 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when 47 | an empty list is passed as input.""" 48 | expected = {} 49 | actual = Preprocessor.type_to_idx([]) 50 | 51 | assert actual == expected 52 | 53 | def test_type_to_idx_simple_input(): 54 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when 55 | a simple list of strings is passed as input.""" 56 | test = ["This", "is", "a", "test", "."] 57 | expected = {'This': 0, 'is': 1, 'a': 2, 'test': 3, '.': 4} 58 | actual = Preprocessor.type_to_idx(test) 59 | 60 | assert actual == expected 61 | 62 | def test_type_to_idx_intial_mapping(): 63 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when 64 | a simple list of strings is passed as input and a supplied `intitial_mapping` argument""" 65 | test = ["This", "is", "a", "test", "."] 66 | initial_mapping = {'This': 0, 'is': 1} 67 | 68 | expected = {'This': 0, 'is': 1, 'a': 2, 'test': 3, '.': 4} 69 | actual = Preprocessor.type_to_idx(test, initial_mapping=initial_mapping) 70 | 71 | assert actual == expected 72 | 73 | def test_get_type_to_idx_sequence(): 74 | """""" 75 | simple_seq = ["This", "is", "a", "test", ".", constants.UNK] 76 | simple_type_to_idx = Preprocessor.type_to_idx(simple_seq) 77 | simple_expected = [0, 1, 2, 3, 4] 78 | simple_actual = Preprocessor.get_type_idx_sequence(simple_seq, type_to_idx=simple_type_to_idx) 79 | 80 | pass 81 | 82 | def test_chunk_entities(): 83 | """Asserts that call to Preprocessor.chunk_entities() returns the 84 | expected results.""" 85 | simple_seq = ['B-PRGE', 'I-PRGE', 'O', 'B-PRGE'] 86 | simple_expected = [('PRGE', 0, 2), ('PRGE', 3, 4)] 87 | 88 | two_type_seq = ['B-LIVB', 'I-LIVB', 'O', 'B-PRGE'] 89 | two_type_expected = [('LIVB', 0, 2), ('PRGE', 3, 4)] 90 | 91 | invalid_seq = ['O', 'I-CHED', 'I-CHED', 'O'] 92 | invalid_expected = [] 93 | 94 | blank_seq = [] 95 | blank_expected = [] 96 | 97 | assert Preprocessor.chunk_entities(simple_seq) == simple_expected 98 | assert Preprocessor.chunk_entities(two_type_seq) == two_type_expected 99 | assert Preprocessor.chunk_entities(invalid_seq) == invalid_expected 100 | assert Preprocessor.chunk_entities(blank_seq) == blank_expected 101 | 102 | def test_sterilize(): 103 | """Asserts that call to Preprocessor.sterilize() returns the 104 | expected results.""" 105 | # test for proceeding and preeceding spaces 106 | simple_text = " This is an easy test. " 107 | simple_expected = "This is an easy test." 108 | # test for mutliple inline spacing errors 109 | multiple_spaces_text = "This is a test with improper spacing. " 110 | multiple_spaces_expected = "This is a test with improper spacing." 111 | # blank value test and its expected value 112 | blank_text = "" 113 | blank_expected = "" 114 | 115 | assert Preprocessor.sterilize(simple_text) == simple_expected 116 | assert Preprocessor.sterilize(multiple_spaces_text) == multiple_spaces_expected 117 | assert Preprocessor.sterilize(blank_text) == blank_expected 118 | -------------------------------------------------------------------------------- /saber/tests/test_text_utils.py: -------------------------------------------------------------------------------- 1 | import en_coref_md 2 | import pytest 3 | 4 | from ..utils import text_utils 5 | 6 | ######################################### PYTEST FIXTURES ######################################### 7 | 8 | @pytest.fixture 9 | def nlp(): 10 | """Returns an instance of a spaCy's nlp object after replacing the default tokenizer with 11 | our modified one.""" 12 | custom_nlp = en_coref_md.load() 13 | custom_nlp.tokenizer = text_utils.biomedical_tokenizer(custom_nlp) 14 | return custom_nlp 15 | 16 | ############################################ UNIT TESTS ############################################ 17 | 18 | def test_biomedical_tokenizer(nlp): 19 | """Asserts that call to customized spaCy tokenizer returns the expected results. 20 | """ 21 | # the empty string 22 | blank_text = "" 23 | blank_expected = [] 24 | # simple test with no complexities 25 | simple_text = "This is an easy test." 26 | simple_expected = ["This", "is", "an", "easy", "test", "."] 27 | # complicated test with some important edge cases 28 | complicated_test = ("This test's tokenizers handeling of very-tricky situations, 3X, " 29 | "more/or/less.") 30 | complicated_expected = ["This", "test", "'", "s", "tokenizers", "handeling", "of", 31 | "very", "-", "tricky", "situations", ",", "3X", ",", "more", "/", "or", 32 | "/", "less", "."] 33 | 34 | # these tests were taken straight from training data 35 | from_CHED_ds = ("The results have shown that the degradation product p-choloroaniline is not " 36 | "a significant factor in chlorhexidine-digluconate associated erosive " 37 | "cystitis.") 38 | from_CHED_ds_expected = ['The', 'results', 'have', 'shown', 'that', 'the', 'degradation', 39 | 'product', 'p', '-', 'choloroaniline', 'is', 'not', 'a', 'significant', 40 | 'factor', 'in', 'chlorhexidine', '-', 'digluconate', 'associated', 41 | 'erosive', 'cystitis', '.'] 42 | from_DISO_ds = ("Rats were treated with seven day intravenous infusion of fucoidan " 43 | "(30 micrograms h-1) or vehicle.") 44 | from_DISO_expected = ['Rats', 'were', 'treated', 'with', 'seven', 'day', 'intravenous', 45 | 'infusion', 'of', 'fucoidan', '(', '30', 'micrograms', 'h', '-', '1', 46 | ')', 'or', 'vehicle', '.'] 47 | from_LIVB_ds = ("Methanoregula formicica sp. nov., a methane-producing archaeon isolated from " 48 | "methanogenic sludge.") 49 | from_LIVB_ds_expected = ['Methanoregula', 'formicica', 'sp', '.', 'nov', '.', ',', 'a', 50 | 'methane', '-', 'producing', 'archaeon', 'isolated', 'from', 51 | 'methanogenic', 'sludge', '.'] 52 | from_PRGE_ds = ("Here we report the cloning, expression, and biochemical characterization of " 53 | "the 32-kDa subunit of human (h) TFIID, termed hTAFII32.") 54 | from_PRGE_ds_expected = ['Here', 'we', 'report', 'the', 'cloning', ',', 'expression', ',', 55 | 'and', 'biochemical', 'characterization', 'of', 'the', '32', '-', 56 | 'kDa', 'subunit', 'of', 'human', '(', 'h', ')', 'TFIID', ',', 'termed', 57 | 'hTAFII32', '.'] 58 | 59 | # generic tests 60 | assert [t.text for t in nlp(blank_text)] == blank_expected 61 | assert [t.text for t in nlp(simple_text)] == simple_expected 62 | assert [t.text for t in nlp(complicated_test)] == complicated_expected 63 | # tests taken straight from training data 64 | assert [t.text for t in nlp(from_CHED_ds)] == from_CHED_ds_expected 65 | assert [t.text for t in nlp(from_DISO_ds)] == from_DISO_expected 66 | assert [t.text for t in nlp(from_LIVB_ds)] == from_LIVB_ds_expected 67 | assert [t.text for t in nlp(from_PRGE_ds)] == from_PRGE_ds_expected 68 | -------------------------------------------------------------------------------- /saber/tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | ######################################### PYTEST FIXTURES ######################################### 2 | 3 | ############################################ UNIT TESTS ############################################ 4 | -------------------------------------------------------------------------------- /saber/trainer.py: -------------------------------------------------------------------------------- 1 | """Contains the Trainer class, which coordinates the training of Keras ML models for Saber. 2 | """ 3 | import logging 4 | import random 5 | 6 | from .utils import data_utils, model_utils 7 | 8 | LOGGER = logging.getLogger(__name__) 9 | 10 | class Trainer(object): 11 | """A class for co-ordinating the training of Keras model(s). 12 | 13 | Args: 14 | config (Config): A Config object which contains a set of harmonized arguments provided in 15 | a *.ini file and, optionally, from the command line. 16 | datasets (list): A list containing one or more Dataset objects. 17 | model (BaseModel): The model to train. 18 | """ 19 | def __init__(self, config, datasets, model): 20 | self.config = config # hyperparameters and model details 21 | self.datasets = datasets # dataset(s) tied to this instance 22 | self.model = model # model tied to this instance 23 | 24 | self.output_dir = model_utils.prepare_output_directory(self.config) 25 | self.callbacks = model_utils.setup_callbacks(self.config, self.output_dir) 26 | self.training_data = self.model.prepare_data_for_training() 27 | 28 | def train(self): 29 | """Co-ordinates the training of Keras model(s) at `self.model.models`. 30 | 31 | Coordinates the training of one or more Keras models (given at `self.model.models`). If a 32 | valid or test set is provided (`Dataset.directory['valid']` or `Dataset.directory['test']` 33 | are not None) a simple train/valid/test strategy is used. Otherwise, cross-validation is 34 | used. 35 | 36 | Args: 37 | callbacks (dict): Dictionary containing Keras callback objects. 38 | output_dir (list): List of directories to save model output to, one for each model. 39 | """ 40 | # TODO: ugly, is there a better way to check for this? what if dif ds follow dif schemes? 41 | if (self.training_data[0]['x_valid'] is not None or 42 | self.training_data[0]['x_test'] is not None): 43 | self._train_valid_test() 44 | else: 45 | self._cross_validation() 46 | 47 | def _train_valid_test(self): 48 | """Trains a Keras model with a standard train/valid/test strategy. 49 | 50 | Trains a Keras model (`self.model.models`), or models in the case of multi-task learning 51 | (`self.model.models` is a list of Keras models) using a simple train/valid/test strategy. 52 | Minimally expects a train partition and one or both of valid and test partitions to be 53 | supplied in the Dataset objects at `self.datasets`. 54 | 55 | Args: 56 | callbacks (dict): Dictionary containing Keras callback objects. 57 | output_dir (list): List of directories to save model output to, one for each model. 58 | """ 59 | print('Using train/test/valid strategy...') 60 | LOGGER.info('Using a train/test/valid strategy for training') 61 | # use 10% of train data as validation data if no validation data provided 62 | if self.training_data[0]['x_valid'] is None: 63 | self.training_data = data_utils.collect_valid_data(self.training_data) 64 | # get list of Keras Callback objects for computing/storing metrics 65 | metrics = model_utils.setup_metrics_callback(config=self.config, 66 | datasets=self.datasets, 67 | training_data=self.training_data, 68 | output_dir=self.output_dir) 69 | # training loop 70 | for epoch in range(self.config.epochs): 71 | print('Global epoch: {}/{}\n{}'.format(epoch + 1, self.config.epochs, '-' * 20)) 72 | # get a random ordering of the dataset/model indices 73 | ds_idx = random.sample(range(0, len(self.datasets)), len(self.datasets)) 74 | for i in ds_idx: 75 | self.model.models[i].fit(x=self.training_data[i]['x_train'], 76 | y=self.training_data[i]['y_train'], 77 | batch_size=self.config.batch_size, 78 | callbacks=[cb[i] for cb in self.callbacks] + [metrics[i]], 79 | validation_data=(self.training_data[i]['x_valid'], 80 | self.training_data[i]['y_valid']), 81 | verbose=1, 82 | # required for Keras to properly display current epoch 83 | initial_epoch=epoch, 84 | epochs=epoch + 1) 85 | 86 | def _cross_validation(self): 87 | """Trains a Keras model with a cross-validation strategy. 88 | 89 | Trains a Keras model (self.model.models) or models in the case of multi-task learning 90 | (self.model.models is a list of Keras models) using a cross-validation strategy. Expects 91 | only a train partition to be supplied in `training_data`. 92 | 93 | Args: 94 | training_data (dict): a dictionary of dictionaries, where the first set of keys are 95 | dataset indices (0, 1, ...) and the second set of keys are dataset partitions 96 | ('X_train', 'y_train', 'X_valid', ...) 97 | output_dir (lst): a list of output directories, one for each dataset 98 | callbacks: a Keras CallBack object for per epoch model checkpointing. 99 | """ 100 | print('Using {}-fold cross-validation strategy...'.format(self.config.k_folds)) 101 | LOGGER.info('Using a %s-fold cross-validation strategy for training', self.config.k_folds) 102 | # get the train/valid partitioned data for all datasets and all folds 103 | self.training_data = data_utils.collect_cv_data(self.training_data, self.config.k_folds) 104 | # training loop 105 | for fold in range(self.config.k_folds): 106 | # get list of Keras Callback objects for computing/storing metrics 107 | metrics = model_utils.setup_metrics_callback(config=self.config, 108 | datasets=self.datasets, 109 | training_data=self.training_data, 110 | output_dir=self.output_dir, 111 | fold=fold) 112 | for epoch in range(self.config.epochs): 113 | train_info = (fold + 1, self.config.k_folds, epoch + 1, self.config.epochs) 114 | print('Fold: {}/{}; Global epoch: {}/{}\n{}'.format(*train_info, '-' * 30)) 115 | # get a random ordering of the dataset/model indices 116 | ds_idx = random.sample(range(0, len(self.datasets)), len(self.datasets)) 117 | for i in ds_idx: 118 | self.model.models[i].fit( 119 | x=self.training_data[i][fold]['x_train'], 120 | y=self.training_data[i][fold]['y_train'], 121 | batch_size=self.config.batch_size, 122 | callbacks=[cb[i] for cb in self.callbacks] + [metrics[i]], 123 | validation_data=(self.training_data[i][fold]['x_valid'], 124 | self.training_data[i][fold]['y_valid']), 125 | verbose=1, 126 | # required for Keras to properly display current epoch 127 | initial_epoch=epoch, 128 | epochs=epoch + 1) 129 | 130 | # clear and rebuild the model at end of each fold (except for the last fold) 131 | if fold < self.config.k_folds - 1: 132 | self._reset_model() 133 | 134 | def _reset_model(self): 135 | """Clears and rebuilds the model at the end of a cross-validation fold. 136 | """ 137 | # destroys current TF graph and creates new one, useful for avoiding clutter from old models 138 | # K.clear_session() 139 | self.model.models = [] 140 | self.model.specify() 141 | self.model.compile() 142 | -------------------------------------------------------------------------------- /saber/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/utils/__init__.py -------------------------------------------------------------------------------- /saber/utils/app_utils.py: -------------------------------------------------------------------------------- 1 | """A collection of web-service-related helper/utility functions. 2 | """ 3 | import logging 4 | import os.path 5 | import traceback 6 | import xml.etree.ElementTree as ET 7 | from urllib.error import HTTPError 8 | from urllib.request import urlopen 9 | 10 | import tensorflow as tf 11 | 12 | from .. import constants 13 | from ..saber import Saber 14 | from ..utils import generic_utils 15 | 16 | # TODO: Need better error handeling here 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | def get_pubmed_xml(pmid): 20 | """Uses the Entrez Utilities Web Service API to fetch XML representation of PubMed document. 21 | 22 | Args: 23 | pmid (int): the PubMed ID of the abstract to fetch 24 | 25 | Returns: 26 | response from Entrez Utilities Web Service API 27 | 28 | Raises: 29 | ValueError if 'pmid' is not an integer 30 | ValueError if 'pmid' has value less than 1 31 | AssertionError if the requested PubMed ID, 'pmid' and the return PubMedID do not match. 32 | """ 33 | if not isinstance(pmid, int): 34 | err_msg = "Argument 'pmid' must be of type {}, not {}.".format(int, type(pmid)) 35 | LOGGER.error('ValueError %s', err_msg) 36 | raise ValueError(err_msg) 37 | if pmid < 1: 38 | err_msg = "Argument 'pmid' must have a value of 1 or greater. Got {}".format(pmid) 39 | LOGGER.error('ValueError %s', err_msg) 40 | raise ValueError(err_msg) 41 | 42 | try: 43 | request = '{}{}'.format(constants.EUTILS_API_ENDPOINT, pmid) 44 | response = urlopen(request).read() 45 | except HTTPError: 46 | err_msg = ("HTTP Error 400: Bad Request was returned. Check that the supplied value for " 47 | "'pmid' ({}) is a valid PubMed ID.".format(pmid)) 48 | traceback.print_exc() 49 | LOGGER.error('HTTPError %s', err_msg) 50 | print(err_msg) 51 | else: 52 | root = get_root(response) 53 | response_pmid = root.find('PubmedArticle').find('MedlineCitation').find('PMID').text 54 | # ensure that requested and returned pubmed ids are the same 55 | if not int(response_pmid) == pmid: 56 | err_msg = ('Requested PubMed ID and PubMed ID returned by Entrez Utilities Web Service ' 57 | 'API do not match.') 58 | LOGGER.error('AssertionError %s', err_msg) 59 | raise AssertionError(err_msg) 60 | 61 | return response 62 | 63 | def get_pubmed_text(pmid): 64 | """Returns the abstract title and text for a given PubMed id using the the Entrez Utilities Web 65 | Service API. 66 | 67 | Args: 68 | pmid (int): the PubMed ID of the abstract to fetch 69 | 70 | Returns: 71 | two-tuple containing the abstract title and text for PubMed ID 'pmid' 72 | """ 73 | xml = get_pubmed_xml(pmid) 74 | root = get_root(xml) 75 | # TODO: There has got to be a better way to do this. 76 | # recurse down the xml tree to abstractText 77 | abstract_title = root.find('PubmedArticle').find('MedlineCitation').find('Article').find('ArticleTitle').text 78 | abstract_text = root.find('PubmedArticle').find('MedlineCitation').find('Article').find('Abstract').find('AbstractText').text 79 | 80 | return abstract_title, abstract_text 81 | 82 | def get_root(xml): 83 | """Return root of given XML string. 84 | 85 | Args: 86 | xml (str): a string containing the contents of an XML file. 87 | 88 | Returns: 89 | root of the given XML file, `xml`. 90 | """ 91 | return ET.fromstring(xml) 92 | 93 | def load_models(ents): 94 | """Loads a model for each entity in `ents`. 95 | 96 | Given a dict with key (str): value (bool) pairs, loads each model (key) for which value is True. 97 | 98 | Args: 99 | ents (dict): A dictionary where the keys correspond to entities and the values are booleans. 100 | 101 | Returns: 102 | A dictionary with keys representing the model and values a Saber object with a loaded model. 103 | """ 104 | models = {} # acc for models 105 | for ent, value in ents.items(): 106 | if value: 107 | # create and load the pre-trained models 108 | saber = Saber() 109 | saber.load(ent) 110 | models[ent] = saber 111 | # TEMP: Weird solution to a weird bug. 112 | # https://github.com/tensorflow/tensorflow/issues/14356#issuecomment-385962623 113 | graph = tf.get_default_graph() 114 | 115 | return models, graph 116 | 117 | def harmonize_entities(default_ents, requested_ents): 118 | """Harmonizes two dictionaries representing default_ents and requested requested_ents. 119 | 120 | Given two dictionaries of entity: boolean key: value pairs, returns a 121 | dictionary where the values of entities specified in `requested_ents` override those specified 122 | in `default_ents`. Entities present in `default_ents` but not in `requested_ents` will be set to 123 | False by default. 124 | 125 | Args: 126 | default_ents (dict): contains entity (str): boolean key: value pairs representing which 127 | entities should be annotated in a given text. 128 | requested_ents (dict): contains entity (str): boolean key: value pairs representing which 129 | entities should be predicted in a given text. 130 | 131 | Returns: a dictionary containing all key, value pairs in `default_ents`, where values in 132 | `requested_ents` override those in default_ents. Any key in `default_ents` but not in 133 | `requested_ents` will have its value set to False by default. 134 | """ 135 | entities = {} 136 | for ent in default_ents: 137 | entities[ent] = False 138 | for ent, value in requested_ents.items(): 139 | if ent in entities: 140 | entities[ent] = value 141 | 142 | return entities 143 | 144 | def parse_request_json(request): 145 | """Returns a dictionary of data parsed from a JSON payload passed in a POST request to Saber. 146 | """ 147 | request_json = request.get_json(force=True) 148 | parsed_request_json = { 149 | 'text': request_json.get('text', None), 150 | 'pmid': request_json.get('pmid', None), 151 | 'ents': request_json.get('ents', None), 152 | 'coref': request_json.get('coref', False), 153 | 'ground': request_json.get('ground', False), 154 | } 155 | 156 | # decide which entities to annotate 157 | default_ents, requested_ents = constants.ENTITIES, parsed_request_json['ents'] 158 | if requested_ents is not None: 159 | parsed_request_json['ents'] = harmonize_entities(default_ents, requested_ents) 160 | else: 161 | parsed_request_json['ents'] = default_ents 162 | 163 | return parsed_request_json 164 | 165 | def combine_annotations(annotations): 166 | """Given a list of annotations made by a Saber model, combines all annotations under one dict. 167 | 168 | Args: 169 | annotations (list): a list of annotations returned by a Saber model 170 | 171 | Returns: 172 | a dict containing all annotations in `annotations`. 173 | """ 174 | combined_anns = [] 175 | for ann in annotations: 176 | combined_anns.extend(ann['ents']) 177 | # create json containing combined annotation 178 | return combined_anns 179 | -------------------------------------------------------------------------------- /saber/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | """A collection of generic helper/utility functions. 2 | """ 3 | import errno 4 | import logging 5 | import os 6 | import shutil 7 | 8 | from setuptools.archive_util import unpack_archive 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | def is_consecutive(lst): 13 | """Returns True if `lst` contains all numbers from 0 to `len(lst)` with no duplicates. 14 | """ 15 | return sorted(lst) == list(range(len(lst))) 16 | 17 | def reverse_dict(mapping): 18 | """Returns a dictionary composed of the reverse v, k pairs of a dictionary `mapping`. 19 | """ 20 | return {v: k for k, v in mapping.items()} 21 | 22 | # https://stackoverflow.com/questions/273192/how-can-i-create-a-directory-if-it-does-not-exist#273227 23 | def make_dir(directory): 24 | """Creates a directory at `directory` if it does not already exist. 25 | """ 26 | try: 27 | os.makedirs(directory) 28 | except OSError as err: 29 | if err.errno != errno.EEXIST: 30 | raise 31 | 32 | def clean_path(filepath): 33 | """Returns normalized and absolutized `filepath`. 34 | """ 35 | filepath = filepath.strip() if isinstance(filepath, str) else filepath 36 | return os.path.abspath(os.path.normpath(filepath)) 37 | 38 | def extract_directory(directory): 39 | """Extracts bz2 compressed directory at `directory` if directory is compressed. 40 | """ 41 | if not os.path.isdir(directory): 42 | head, _ = os.path.split(os.path.abspath(directory)) 43 | 44 | print('Unzipping...', end=' ', flush=True) 45 | unpack_archive(directory + '.tar.bz2', extract_dir=head) 46 | 47 | def compress_directory(directory): 48 | """Compresses a given directory using bz2 compression. 49 | 50 | Raises: 51 | ValueError: if no directory at `directory` exists or if `directory`.tar.bz2 already exists. 52 | """ 53 | # clean/normalize directory 54 | directory = os.path.abspath(os.path.normcase(os.path.normpath(directory))) 55 | 56 | # raise ValueError if directory.tar.bz2 already exists or if directory not valid 57 | output_filepath = '{}.tar.bz2'.format(directory) 58 | if os.path.exists(output_filepath): 59 | err_msg = "{} already exists.".format(output_filepath) 60 | LOGGER.error('ValueError %s', err_msg) 61 | raise ValueError(err_msg) 62 | if not os.path.exists(directory): 63 | err_msg = "File or directory at 'directory' does not exist." 64 | LOGGER.error('ValueError %s', err_msg) 65 | raise ValueError(err_msg) 66 | 67 | # create bz2 compressed directory, remove uncompressed directory 68 | root_dir = os.path.abspath(''.join(os.path.split(directory)[:-1])) 69 | base_dir = os.path.basename(directory) 70 | shutil.make_archive(base_name=directory, format='bztar', root_dir=root_dir, base_dir=base_dir) 71 | shutil.rmtree(directory) 72 | -------------------------------------------------------------------------------- /saber/utils/grounding_utils.py: -------------------------------------------------------------------------------- 1 | """A collection of helper/utility functions for grounding entities. 2 | """ 3 | import logging 4 | 5 | import requests 6 | 7 | from .. import constants 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | def ground(annotation): 12 | """Maps entities in `annotation` to unique indentifiers in an external database or ontology. 13 | 14 | For each entry in `annotation[ents]`, the text representing the annotation (`ent['text']`) is 15 | mapped to a unique identifier in an external database or ontology (if such a unique identifier 16 | is found). Each annotation in `annotation` is updated with an 'xrefs' key which contains a 17 | dictionary with information representing the mapping. 18 | 19 | This function relies on the EXTRACT API to perform the mapping. 20 | 21 | Args: 22 | annotation (dict): A dict containing a list of annotations at key 'ents'. Each annotation 23 | is expected to have a key 'text'. 24 | 25 | Resources: 26 | - EXTRACT 2.0 API: https://extract.jensenlab.org/ 27 | """ 28 | request = 'https://tagger.jensenlab.org/GetEntities?format=tsv&document=' 29 | 30 | # collect annotations made by Saber in a dictionary 31 | annotations = { 32 | 'CHED': [ent for ent in annotation['ents'] if ent['label'] == 'CHED'], 33 | 'DISO': [ent for ent in annotation['ents'] if ent['label'] == 'DISO'], 34 | 'LIVB': [ent for ent in annotation['ents'] if ent['label'] == 'LIVB'], 35 | 'PRGE': [ent for ent in annotation['ents'] if ent['label'] == 'PRGE'], 36 | } 37 | 38 | for label, anns in annotations.items(): 39 | if anns: 40 | # prepand to GET request the text to ground along with its entity type 41 | current_request = '{}{}'.format(request, '+'.join([ann['text'] for ann in anns])) 42 | if label in constants.ENTITY_TYPES: 43 | current_request += '&entity_types={}'.format(constants.ENTITY_TYPES[label]) 44 | 45 | # request to EXTRACT 2.0 API 46 | response = requests.get(current_request).text 47 | entries = [entry.split('\t') for entry in response.split('\n')] if response else [] 48 | 49 | xrefs = {} 50 | 51 | # collect unique identifiers returned by EXTRACT 2.0 API 52 | for entry in entries: 53 | xref = {'namespace': constants.NAMESPACES[label], 'id': entry[-1]} 54 | # in the future, EXTRACT 2.0 API will to assign organism-ids to PRGE labels 55 | if label == 'PRGE': 56 | xref['organism-id'] = entry[1] 57 | 58 | if entry[0] in xrefs: 59 | xrefs[entry[0]].append(xref) 60 | else: 61 | xrefs[entry[0]] = [xref] 62 | 63 | # update annotations with xrefs field 64 | for ann in anns: 65 | if ann['text'] in xrefs: 66 | ann.update(xrefs=xrefs[ann['text']]) 67 | 68 | return annotation 69 | -------------------------------------------------------------------------------- /saber/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """A collection of model-related helper/utility functions. 2 | """ 3 | import os 4 | from time import strftime 5 | 6 | from keras.callbacks import ModelCheckpoint, TensorBoard 7 | 8 | from .. import constants 9 | from ..metrics import Metrics 10 | from .generic_utils import make_dir 11 | 12 | # I/O 13 | 14 | def prepare_output_directory(config): 15 | """Create output directories `config.output_folder/config.dataset_folder` for each dataset. 16 | 17 | Creates the following directory structure: 18 | . 19 | ├── config.output_folder 20 | | └── 21 | | └── 22 | | └── train_session___
__ 23 | | └── 24 | | └── train_session___
__ 25 | | └── 26 | | └── train_session___
__ 27 | 28 | In the case of only a single dataset, 29 | and are 30 | collapsed into a single directory. Saves a copy of the config file used to train the model 31 | (`config`) to the top level of this directory. 32 | 33 | Args: 34 | config (Config): A Config object which contains a set of harmonized arguments provided in 35 | a *.ini file and, optionally, from the command line. 36 | 37 | Returns: 38 | a list of directory paths to the subdirectories 39 | train_session___
__, one for each dataset in `dataset_folder`. 40 | """ 41 | output_dirs = [] 42 | output_folder = config.output_folder 43 | # if multiple datasets, create additional directory to house all output directories 44 | if len(config.dataset_folder) > 1: 45 | dataset_names = '_'.join([os.path.basename(ds) for ds in config.dataset_folder]) 46 | output_folder = os.path.join(output_folder, dataset_names) 47 | make_dir(output_folder) 48 | 49 | for dataset in config.dataset_folder: 50 | # create a subdirectory for each datasets name 51 | dataset_dir = os.path.join(output_folder, os.path.basename(dataset)) 52 | # create a subdirectory for each train session 53 | train_session_dir = strftime("train_session_%a_%b_%d_%I_%M_%S").lower() 54 | dataset_train_session_dir = os.path.join(dataset_dir, train_session_dir) 55 | output_dirs.append(dataset_train_session_dir) 56 | make_dir(dataset_train_session_dir) 57 | 58 | # copy config file to top level directory 59 | config.save(dataset_train_session_dir) 60 | 61 | return output_dirs 62 | 63 | def prepare_pretrained_model_dir(config): 64 | """Returns path to top-level directory to save a pre-trained model. 65 | 66 | Returns a directory path to save a pre-trained model based on `config.dataset_folder` and 67 | `config.output_folder`. The folder which contains the saved model is named from each dataset 68 | name in `config.dataset_folder` joined by an underscore: 69 | . 70 | ├── config.output_folder 71 | | └── 72 | | └── 73 | 74 | config (Config): A Config object which contains a set of harmonized arguments provided in 75 | a *.ini file and, optionally, from the command line. 76 | 77 | Returns: 78 | Full path to save a pre-trained model based on `config.dataset_folder` and 79 | `config.dataset_folder`. 80 | """ 81 | ds_names = '_'.join([os.path.basename(ds) for ds in config.dataset_folder]) 82 | return os.path.join(config.output_folder, constants.PRETRAINED_MODEL_DIR, ds_names) 83 | 84 | # Callbacks 85 | 86 | def setup_checkpoint_callback(config, output_dir): 87 | """Sets up per epoch model checkpointing. 88 | 89 | Sets up model checkpointing by creating a Keras CallBack for each output directory in 90 | `output_dir` (corresponding to individual datasets). 91 | 92 | Args: 93 | output_dir (lst): A list of output directories, one for each dataset. 94 | 95 | Returns: 96 | checkpointer: A Keras CallBack object for per epoch model checkpointing. 97 | """ 98 | checkpointers = [] 99 | for dir_ in output_dir: 100 | # if only saving best weights, filepath needs to be the same so it gets overwritten 101 | if config.save_all_weights: 102 | filepath = os.path.join(dir_, 'weights_epoch_{epoch:03d}_val_loss_{val_loss:.4f}.hdf5') 103 | else: 104 | filepath = os.path.join(dir_, 'weights_best_epoch.hdf5') 105 | 106 | checkpointer = ModelCheckpoint(filepath=filepath, 107 | monitor='val_loss', 108 | save_best_only=(not config.save_all_weights), 109 | save_weights_only=True) 110 | checkpointers.append(checkpointer) 111 | 112 | return checkpointers 113 | 114 | def setup_tensorboard_callback(output_dir): 115 | """Setup logs for use with TensorBoard. 116 | 117 | This callback writes a log for TensorBoard, which allows you to visualize dynamic graphs of 118 | your training and test metrics, as well as activation histograms for the different layers in 119 | your model. Logs are saved as `tensorboard_logs` at the top level of each directory in 120 | `output_dir`. 121 | 122 | Args: 123 | output_dir (lst): A list of output directories, one for each dataset. 124 | 125 | Returns: 126 | A list of Keras CallBack object for logging TensorBoard visualizations. 127 | 128 | Example: 129 | >>> tensorboard --logdir=/path_to_tensorboard_logs 130 | """ 131 | tensorboards = [] 132 | for dir_ in output_dir: 133 | tensorboard_dir = os.path.join(dir_, 'tensorboard_logs') 134 | tensorboards.append(TensorBoard(log_dir=tensorboard_dir)) 135 | 136 | return tensorboards 137 | 138 | def setup_metrics_callback(config, datasets, training_data, output_dir, fold=None): 139 | """Creates Keras Metrics Callback objects, one for each dataset in `datasets`. 140 | 141 | Args: 142 | config (Config): Contains a set of harmonzied arguments provided in a *.ini file and, 143 | optionally, from the command line. 144 | datasets (list): A list of Dataset objects. 145 | training_data (dict): A dictionary containing training data (inputs and targets). 146 | output_dir (list): List of directories to save model output to, one for each model. 147 | fold (int): The current fold in k-fold cross-validation. Defaults to None. 148 | 149 | Returns: 150 | A list of Metric objects, one for each dataset in `datasets`. 151 | """ 152 | metrics = [] 153 | for i, dataset in enumerate(datasets): 154 | eval_data = training_data[i] if fold is None else training_data[i][fold] 155 | metric = Metrics(config=config, 156 | training_data=eval_data, 157 | index_map=dataset.idx_to_tag, 158 | output_dir=output_dir[i], 159 | fold=fold) 160 | metrics.append(metric) 161 | 162 | return metrics 163 | 164 | def setup_callbacks(config, output_dir): 165 | """Returns a list of Keras Callback objects to use during training. 166 | 167 | Args: 168 | config (Config): A Config object which contains a set of harmonized arguments provided in 169 | a *.ini file and, optionally, from the command line. 170 | output_dir (list): A list of filepaths, one for each dataset in `self.datasets`. 171 | 172 | Returns: 173 | A list of Keras Callback objects to use during training. 174 | """ 175 | callbacks = [] 176 | # model checkpointing 177 | callbacks.append(setup_checkpoint_callback(config, output_dir)) 178 | # tensorboard 179 | if config.tensorboard: 180 | callbacks.append(setup_tensorboard_callback(output_dir)) 181 | 182 | return callbacks 183 | 184 | # Evaluation metrics 185 | 186 | def precision_recall_f1_support(true_positives, false_positives, false_negatives): 187 | """Returns the precision, recall, F1 and support from TP, FP and FN counts. 188 | 189 | Returns a four-tuple containing the precision, recall, F1-score and support 190 | For the given true_positive (TP), false_positive (FP) and 191 | false_negative (FN) counts. 192 | 193 | Args: 194 | true_positives (int): Number of true-positives predicted by classifier. 195 | false_positives (int): Number of false-positives predicted by classifier. 196 | false_negatives (int): Number of false-negatives predicted by classifier. 197 | 198 | Returns: 199 | Four-tuple containing (precision, recall, f1, support). 200 | """ 201 | precision = true_positives / (true_positives + false_positives) if true_positives > 0 else 0. 202 | recall = true_positives / (true_positives + false_negatives) if true_positives > 0 else 0. 203 | f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0. 204 | support = true_positives + false_negatives 205 | 206 | return precision, recall, f1_score, support 207 | 208 | # Saving/loading 209 | 210 | def load_pretrained_model(config, datasets, weights_filepath, model_filepath): 211 | """Loads a pre-trained Keras model from its pre-trained weights and architecture files. 212 | 213 | Loads a pre-trained Keras model given by its pre-trained weights (`weights_filepath`) and 214 | architecture files (`model_filepath`). The type of model to load is specificed in 215 | `config.model_name`. 216 | 217 | Args: 218 | config (Config): config (Config): A Config object which contains a set of harmonized 219 | arguments provided in a *.ini file and, optionally, from the command line. 220 | datasets (Dataset): A list of Dataset objects. 221 | weights_filepath (str): A filepath to the weights of a pre-trained Keras model. 222 | model_filepath (str): A filepath to the architecture of a pre-trained Keras model. 223 | 224 | Returns: 225 | A pre-trained Keras model. 226 | """ 227 | if config.model_name == 'mt-lstm-crf': 228 | from ..models.multi_task_lstm_crf import MultiTaskLSTMCRF 229 | model = MultiTaskLSTMCRF(config, datasets) 230 | model.load(weights_filepath, model_filepath) 231 | model.compile() 232 | 233 | return model 234 | -------------------------------------------------------------------------------- /saber/utils/text_utils.py: -------------------------------------------------------------------------------- 1 | """A collection of helper/utility functions for processing text. 2 | """ 3 | import re 4 | 5 | from spacy.tokenizer import Tokenizer 6 | 7 | # NERsuite-like tokenization: alnum sequences preserved as single 8 | # tokens, rest are single-character tokens. 9 | # https://github.com/spyysalo/standoff2conll/blob/master/common.py 10 | INFIX_RE = re.compile(r'''([0-9a-zA-Z]+|[^0-9a-zA-Z])''') 11 | 12 | # https://spacy.io/usage/linguistic-features#native-tokenizers 13 | def biomedical_tokenizer(nlp): 14 | """ 15 | Customizes spaCy's tokenizer class for better handling of biomedical text. 16 | """ 17 | return Tokenizer(nlp.vocab, infix_finditer=INFIX_RE.finditer) 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="saber", 8 | version="0.1.0-alpha", 9 | author="John Giorgi", 10 | author_email="johnmgiorgi@gmail.com", 11 | license="MIT", 12 | description="Saber: Sequence Annotator for Biomedical Entities and Relations", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/BaderLab/saber", 16 | python_requires='>=3.5', 17 | packages=setuptools.find_packages(), 18 | classifiers=[ 19 | "Development Status :: 3 - Alpha", 20 | "Framework :: Flask", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3 :: Only", 23 | "Programming Language :: Python :: 3.5", 24 | "Programming Language :: Python :: 3.6", 25 | "Programming Language :: Python :: 3.7", 26 | "Operating System :: OS Independent", 27 | ], 28 | keywords=[ 29 | 'Natural Language Processing', 30 | 'Named Entity Recognition', 31 | ], 32 | install_requires=[ 33 | 'scikit-learn>=0.20.1', 34 | 'tensorflow>=1.12.0', 35 | 'Flask>=1.0.2', 36 | 'waitress>=1.1.0', 37 | 'keras==2.2.4', 38 | 'PTable>=0.9.2', 39 | 'spacy>=2.0.11, <=2.0.13', 40 | 'gensim>=3.4.0', 41 | 'nltk>=3.3', 42 | 'googledrivedownloader>=0.3', 43 | 'google-compute-engine', 44 | 'msgpack==0.5.6', 45 | 'keras-contrib @ git+https://www.github.com/keras-team/keras-contrib.git', 46 | 'en-coref-md @ https://github.com/huggingface/neuralcoref-models/releases/download/en_coref_md-3.0.0/en_coref_md-3.0.0.tar.gz', 47 | ], 48 | include_package_data=True, 49 | # allows us to install + run tests with `python setup.py test` 50 | # https://docs.pytest.org/en/latest/goodpractices.html#integrating-with-setuptools-python-setup-py-test-pytest-runner 51 | setup_requires=['pytest-runner'], 52 | tests_require=['pytest'], 53 | zip_safe=False, 54 | ) 55 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | manifest 4 | pyroma 5 | py 6 | 7 | [testenv] 8 | commands = pytest --cov=saber -v 9 | deps = 10 | pytest 11 | pytest-cov 12 | coveralls 13 | 14 | [testenv:manifest] 15 | deps = check-manifest 16 | skip_install = true 17 | commands = check-manifest 18 | 19 | [testenv:pyroma] 20 | deps = 21 | pygments 22 | pyroma 23 | skip_install = true 24 | commands = pyroma --min=10 . 25 | description = Run the pyroma tool to check the project's package friendliness. 26 | 27 | [testenv:coverage-clean] 28 | deps = coverage 29 | skip_install = true 30 | commands = coverage erase 31 | 32 | [testenv:coverage-report] 33 | deps = coverage 34 | skip_install = true 35 | commands = 36 | coverage combine 37 | coverage report 38 | 39 | #################### 40 | # Deployment tools # 41 | #################### 42 | [testenv:build] 43 | skip_install = true 44 | deps = 45 | wheel 46 | setuptools 47 | commands = 48 | python setup.py -q sdist bdist_wheel 49 | 50 | [testenv:release] 51 | skip_install = true 52 | deps = 53 | {[testenv:build]deps} 54 | twine >= 1.5.0 55 | commands = 56 | {[testenv:build]commands} 57 | twine upload --skip-existing dist/* 58 | --------------------------------------------------------------------------------