├── .gitignore
├── .travis.yml
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── docs
├── guide_to_saber_api.md
├── img
│ └── saber_logo.png
├── index.md
├── installation.md
├── quick_start.md
└── resources.md
├── mkdocs.yml
├── notebooks
└── lightning_tour.ipynb
├── saber
├── __init__.py
├── cli
│ ├── __init__.py
│ ├── app.py
│ └── train.py
├── config.ini
├── config.py
├── constants.py
├── dataset.py
├── embeddings.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ └── multi_task_lstm_crf.py
├── preprocessor.py
├── saber.py
├── tests
│ ├── __init__.py
│ ├── resources
│ │ ├── __init__.py
│ │ ├── dummy_config.ini
│ │ ├── dummy_constants.py
│ │ ├── dummy_dataset_1
│ │ │ ├── test.tsv
│ │ │ ├── train.tsv
│ │ │ └── valid.tsv
│ │ ├── dummy_dataset_2
│ │ │ ├── test.tsv
│ │ │ ├── train.tsv
│ │ │ └── valid.tsv
│ │ ├── dummy_word_embeddings
│ │ │ └── dummy_word_embeddings.txt
│ │ └── helpers.py
│ ├── test_app_utils.py
│ ├── test_base_model.py
│ ├── test_config.py
│ ├── test_data_utils.py
│ ├── test_dataset.py
│ ├── test_embeddings.py
│ ├── test_generic_utils.py
│ ├── test_grounding_utils.py
│ ├── test_metrics.py
│ ├── test_model_utils.py
│ ├── test_multi_task_lstm_crf.py
│ ├── test_preprocessor.py
│ ├── test_saber.py
│ ├── test_text_utils.py
│ └── test_trainer.py
├── trainer.py
└── utils
│ ├── __init__.py
│ ├── app_utils.py
│ ├── data_utils.py
│ ├── generic_utils.py
│ ├── grounding_utils.py
│ ├── model_utils.py
│ └── text_utils.py
├── setup.cfg
├── setup.py
└── tox.ini
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | bin/
19 | share/
20 | include/
21 | local/
22 | man/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 | pip-selfcheck.json
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
109 | # pytest
110 | .pytest_cache
111 |
112 | # IDE
113 | .idea
114 | *.iml
115 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | os:
4 | - linux
5 | - osx
6 |
7 | env: TOXENV=py
8 |
9 | matrix:
10 | include:
11 | - env: TOXENV=manifest
12 | - env: TOXENV=pyroma
13 | allow_failures:
14 | - os: osx
15 | - python: "3.5"
16 |
17 | python:
18 | - "3.6"
19 | - "3.5"
20 |
21 | install: pip install tox pytest-cov coveralls
22 |
23 | script: tox
24 |
25 | after_success: coveralls
26 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | **Working on your first Pull Request?** You can learn how from this *free* series [How to Contribute to an Open Source Project on GitHub](https://egghead.io/series/how-to-contribute-to-an-open-source-project-on-github)
2 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.6
2 | WORKDIR /app
3 | COPY . /app
4 | RUN pip install .
5 | #RUN pip install git+https://github.com/BaderLab/saber.git
6 | RUN pip install git+https://www.github.com/keras-team/keras-contrib.git
7 | RUN pip install https://github.com/huggingface/neuralcoref-models/releases/download/en_coref_md-3.0.0/en_coref_md-3.0.0.tar.gz
8 | CMD ["python", "-m", "saber.cli.app"]
9 | EXPOSE 5000
10 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Bader Lab, University of Toronto
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Include README(s)
2 | include README.md CONTRIBUTING.md
3 |
4 | # Include license file
5 | include LICENSE.txt
6 |
7 | # Include config file(s)
8 | global-include *.ini *.yml
9 |
10 | # Include Dockerfile
11 | include Dockerfile
12 |
13 | # Include test resources
14 | graft saber/tests
15 |
16 | # Include docs
17 | graft docs
18 |
19 | # Don't include notebooks
20 | prune notebooks
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Saber
7 |
8 |
9 |
10 |
11 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | Saber (Sequence Annotator for Biomedical Entities and Relations) is a deep-learning based tool for information extraction in the biomedical domain.
37 |
38 |
39 |
40 | Installation •
41 | Quickstart •
42 | Documentation
43 |
44 |
45 | ## Installation
46 |
47 | _Note! This is a work in progress. Many things are broken, and the codebase is not stable._
48 |
49 | To install Saber, you will need `python3.6`.
50 |
51 | ### Latest PyPI stable release
52 |
53 | [](https://pypi.org/project/saber/)
54 | [](https://pypi.org/project/saber)
55 | [](https://github.com/baderlab/saber/network/dependents)
56 |
57 | ```sh
58 | (saber) $ pip install saber
59 | ```
60 |
61 | > The install from PyPI is currently broken, please install using the instructions below.
62 |
63 | ### Latest development release on GitHub
64 |
65 | [](https://github.com/baderlab/saber/releases)
66 | [](https://github.com/baderlab/saber/stargazers)
67 | [](https://github.com/BaderLab/saber/network/members)
68 | [](https://github.com/baderlab/saber/graphs/commit-activity)
69 | [](https://github.com/baderlab/saber/pulse)
70 |
71 | Pull and install straight from GitHub
72 |
73 | ```sh
74 | (saber) $ pip install git+https://github.com/BaderLab/saber.git
75 | ```
76 |
77 | or install by cloning the repository
78 |
79 | ```sh
80 | (saber) $ git clone https://github.com/BaderLab/saber.git
81 | (saber) $ cd saber
82 | ```
83 |
84 | and then using either `pip`
85 |
86 | ```sh
87 | (saber) $ pip install -e .
88 | ```
89 | or `setuptools`
90 |
91 | ```sh
92 | (saber) $ python setup.py install
93 | ```
94 |
95 | See the [documentation](https://baderlab.github.io/saber/installation/) for more detailed installation instructions.
96 |
97 | ## Quickstart
98 |
99 | If your goal is to use Saber to annotate biomedical text, then you can either use the [web-service](#web-service) or a [pre-trained model](#pre-trained-models). If you simply want to check Saber out, without installing anything locally, try the [Google Colaboratory](#google-colaboratory) notebook.
100 |
101 | ### Google Colaboratory
102 |
103 | The fastest way to check out Saber is by following along with the Google Colaboratory notebook ([](https://colab.research.google.com/drive/1WD7oruVuTo6p_908MQWXRBdLF3Vw2MPo)). In order to be able to run the cells, select "Open in Playground" or, alternatively, save a copy to your own Google Drive account (File > Save a copy in Drive).
104 |
105 | ### Web-service
106 |
107 | To use Saber as a **local** web-service, run
108 |
109 | ```
110 | (saber) $ python -m saber.cli.app
111 | ```
112 |
113 | or, if you prefer, you can pull & run the Saber image from **Docker Hub**
114 |
115 | ```sh
116 | # Pull Saber image from Docker Hub
117 | $ docker pull pathwaycommons/saber
118 | # Run docker (use `-dt` instead of `-it` to run container in background)
119 | $ docker run -it --rm -p 5000:5000 --name saber pathwaycommons/saber
120 | ```
121 |
122 | There are currently two endpoints, `/annotate/text` and `/annotate/pmid`. Both expect a `POST` request with a JSON payload, e.g.,
123 |
124 | ```json
125 | {
126 | "text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."
127 | }
128 | ```
129 |
130 | or
131 |
132 | ```json
133 | {
134 | "pmid": 11835401
135 | }
136 | ```
137 |
138 | For example, running the web-service locally and using `cURL`
139 |
140 | ```sh
141 | $ curl -X POST 'http://localhost:5000/annotate/text' \
142 | --data '{"text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."}'
143 | ```
144 |
145 | Documentation for the Saber web-service API can be found [here](https://baderlab.github.io/saber-api-docs/).
146 |
147 | ### Pre-trained models
148 |
149 | First, import the `Saber` class. This is the interface to Saber
150 |
151 | ```python
152 | from saber.saber import Saber
153 | ```
154 |
155 | then create a `Saber` object
156 |
157 | ```python
158 | saber = Saber()
159 | ```
160 |
161 | and then load the model of our choice
162 |
163 | ```python
164 | saber.load('PRGE')
165 | ```
166 |
167 | To annotate text with the model, just call the `Saber.annotate()` method
168 |
169 | ```python
170 | saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.")
171 | ```
172 | See the [documentation](https://baderlab.github.io/saber/quick_start/#pre-trained-models) for more details on using pre-trained models.
173 |
174 | ## Documentation
175 |
176 | Documentation for the Saber package can be found [here](https://baderlab.github.io/saber/). The web-service API has its own documentation [here](https://baderlab.github.io/saber-api-docs/#introduction).
177 |
178 | You can also call `help()` on any Saber method for more information
179 |
180 | ```python
181 | from saber import Saber
182 |
183 | saber = Saber()
184 |
185 | help(saber.annotate)
186 | ```
187 |
188 | or pass the `--help` flag to any of the command-line interfaces
189 |
190 | ```
191 | python -m src.cli.train --help
192 | ```
193 |
194 | Feel free to open an issue or reach out to us on our slack channel ([](https://join.slack.com/t/saber-nlp/shared_invite/enQtNzE0MzY5ODM3MTc0LWZmY2VjMTY5MjllMmIzNDhkM2VhZjk5ODE1MDYyZjE5OGFjYWVhY2I2NDk5Yjk1N2Q3NTI4YTdhMTI5MjRiOGY)) for more help.
195 |
196 |
197 |
--------------------------------------------------------------------------------
/docs/guide_to_saber_api.md:
--------------------------------------------------------------------------------
1 | # Guide to the Saber API
2 |
3 | You can interact with Saber as a web-service (explained in [Quick Start: Web-service](https://baderlab.github.io/saber/quick_start/#web-service)), [command line tool](#command-line-tool), or as a [python package](#python-package). If you created a virtual environment, _remember to activate it first_.
4 |
5 | ### Command line tool
6 |
7 | Currently, the command line tool simply trains the model. To use it, call
8 |
9 | ```
10 | (saber) $ python -m saber.cli.train
11 | ```
12 |
13 | along with any command line arguments. For example, to train the model on the [NCBI Disease](https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/) corpus
14 |
15 | ```
16 | (saber) $ python -m saber.cli.train --dataset_folder NCBI_Disease_BIO
17 | ```
18 |
19 | !!! tip
20 | See [Resources: Datasets](https://baderlab.github.io/saber/resources/#datasets) for help preparing datasets and word embeddings for training.
21 |
22 | Run `python -m saber.cli.train --help` to see all possible arguments.
23 |
24 | Of course, supplying arguments at the command line can quickly become cumbersome. Saber also allows you to provide a configuration file, which can be specified like so
25 |
26 | ```
27 | (saber) $ python -m saber.cli.train --config_filepath path/to/config.ini
28 | ```
29 |
30 | Copy the contents of the [default config file](https://github.com/BaderLab/saber/blob/master/saber/config.ini) to a new `*.ini` file in order to get started.
31 |
32 | !!! note
33 | Arguments supplied at the command line overwrite those found in the configuration file, e.g.,
34 |
35 | ```
36 | (saber) $ python -m saber.cli.train --dataset_folder path/to/dataset --k_folds 10
37 | ```
38 |
39 | would overwrite the arguments for `dataset_folder` and `k_folds` found in the configuration file.
40 |
41 | ### Python package
42 |
43 | You can also import Saber and interact with it as a python package. Saber exposes its functionality through the `Saber` class. Here is just about everything Saber does in one script:
44 |
45 | ```python
46 | from saber.saber import Saber
47 |
48 | # First, create a Saber object, which exposes Sabers functionality
49 | saber = Saber()
50 |
51 | # Load a dataset and create a model (provide a list of datasets to use multi-task learning!)
52 | saber.load_dataset('path/to/datasets/GENIA')
53 | saber.build(model_name='MT-LSTM-CRF')
54 |
55 | # Train and save a model
56 | saber.train()
57 | saber.save('pretrained_models/GENIA')
58 |
59 | # Load a model
60 | del saber
61 | saber = Saber()
62 | saber.load('pretrained_models/GENIA')
63 |
64 | # Perform prediction on raw text, get resulting annotation
65 | raw_text = 'The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.'
66 | annotation = saber.annotate(raw_text)
67 |
68 | # Use transfer learning to continue training on a new dataset
69 | saber.load_dataset('path/to/datasets/CRAFT')
70 | saber.train()
71 | ```
72 |
73 | #### Transfer learning
74 |
75 | Transfer learning is as easy as training, saving, loading, and then continuing training of a model. Here is an example
76 |
77 | ```python
78 | # Create and train a model on GENIA corpus
79 | saber = Saber()
80 | saber.load_dataset('path/to/datasets/GENIA')
81 | saber.build(model_name='MT-LSTM-CRF')
82 | saber.train()
83 | saber.save('pretrained_models/GENIA')
84 |
85 | # Load that model
86 | del saber
87 | saber = Saber()
88 | saber.load('pretrained_models/GENIA')
89 |
90 | # Use transfer learning to continue training on a new dataset
91 | saber.load_dataset('path/to/datasets/CRAFT')
92 | saber.train()
93 | ```
94 |
95 | !!! note
96 | This is currently only supported by the `mt-lstm-crf` model.
97 |
98 | #### Multi-task learning
99 |
100 | Multi-task learning is as easy as specifying multiple dataset paths, either in the `config` file, at the command line via the flag `--dataset_folder`, or as an argument to `load_dataset()`. The number of datasets is arbitrary.
101 |
102 | Here is an example using the last method
103 |
104 | ```python
105 | saber = Saber()
106 |
107 | # Simply pass multiple dataset paths as a list to load_dataset to use multi-task learning.
108 | saber.load_dataset(['path/to/datasets/NCBI_Disease', 'path/to/datasets/Linnaeus'])
109 |
110 | saber.build(model_name='MT-LSTM-CRF')
111 | saber.train()
112 | ```
113 |
114 | !!! note
115 | This is currently only supported by the `mt-lstm-crf` model.
116 |
117 | #### Training on GPUs
118 |
119 | Saber will automatically train on as many GPUs as are available. In order for this to work, you must have [CUDA](https://developer.nvidia.com/cuda-downloads) and, optionally, [CudDNN](https://developer.nvidia.com/cudnn) installed. If you are using conda to manage your environment, then these are installed for you when you call
120 |
121 | ```
122 | (saber) $ conda install tensorflow-gpu
123 | ```
124 |
125 | Otherwise, install them yourself and use `pip` to install `tensorflow-gpu`
126 |
127 | ```
128 | (saber) $ pip install tensorflow-gpu
129 | ```
130 |
131 | To control which GPUs Saber trains on, you can use the `CUDA_VISIBLE_DEVICES` environment variable, e.g.,
132 |
133 | ```
134 | # To train exclusively on CPU
135 | (saber) $ CUDA_VISIBLE_DEVICES="" python -m saber.cli.train
136 |
137 | # To train on 1 GPU with ID=0
138 | (saber) $ CUDA_VISIBLE_DEVICES="0" python -m saber.cli.train
139 |
140 | # To train on 2 GPUs with IDs=0,2
141 | (saber) $ CUDA_VISIBLE_DEVICES="0,2" python -m saber.cli.train
142 | ```
143 |
144 | !!! tip
145 | You can get information about your NVIDIA GPUs by typing `nvidia-smi` at the command line (assuming the GPUs are setup properly and the nvidia driver is installed).
146 |
147 | #### Saving and loading models
148 |
149 | In the following sections we introduce the saving and loading of models.
150 |
151 | ##### Saving a model
152 |
153 | Assuming the model has already been created (see above), we can easily save our model like so
154 |
155 | ```python
156 | save_dir = 'path/to/pretrained_models/mymodel'
157 | saber.save(save_dir)
158 | ```
159 |
160 | ##### Loading a model
161 |
162 | Lets illustrate loading a model with a new `Saber` object
163 |
164 | ```python
165 | # Delete our previous Saber object (if it exists)
166 | del saber
167 | # Create a new Saber object
168 | saber = Saber()
169 | # Load a previous model
170 | saber.load(path_to_saved_model)
171 | ```
172 |
--------------------------------------------------------------------------------
/docs/img/saber_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/docs/img/saber_logo.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | __Saber__ (**S**equence **A**nnotator for **B**iomedical **E**ntities and **R**elations) is a deep-learning based tool for __information extraction__ in the biomedical domain.
2 |
3 | The neural network model used is a BiLSTM-CRF [[1](https://arxiv.org/abs/1603.01360), [2](https://arxiv.org/abs/1603.01354)]; a state-of-the-art architecture for sequence labelling. The model is implemented using [Keras](https://keras.io/) / [Tensorflow](https://www.tensorflow.org).
4 |
5 | The goal is that Saber will eventually perform all the important steps in text-mining of biomedical literature:
6 |
7 | - Coreference resolution (:white_check_mark:)
8 | - Biomedical named entity recognition (BioNER) (:white_check_mark:)
9 | - Entity linking / grounding / normalization (:white_check_mark:)
10 | - Simple relation extraction (:soon:)
11 | - Event extraction (:soon:)
12 |
13 | Pull requests are welcome! If you encounter any bugs, please open an issue in the [GitHub repository](https://github.com/BaderLab/saber).
14 |
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | To install Saber, you will need `python3.6`. If not already installed, `python3` can be installed via
4 |
5 | - The [official installer](https://www.python.org/downloads/)
6 | - [Homebrew](https://brew.sh), on MacOS (`brew install python3`)
7 | - [Miniconda3](https://conda.io/miniconda.html) / [Anaconda3](https://www.anaconda.com/download/)
8 |
9 | !!! note
10 | Run `python --version` at the command line to make sure installation was successful. You may need to type `python3` (not just `python`) depending on your install method.
11 |
12 | (OPTIONAL) Activate your virtual environment (see [below](#optional-creating-and-activating-virtual-environments) for help)
13 |
14 | ```sh
15 | $ conda activate saber
16 | # Notice your command prompt has changed to indicate that the environment is active
17 | (saber) $
18 | ```
19 |
20 | ## Latest PyPI stable release
21 |
22 | [](https://pypi.org/project/saber/)
23 | [](https://pypi.org/project/saber)
24 | [](https://github.com/baderlab/saber/network/dependents)
25 |
26 | ```sh
27 | (saber) $ pip install saber
28 | ```
29 |
30 | !!! error
31 | The install from PyPI is currently broken, please install using the instructions below.
32 |
33 | ## Latest development release on GitHub
34 |
35 | [](https://github.com/baderlab/saber/releases)
36 | [](https://github.com/baderlab/saber/stargazers)
37 | [](https://github.com/BaderLab/saber/network/members)
38 | [](https://github.com/baderlab/saber/graphs/commit-activity)
39 | [](https://github.com/baderlab/saber/pulse)
40 |
41 | Pull and install straight from GitHub
42 |
43 | ```sh
44 | (saber) $ pip install git+https://github.com/BaderLab/saber.git
45 | ```
46 |
47 | or install by cloning the repository
48 |
49 | ```sh
50 | (saber) $ git clone https://github.com/BaderLab/saber.git
51 | (saber) $ cd saber
52 | ```
53 |
54 | and then using either `pip`
55 |
56 | ```sh
57 | (saber) $ pip install -e .
58 | ```
59 | or `setuptools`
60 |
61 | ```sh
62 | (saber) $ python setup.py install
63 | ```
64 |
65 | !!! note
66 | See [Running tests](#running-tests) for a way to verify your installation.
67 |
68 | ## (OPTIONAL) Creating and activating virtual environments
69 |
70 | When using `pip` it is generally recommended to install packages in a virtual environment to avoid modifying system state. To create a virtual environment named `saber`
71 |
72 | ### Using virtualenv or venv
73 |
74 | Using [virtualenv](https://virtualenv.pypa.io/en/stable/)
75 |
76 | ```
77 | $ virtualenv --python=python3 /path/to/new/venv/saber
78 | ```
79 |
80 | Using [venv](https://docs.python.org/3/library/venv.html)
81 |
82 | ```
83 | $ python3 -m venv /path/to/new/venv/saber
84 | ```
85 |
86 | Next, you need to activate the environment.
87 |
88 | ```
89 | $ source /path/to/new/venv/saber/bin/activate
90 | # Notice your command prompt has changed to indicate that the environment is active
91 | (saber) $
92 | ```
93 |
94 | ### Using Conda
95 |
96 | If you use [Conda](https://conda.io/docs/) / [Miniconda](https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh), you can create an environment named `saber` by running
97 |
98 | ```
99 | $ conda create -n saber python=3.6
100 | ```
101 |
102 | To activate the environment
103 |
104 | ```
105 | $ conda activate saber
106 | # Notice your command prompt has changed to indicate that the environment is active
107 | (saber) $
108 | ```
109 |
110 | !!! note
111 | You do not _need_ to name the environment `saber`.
112 |
113 | ## Running tests
114 |
115 | Sabers test suite can be found in `saber/tests`. If Saber is already installed, you can run `pytest` on the installation directory
116 |
117 | ```
118 | # Install pytest
119 | (saber) $ pip install pytest
120 | # Find out where Saber is installed
121 | (saber) $ INSTALL_DIR=$(python -c "import os; import saber; print(os.path.dirname(saber.__file__))")
122 | # Run tests on that installation directory
123 | (saber) $ python -m pytest $INSTALL_DIR
124 | ```
125 |
126 | Alternatively, to clone Saber, install it, and run the test suite all in one go
127 |
128 | ```
129 | (saber) $ git clone https://github.com/BaderLab/saber.git
130 | (saber) $ cd saber
131 | (saber) $ python setup.py test
132 | ```
133 |
--------------------------------------------------------------------------------
/docs/quick_start.md:
--------------------------------------------------------------------------------
1 | # Quick Start
2 |
3 | If your goal is to use Saber to annotate biomedical text, then you can either use the [web-service](#web-service) or a [pre-trained model](#pre-trained-models). If you simply want to check Saber out, without installing anything locally, try the [Google Colaboratory](#google-colaboratory) notebook.
4 |
5 | ## Google Colaboratory
6 |
7 | The fastest way to check out Saber is by following along with the Google Colaboratory notebook ([](https://colab.research.google.com/drive/1WD7oruVuTo6p_908MQWXRBdLF3Vw2MPo)). In order to be able to run the cells, select "Open in Playground" or, alternatively, save a copy to your own Google Drive account (File > Save a copy in Drive).
8 |
9 | ## Web-service
10 |
11 | To use Saber as a **local** web-service, run
12 |
13 | ```
14 | (saber) $ python -m saber.cli.app
15 | ```
16 |
17 | or, if you prefer, you can pull & run the Saber image from **Docker Hub**
18 |
19 | ```
20 | # Pull Saber image from Docker Hub
21 | $ docker pull pathwaycommons/saber
22 | # Run docker (use `-dt` instead of `-it` to run container in background)
23 | $ docker run -it --rm -p 5000:5000 --name saber pathwaycommons/saber
24 | ```
25 |
26 | !!! tip
27 | Alternatively, you can clone the GitHub repository and build the container from the `Dockerfile` with `docker build -t saber .`
28 |
29 | There are currently two endpoints, `/annotate/text` and `/annotate/pmid`. Both expect a `POST` request with a JSON payload, e.g.
30 |
31 | ```json
32 | {
33 | "text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."
34 | }
35 | ```
36 |
37 | or
38 |
39 | ```json
40 | {
41 | "pmid": 11835401
42 | }
43 | ```
44 |
45 | For example, with the web-service running locally
46 |
47 | ``` bash tab="Bash"
48 | curl -X POST 'http://localhost:5000/annotate/text' \
49 | --data '{"text": 'The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.'}'
50 | ```
51 |
52 | ``` python tab="python"
53 | import requests # assuming you have requests package installed!
54 |
55 | url = "http://localhost:5000/annotate/pmid"
56 | payload = {"text": "The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53."}
57 | response = requests.post(url, json=payload)
58 |
59 | print(response.text)
60 | print(response.status_code, response.reason)
61 | ```
62 |
63 | !!! warning
64 | The first request to the web-service will be slow (~60s). This is because a large language
65 | model needs to be loaded into memory.
66 |
67 | Documentation for the Saber web-service API can be found [here](https://baderlab.github.io/saber-api-docs/). We hope to provide a live version of the web-service soon!
68 |
69 | ## Pre-trained models
70 |
71 | First, import `Saber`. This class coordinates training, annotation, saving and loading of models and datasets. In short, this is the interface to Saber.
72 |
73 | ```python
74 | from saber.saber import Saber
75 | ```
76 |
77 | To load a pre-trained model, first create a `Saber` object
78 |
79 | ```python
80 | saber = Saber()
81 | ```
82 |
83 | and then load the model of our choice
84 |
85 | ```python
86 | saber.load('PRGE')
87 | ```
88 |
89 | !!! tip
90 | See [Resources: Pre-trained models](../resources#pre-trained-models) for pre-trained model names and details. You will need an internet connection to download a pre-trained model.
91 |
92 | To annotate text with the model, just call the `Saber.annotate()` method
93 |
94 | ```python
95 | saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.")
96 | ```
97 |
98 | !!! warning
99 | The `Saber.annotate()` method will be slow the first time you call it (~60s). This is because a large language model needs to be loaded into memory.
100 |
101 | ### Coreference Resolution
102 |
103 | [**Coreference**](https://en.wikipedia.org/wiki/Coreference) occurs when two or more expressions in a text refer to the same person or thing, that is, they have the same **referent**. Take the following example:
104 |
105 | _"__IL-6__ supports tumour growth and metastasising in terminal patients, and __it__ significantly engages in cancer cachexia (including anorexia) and depression associated with malignancy."_
106 |
107 | Clearly, "__it__" referes to "__IL-6__". If we do not resolve this coreference, then "__it__" will not be labeled as an entity and any relation or event it is mentioned in will not be extracted. Saber uses [NeuralCoref](https://github.com/huggingface/neuralcoref), a state-of-the-art coreference resolution tool based on neural nets and built on top of [Spacy](https://spacy.io). To use it, just supply the argument `coref=True` (which is `False` by default) to the `Saber.annotate()` method
108 |
109 | ```python
110 | text = "IL-6 supports tumour growth and metastasising in terminal patients, and it significantly engages in cancer cachexia (including anorexia) and depression associated with malignancy."
111 | # WITHOUT coreference resolution
112 | saber.annotate(text, coref=False)
113 | # WITH coreference resolution
114 | saber.annotate(text, coref=True)
115 | ```
116 |
117 | !!! note
118 | If you are using the web-service, simply supply `"coref": true` in your `JSON` payload to resolve coreferences.
119 |
120 | Saber currently takes the simplest possible approach: replace all coreference mentions with their referent, and then feed the resolved text to the model that identifies named entities.
121 |
122 | ### Grounding
123 |
124 | **Grounding** (sometimes called **entity linking** or **normalization**) involves mapping each annotated entity to a unique identifier in an external resource such as a database or ontology. To ground entities in a call to `Saber.annotate()`, simply pass the argument `ground=True`
125 |
126 | ```python
127 | saber.annotate('The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.', ground=True)
128 | ```
129 |
130 | The grounding functionality is implemented by the [EXTRACT 2.0 API](https://extract.jensenlab.org/). Note that you will need an internet connection or grounding will fail. Also note that `Saber.annotate()` will take slightly longer to return a response when `ground=True` (up to a few seconds).
131 |
132 | See [Resources: Pre-trained models](../resources#pre-trained-models) for a list of the the external resources each entity type (annotated by the pre-trained models) is grounded to.
133 |
134 | !!! note
135 | If you are using the web-service, simply supply `"ground": true` in your `JSON` payload to ground entities.
136 |
137 | ### Working with annotations
138 |
139 | The `Saber.annotate()` method returns a simple `dict` object
140 |
141 | ```python
142 | ann = saber.annotate("The phosphorylation of Hdm2 by MK2 promotes the ubiquitination of p53.")
143 | ```
144 |
145 | which contains the keys `title`, `text` and `ents`
146 |
147 | - `title`: contains the title of the article, if provided
148 | - `text`: contains the text (which is minimally processed) the model was deployed on
149 | - `ents`: contains a list of entities present in the `text` that were annotated by the model
150 |
151 | For example, to see all entities annotated by the model, call
152 |
153 | ```python
154 | ann['ents']
155 | ```
156 |
157 | #### Converting annotations to JSON
158 |
159 | The `Saber.annotate()` method returns a `dict` object, but can be converted to a `JSON` formatted string for ease-of-use in downstream applications
160 |
161 | ```python
162 | import json
163 |
164 | # convert to json object
165 | json_ann = json.dumps(ann)
166 |
167 | # convert back to python dictionary
168 | ann = json.loads(json_ann)
169 | ```
170 |
--------------------------------------------------------------------------------
/docs/resources.md:
--------------------------------------------------------------------------------
1 | # Resources
2 |
3 | Saber is ready to go out-of-the box when using the __web-service__ or a __pre-trained model__. However, if you plan on training you own models, you will need to provide a dataset (or datasets!) and, ideally, pre-trained word embeddings.
4 |
5 | ## Pre-trained models
6 |
7 | Pre-trained model names can be passed to `Saber.load()` (see [Quick Start: Pre-trained Models](https://baderlab.github.io/saber/quick_start/#pre-trained-models)). Appending `"*-large"` to the model name (e.g. `"PRGE-large"` will download a much larger model, which should perform slightly better than the base model.
8 |
9 | Identifier | Semantic Group | Identified entity types | Namespace |
10 | ---------- | -------------- | ----------------------- | --------- |
11 | `CHED` | Chemicals | Abbreviations and Acronyms, Molecular Formulas, Chemical database identifiers, IUPAC names, Trivial (common names of chemicals and trademark names), Family (chemical families with a defined structure) and Multiple (non-continuous mentions of chemicals in text) | [PubChem Compounds](https://pubchem.ncbi.nlm.nih.gov/)
12 | `DISO` | Disorders | Acquired Abnormality, Anatomical Abnormality, Cell or Molecular Dysfunction, Congenital Abnormality, Disease or Syndrome, Mental or Behavioral Dysfunction, Neoplastic Process, Pathologic Function, Sign or Symptom | [Disease Ontology](http://disease-ontology.org/)
13 | `LIVB` | Organisms | Species, Taxa | [NCBI Taxonomy](https://www.ncbi.nlm.nih.gov/taxonomy)
14 | `PRGE` | Genes and Gene Products | Genes, Gene Products | [STRING](https://string-db.org/)
15 |
16 | ## Datasets
17 |
18 | Currently, Saber requires corpora to be in a **CoNLL** format with a BIO or IOBES tag scheme, e.g.:
19 |
20 | ```
21 | Selegiline B-CHED
22 | - O
23 | induced O
24 | postural B-DISO
25 | hypotension I-DISO
26 | ...
27 | ```
28 |
29 | Corpora in such a format are collected in [here](https://github.com/BaderLab/Biomedical-Corpora) for convenience.
30 |
31 | !!! info
32 | Many of the corpora in the BIO and IOBES tag format were originally collected by [Crichton _et al_., 2017](https://doi.org/10.1186/s12859-017-1776-8), [here](https://github.com/cambridgeltl/MTL-Bioinformatics-2016).
33 |
34 | In this format, the first column contains each token of an input sentence, the last column contains the tokens tag, all columns are separated by tabs, and all sentences by a newline.
35 |
36 | Of course, not all corpora are distributed in the CoNLL format:
37 |
38 | - Corpora in the **Standoff** format can be converted to **CoNLL** format using [this](https://github.com/spyysalo/standoff2conll) tool.
39 | - Corpora in **PubTator** format can be converted to **Standoff** first using [this](https://github.com/spyysalo/pubtator) tool.
40 |
41 | Saber infers the "training strategy" based on the structure of the dataset folder:
42 |
43 | - To use k-fold cross-validation, simply provide a `train.*` file in your dataset folder.
44 |
45 | E.g.
46 | ```
47 | .
48 | ├── NCBI_Disease
49 | │ └── train.tsv
50 | ```
51 |
52 | - To use a train/valid/test strategy, provide `train.*` and `test.*` files in your dataset folder. Optionally, you can provide a `valid.*` file. If not provided, a random 10% of examples from `train.*` are used as the validation set.
53 |
54 | E.g.
55 | ```
56 | .
57 | ├── NCBI_Disease
58 | │ ├── test.tsv
59 | │ └── train.tsv
60 | ```
61 |
62 | ## Word embeddings
63 |
64 | When training new models, you can (and should) provide your own pre-trained word embeddings with the `pretrained_embeddings` argument (either at the command line or in the configuration file). Saber expects all word embeddings to be in the `word2vec` file format. [Pyysalo _et al_. 2013](https://pdfs.semanticscholar.org/e2f2/8568031e1902d4f8ee818261f0f2c20de6dd.pdf) provide word embeddings that work quite well in the biomedical domain, which can be downloaded [here](http://bio.nlplab.org). Alternatively, from the command line call:
65 |
66 | ```
67 | # Replace this with a location you want to save the embeddings to
68 | $ mkdir path/to/word_embeddings
69 | # Note: this file is over 4GB
70 | $ wget http://evexdb.org/pmresources/vec-space-models/wikipedia-pubmed-and-PMC-w2v.bin -O path/to/word_embeddings
71 | ```
72 |
73 | To use these word embeddings with Saber, provide their path in the `pretrained_embeddings` argument (either in the `config` file or at the command line). Alternatively, pass their path to `Saber.load_embeddings()`. For example:
74 |
75 | ```python
76 | from saber.saber import Saber
77 |
78 | saber = Saber()
79 |
80 | saber.load_dataset('path/to/dataset')
81 | # load the embeddings here
82 | saber.load_embeddings('path/to/word_embeddings')
83 |
84 | saber.build()
85 | saber.train()
86 | ```
87 |
88 | ### GloVe
89 |
90 | To use [GloVe](https://nlp.stanford.edu/projects/glove/) embeddings, just convert them to the [word2vec](https://code.google.com/archive/p/word2vec/) format first:
91 |
92 | ```
93 | (saber) $ python
94 | >>> from gensim.scripts.glove2word2vec import glove2word2vec
95 | >>> glove_input_file = 'glove.txt'
96 | >>> word2vec_output_file = 'word2vec.txt'
97 | >>> glove2word2vec(glove_input_file, word2vec_output_file)
98 | ```
99 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | # Project information
2 | site_name: Saber
3 | site_description: 'Documentation for Saber'
4 | site_author: 'John Giorgi'
5 | site_url: 'https://baderlab.github.io/saber/'
6 |
7 | # Repository
8 | repo_name: 'BaderLab/saber'
9 | repo_url: 'https://github.com/BaderLab/saber'
10 |
11 | # Site map
12 | nav:
13 | - About: index.md
14 | - Installation: installation.md
15 | - Quick Start: quick_start.md
16 | - Guide to the Saber API: guide_to_saber_api.md
17 | - Resources: resources.md
18 |
19 | # Configuration
20 | theme:
21 | language: 'en'
22 | name: 'material'
23 | logo:
24 | icon: 'memory'
25 | palette:
26 | primary: 'deep purple'
27 | accent: 'deep purple'
28 |
29 | # Customization
30 | extra:
31 | social:
32 | - type: 'github'
33 | link: 'https://github.com/BaderLab'
34 |
35 | # Extensions
36 | markdown_extensions:
37 | - admonition
38 | - codehilite
39 | - pymdownx.superfences
40 | - pymdownx.details
41 | - pymdownx.emoji:
42 | emoji_generator: !!python/name:pymdownx.emoji.to_svg
43 | - toc:
44 | permalink: true
45 |
--------------------------------------------------------------------------------
/saber/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | from datetime import datetime
5 |
6 | # set Tensorflow logging level
7 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8 |
9 | # if applicable, delete the existing log file to generate a fresh log file during each execution
10 | try:
11 | os.remove("saber.log")
12 | except OSError:
13 | pass
14 |
15 | # create the logger
16 | logging.basicConfig(filename="saber.log",
17 | level=logging.DEBUG,
18 | format='%(name)s - %(levelname)s - %(message)s')
19 |
20 | # log the date to start
21 | logging.info('Saber invocation: %s\n%s', datetime.now(), '=' * 75)
22 |
--------------------------------------------------------------------------------
/saber/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/cli/__init__.py
--------------------------------------------------------------------------------
/saber/cli/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Simple web service which exposes Saber's functionality via a RESTful API.
3 | """
4 | import argparse
5 | import logging
6 |
7 | from flask import Flask, jsonify, redirect, request
8 | from waitress import serve
9 |
10 | from .. import constants
11 | from ..utils import app_utils
12 |
13 | app = Flask(__name__)
14 |
15 | LOGGER = logging.getLogger(__name__)
16 |
17 | @app.route('/')
18 | def serve_api_docs():
19 | """Flask view that redirects to the Saber API docs from route '/'.
20 | """
21 | return redirect('https://baderlab.github.io/saber-api-docs/')
22 |
23 | @app.route('/annotate/text', methods=['POST'])
24 | def annotate_text():
25 | """Annotates raw text recieved in a POST request.
26 |
27 | Returns:
28 | JSON formatted string.
29 | """
30 | parsed_request_json = app_utils.parse_request_json(request)
31 | # raw text to perform annotation on
32 | text = parsed_request_json['text']
33 |
34 | annotation = predict(text=text,
35 | ents=parsed_request_json['ents'],
36 | coref=parsed_request_json['coref'],
37 | ground=parsed_request_json['ground'])
38 |
39 | return jsonify(annotation)
40 |
41 | @app.route('/annotate/pmid', methods=['POST'])
42 | def annotate_pmid():
43 | """Annotates the abstract of a document with the PubMed ID recieved in a POST request.
44 |
45 | Returns:
46 | JSON formatted string.
47 | """
48 | parsed_request_json = app_utils.parse_request_json(request)
49 | # use Entrez Utilities Web Service API to get the abtract text
50 | title, abstract = app_utils.get_pubmed_text(parsed_request_json['pmid'])
51 | text = '{}\n{}'.format(title, abstract)
52 |
53 | annotation = predict(text=text,
54 | ents=parsed_request_json['ents'],
55 | coref=parsed_request_json['coref'],
56 | ground=parsed_request_json['ground'])
57 |
58 | return jsonify(annotation)
59 |
60 | def predict(text, ents, coref=False, ground=False):
61 | """Annotates raw text (`text`) for entities according to their boolean value in `ents`.
62 |
63 | Args:
64 | text (str): Raw text to be annotated.
65 | ents: Dictionary of entity, boolean pairs representing whether or not to annotate the text
66 | for the given entities.
67 |
68 | Returns:
69 | Dictionary containing the annotated entities and processed text.
70 | """
71 | annotations = []
72 | for ent, value in ents.items():
73 | if value:
74 | # TEMP: Weird solution to a weird bug
75 | # https://github.com/tensorflow/tensorflow/issues/14356#issuecomment-385962623
76 | with GRAPH.as_default():
77 | annotations.append(MODELS[ent].annotate(text, coref=coref, ground=ground))
78 |
79 | # if multiple models, combine annotations into one object
80 | final_annotation = annotations[0]
81 | if len(annotations) > 1:
82 | combined_ents = app_utils.combine_annotations(annotations)
83 | final_annotation['ents'] = combined_ents
84 |
85 | return final_annotation
86 |
87 | if __name__ == '__main__':
88 | # parse command line arguments
89 | parser = argparse.ArgumentParser(description='Saber web-service.')
90 | parser.add_argument('-p', '--port', help='Port number. Defaults to 5000', type=int, default=5000)
91 | args = vars(parser.parse_args())
92 | # load the pre-trained models
93 | MODELS, GRAPH = app_utils.load_models(constants.ENTITIES)
94 |
95 | serve(app, host='0.0.0.0', port=args['port'])
96 |
--------------------------------------------------------------------------------
/saber/cli/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """A simple script to train a model with Saber.
3 |
4 | Run the script with:
5 | ```
6 | python -m saber.cli.train
7 | ```
8 | e.g.
9 | ```
10 | python -m saber.cli.train --dataset_folder ./datasets/NCBI_disease_BIO --epochs 25
11 | ```
12 | """
13 | import logging
14 |
15 | from ..config import Config
16 | from ..saber import Saber
17 |
18 |
19 | def main():
20 | """Coordinates a complete training cycle, including reading in a config, loading dataset(s),
21 | training the model, and saving the models weights."""
22 | config = Config(cli=True)
23 | saber = Saber(config)
24 |
25 | if config.pretrained_model:
26 | saber.load(config.pretrained_model)
27 |
28 | saber.load_dataset()
29 |
30 | # don't build a new model if pre-trained one was provided
31 | if not config.pretrained_model:
32 | # don't load embeddings if a pre-trained model was provided
33 | if config.pretrained_embeddings:
34 | saber.load_embeddings()
35 | saber.build()
36 |
37 | try:
38 | saber.train()
39 | except KeyboardInterrupt:
40 | print("\nQutting Saber...")
41 | logging.warning('Saber was terminated early due to KeyboardInterrupt')
42 | finally:
43 | if config.save_model:
44 | saber.save()
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/saber/config.ini:
--------------------------------------------------------------------------------
1 | [mode]
2 | # Possible models: [MT-LSTM-CRF, ]
3 | model_name = MT-LSTM-CRF
4 | # If True, model is compressed and saved in output_folder at the end of training. Model weights are
5 | # taken from the last epoch.
6 | save_model = False
7 |
8 | [data]
9 | # You can specify multiple datasets by listing their paths, separated by a comma.
10 | dataset_folder = ./datasets/NCBI_Disease_BIO
11 | output_folder = ./output
12 | # Path to pre-trained model. To train a model from scratch, leave this blank.
13 | pretrained_model =
14 | # Path to pre-trained word embeddings. In order to use random initialization, leave this blank.
15 | # Note that you can leave this blank when loading a pre-trained model
16 | # (via 'SequenceProcessor.load()') that was trained with pre-tained embeddings
17 | pretrained_embeddings =
18 |
19 | [model]
20 | # If pre-trained word embeddings are provided, word_embed_dim will be the same size as these
21 | # embeddings and this argument will be ignored.
22 | word_embed_dim = 200
23 | char_embed_dim = 30
24 |
25 | [training]
26 | # Values chosen for each hyperparameter represent sensible defaults that perform well across a wide
27 | # range of NLP tasks (POS tagging, Chunking, NER, etc.) and thus should only be changed in special
28 | # circumstances.
29 | optimizer = nadam
30 | activation = relu
31 | # Set to 0 to turn off gradient normalization.
32 | grad_norm = 1.0
33 | # For certain optimizers, these values are ignored. See compile_model() in
34 | # saber/utils/model_utils.py.
35 | learning_rate = 0.0
36 | decay = 0.0
37 |
38 | # Three dropout values must be specified (separated by a comma), corresponding to the dropout rate
39 | # to apply to the input, output and recurrent connections respectively. Must be a value between 0.0
40 | # and 1.0.
41 | dropout_rate = 0.3, 0.3, 0.1
42 |
43 | batch_size = 32
44 | # If a test partition is supplied at 'dataset_folder' (test.*) then this argument is ignored, and a
45 | # simple train/valid/test scheme is used. A valid partition (valid.*) may optionally be provided
46 | # along with the test partition. If none is found, 10% of examples are randomly selected.
47 | k_folds = 5
48 | epochs = 50
49 |
50 | # Matching criteria used when determining whether or not a prediction is a true-positive. Choices
51 | # are 'left' for left-boundary matching, 'right' for right-boundary matching and 'exact' for
52 | # exact-boundary matching.
53 | criteria = exact
54 |
55 | [advanced]
56 | verbose = False
57 | debug = False
58 | # If True, per-epoch logs which can be visualized with TensorBoard are written to output_folder
59 | # Note: These logs can be quite large.
60 | tensorboard = False
61 | # If True, then during training the models weights for each and every epoch will be saved.
62 | # Otherwise, weights are only saved for epochs that achieve a new best on validation loss.
63 | save_all_weights = False
64 | # If True, tokens that occur less than 1 time in the training dataset (hapax legomenon) are replaced
65 | # with a special unknown token. This should result in faster loading times of pre-trained word
66 | # embeddings and faster training times.
67 | replace_rare_tokens = True
68 | # If True, then all pre-trained word embeddings provided via pretrained_embeddings are loaded.
69 | # Otherwise, only pre-trained embeddings for tokens found in the dataset(s) at dataset_folder are
70 | # loaded. For evaluation, it's best to leave this as False. For models that will be deployed, it's
71 | # best to set it to True.
72 | load_all_embeddings = False
73 | # If True, then pre-trained word embeddings will be fine-tuned with the other parameters of the
74 | # neural network during training. Generally, you should not set this to True unless you have a very
75 | # large training dataset.
76 | # NOTE: if 'pretrained_embeddings' are not provided, they will be randomly initialized and
77 | # fine-tuned during training, ignoring this argument.
78 | fine_tune_word_embeddings = False
79 | # TEMP. Set to true if variational dropout should be used.
80 | variational_dropout = False
81 |
--------------------------------------------------------------------------------
/saber/constants.py:
--------------------------------------------------------------------------------
1 | """Collection of constants used by Saber.
2 | """
3 | from pkg_resources import resource_filename
4 |
5 | __version__ = '0.1.0-alpha'
6 |
7 | # DISPLACY OPTIONS
8 | # entity colours
9 | COLOURS = {'PRGE': 'linear-gradient(90deg, #aa9cfc, #fc9ce7)',
10 | 'DISO': 'linear-gradient(90deg, #ef9a9a, #f44336)',
11 | 'CHED': 'linear-gradient(90deg, #1DE9B6, #A7FFEB)',
12 | 'LIVB': 'linear-gradient(90deg, #FF4081, #F8BBD0)',
13 | 'CL': 'linear-gradient(90deg, #00E5FF, #84FFFF)',
14 | }
15 | # entity options
16 | OPTIONS = {'colors': COLOURS}
17 |
18 | # SPECIAL TOKENS
19 | UNK = '' # out-of-vocabulary token
20 | PAD = '' # sequence pad token
21 | START = '' # start-of-sentence token
22 | END = '' # end-of-sentence token
23 | OUTSIDE_TAG = 'O' # 'outside' tag of the IOB, BIO, and IOBES tag formats
24 |
25 | # MISC.
26 | PAD_VALUE = 0 # value of sequence pad
27 | NUM_RARE = 1 # tokens that occur less than NUM_RARE times are replaced UNK
28 | # mapping of special tokens to contants
29 | INITIAL_MAPPING = {'word': {PAD: 0, UNK: 1}, 'tag': {PAD: 0}}
30 | # keys into dictionaries containing information for different partitions of a dataset
31 | PARTITIONS = ['train', 'valid', 'test']
32 |
33 | # FILEPATHS / FILENAMES
34 | # train, valid and test filename patterns
35 | TRAIN_FILE = 'train.*'
36 | VALID_FILE = 'valid.*'
37 | TEST_FILE = 'test.*'
38 | # pre-trained models
39 | ENTITIES = {'ANAT': False,
40 | 'CHED': True,
41 | 'DISO': True,
42 | 'LIVB': False,
43 | 'PRGE': True,
44 | 'TRIG': False}
45 | # Google Drive File IDs for the pre-trained models
46 | PRETRAINED_MODELS = {
47 | 'PRGE': '1xOmxpgNjQJK8OJSvih9wW5AITGQX6ODT',
48 | 'DISO': '1qmrBuqz75KM57Ug5MiDBfp0d5H3S_5ih',
49 | 'CHED': '13s9wvu3Mc8fG73w51KD8RArA31vsuL1c',
50 | }
51 | # relative path to pre-trained model directory
52 | PRETRAINED_MODEL_DIR = resource_filename(__name__, 'pretrained_models')
53 | MODEL_FILENAME = 'model_params.json'
54 | WEIGHTS_FILENAME = 'model_weights.hdf5'
55 | ATTRIBUTES_FILENAME = 'attributes.pickle'
56 | CONFIG_FILENAME = 'config.ini'
57 |
58 | # MODEL SETTINGS
59 | # batch size to use when performing model prediction
60 | PRED_BATCH_SIZE = 256
61 | # max length of a sentence
62 | MAX_SENT_LEN = 100
63 | # max length of a character sequence (word)
64 | MAX_CHAR_LEN = 25
65 | # number of units in the LSTM layers
66 | UNITS_WORD_LSTM = 200
67 | UNITS_CHAR_LSTM = 200
68 | UNITS_DENSE = UNITS_WORD_LSTM // 2
69 | # possible models
70 | MODEL_NAMES = ['mt-lstm-crf',]
71 |
72 | # EXTRACT 2.0 API
73 | # arguments passed in a get request to the EXTRACT 2.0 API to specify entity type
74 | ENTITY_TYPES = {'CHED': -1, 'DISO': -26, 'LIVB': -2}
75 | # the namespaces of the external resources that EXTRACT 2.0 grounds too
76 | NAMESPACES = {'CHED': 'PubChem Compound',
77 | 'DISO': 'Disease Ontology',
78 | 'LIVB': 'NCBI Taxonomy',
79 | 'PRGE': 'STRING',
80 | }
81 |
82 | # RESTful API
83 | # endpoint for Entrez Utilities Web Service API
84 | EUTILS_API_ENDPOINT = ('https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?retmode=xml&db='
85 | 'pubmed&id=')
86 | # CONFIG
87 | CONFIG_ARGS = ['model_name', 'save_model', 'dataset_folder', 'output_folder',
88 | 'pretrained_model', 'pretrained_embeddings', 'word_embed_dim',
89 | 'char_embed_dim', 'optimizer', 'activation', 'learning_rate', 'decay', 'grad_norm',
90 | 'dropout_rate', 'batch_size', 'k_folds', 'epochs', 'criteria', 'verbose',
91 | 'debug', 'save_all_weights', 'tensorboard', 'replace_rare_tokens',
92 | 'load_all_embeddings', 'fine_tune_word_embeddings', 'variational_dropout']
93 |
--------------------------------------------------------------------------------
/saber/dataset.py:
--------------------------------------------------------------------------------
1 | """Contains the Dataset class, which handles the loading and storage of datasets.
2 | """
3 | import logging
4 | import os
5 | from itertools import chain
6 |
7 | from keras.utils import to_categorical
8 | from nltk.corpus.reader.conll import ConllCorpusReader
9 |
10 | from . import constants
11 | from .preprocessor import Preprocessor
12 | from .utils import data_utils, generic_utils
13 |
14 | LOGGER = logging.getLogger(__name__)
15 |
16 | class Dataset(object):
17 | """A class for handling datasets. Expects datasets to be in tab-seperated CoNLL format, where
18 | each line contains a token and its tag (seperated by a tab) and each sentence is seperated
19 | by a blank line.
20 |
21 | Example corpus:
22 | '''
23 | The O
24 | transcription O
25 | of O
26 | most O
27 | RP B-PRGE
28 | genes I-PRGE
29 | ...
30 | '''
31 |
32 | Args:
33 | directory (str): Path to directory containing CoNLL formatted dataset(s).
34 | replace_rare_tokens (bool): True if rare tokens should be replaced with a special unknown
35 | token. Threshold for considering tokens rare can be found at `saber.constants.NUM_RARE`.
36 | """
37 | def __init__(self, directory=None, replace_rare_tokens=True, **kwargs):
38 | self.directory = directory
39 | # don't load corpus unless directory was passed on object construction
40 | if self.directory is not None:
41 | self.directory = data_utils.get_filepaths(directory)
42 | self.conll_parser = ConllCorpusReader(directory, '.conll', ('words', 'pos'))
43 |
44 | self.replace_rare_tokens = replace_rare_tokens
45 |
46 | # word, character and tag sequences from dataset (per partition)
47 | self.type_seq = {'train': None, 'valid': None, 'test': None}
48 | # mappings of word, characters, and tag types to unique integer IDs
49 | self.type_to_idx = {'word': None, 'char': None, 'tag': None}
50 | # reverse mapping of unique integer IDs to tag types
51 | self.idx_to_tag = None
52 | # same as type_seq but all words, characters and tags have been mapped to unique integer IDs
53 | self.idx_seq = {'train': None, 'valid': None, 'test': None}
54 |
55 | for key, value in kwargs.items():
56 | setattr(self, key, value)
57 |
58 | def load(self):
59 | """Coordinates the loading of a given data set at `self.directory`.
60 |
61 | For a given dataset in CoNLL format at `self.directory`, coordinates the loading of data and
62 | updates the appropriate instance attributes. Expects `self.directory` to be a directory
63 | containing a single file, `train.*` and optionally two additional files, `valid.*` and
64 | `test.*`.
65 |
66 | Raises:
67 | ValueError if `self.directory` is None.
68 | """
69 | if self.directory is None:
70 | err_msg = "`Dataset.directory` is None; must be provided before call to `Dataset.load`"
71 | LOGGER.error('ValueError %s', err_msg)
72 | raise ValueError(err_msg)
73 |
74 | # unique words, chars and tags from CoNLL formatted dataset
75 | types = self._get_types()
76 | # map each word, char, and tag type to a unique integer
77 | self._get_idx_maps(types)
78 |
79 | # get word, char, and tag sequences from CoNLL formatted dataset
80 | self._get_type_seq()
81 | # get final representation used for training
82 | self.get_idx_seq()
83 |
84 | # useful during prediction / annotation
85 | self.idx_to_tag = generic_utils.reverse_dict(self.type_to_idx['tag'])
86 |
87 | def _get_types(self):
88 | """Collects the sets of all words, characters and tags in a CoNLL formatted dataset.
89 |
90 | For the CoNLL formatted dataset given at `self.directory`, updates `self.types` with the
91 | sets of all words (word types), characters (character types) and tags (tag types). All types
92 | are shared across all partitions, that is, word, char and tag types are collected from the
93 | train and, if provided, valid/test partitions found at `self.directory/train.*`,
94 | `self.directory/valid.*` and `self.directory/test.*`.
95 | """
96 | types = {'word': [constants.PAD, constants.UNK],
97 | 'char': [constants.PAD, constants.UNK],
98 | 'tag': [constants.PAD],
99 | }
100 |
101 | for _, filepath in self.directory.items():
102 | if filepath is not None:
103 | conll_file = os.path.basename(filepath) # get name of conll file
104 | types['word'].extend(set(self.conll_parser.words(conll_file)))
105 | types['char'].extend(set(chain(*[list(w) for w in self.conll_parser.words(conll_file)])))
106 | types['tag'].extend(set([tag[-1] for tag in self.conll_parser.tagged_words(conll_file)]))
107 |
108 | # ensure that we have only unique types
109 | types['word'] = list(set(types['word']))
110 | types['char'] = list(set(types['char']))
111 | types['tag'] = list(set(types['tag']))
112 |
113 | return types
114 |
115 | def _get_type_seq(self):
116 | """Loads sequence data from a CoNLL format data set given at `self.directory`.
117 |
118 | For the CoNLL formatted dataset given at `self.directory`, updates `self.type_seq` with
119 | lists containing the word, character and tag sequences for the train and, if provided,
120 | valid/test partitions found at `self.directory/train.*`, `self.directory/valid.*` and
121 | `self.directory/test.*`.
122 | """
123 | for partition, filepath in self.directory.items():
124 | if filepath is not None:
125 | conll_file = os.path.basename(filepath) # get name of conll file
126 |
127 | # collect sequence data
128 | sents = list(self.conll_parser.sents(conll_file))
129 | tagged_sents = list(self.conll_parser.tagged_sents(conll_file))
130 |
131 | word_seq = Preprocessor.replace_rare_tokens(sents) if self.replace_rare_tokens else sents
132 | char_seq = [[[c for c in w] for w in s] for s in sents]
133 | tag_seq = [[t[-1] for t in s] for s in tagged_sents]
134 |
135 | # update the class attributes
136 | self.type_seq[partition] = {'word': word_seq, 'char': char_seq, 'tag': tag_seq}
137 |
138 | def _get_idx_maps(self, types, initial_mapping=None):
139 | """Updates `self.type_to_idx` with mappings from word, char and tag types to unique int IDs.
140 | """
141 | initial_mapping = constants.INITIAL_MAPPING if initial_mapping is None else initial_mapping
142 | # generate type to index mappings
143 | self.type_to_idx['word'] = Preprocessor.type_to_idx(types['word'], initial_mapping['word'])
144 | self.type_to_idx['char'] = Preprocessor.type_to_idx(types['char'], initial_mapping['word'])
145 | self.type_to_idx['tag'] = Preprocessor.type_to_idx(types['tag'], initial_mapping['tag'])
146 |
147 | def get_idx_seq(self):
148 | """Updates `self.idx_seq` with the final representation of the data used for training.
149 |
150 | Updates `self.idx_seq` with numpy arrays, by using `self.type_to_idx` to map all elements
151 | in `self.type_seq` to their corresponding integer IDs, for the train and, if provided,
152 | valid/test partitions found at `self.directory/train.*`, `self.directory/valid.*` and
153 | `self.directory/test.*`.
154 | """
155 | for partition, filepath in self.directory.items():
156 | if filepath is not None:
157 | self.idx_seq[partition] = {
158 | 'word': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['word'],
159 | self.type_to_idx['word'],
160 | type_='word'),
161 | 'char': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['word'],
162 | self.type_to_idx['char'],
163 | type_='char'),
164 | 'tag': Preprocessor.get_type_idx_sequence(self.type_seq[partition]['tag'],
165 | self.type_to_idx['tag'],
166 | type_='tag'),
167 | }
168 | # one-hot encode our targets
169 | self.idx_seq[partition]['tag'] = to_categorical(self.idx_seq[partition]['tag'])
170 |
--------------------------------------------------------------------------------
/saber/embeddings.py:
--------------------------------------------------------------------------------
1 | """Contains the Embedding class, which provides all code for working with pre-trained embeddings.
2 | """
3 | import numpy as np
4 | from gensim.models import KeyedVectors
5 |
6 | from . import constants
7 | from .preprocessor import Preprocessor
8 |
9 |
10 | class Embeddings(object):
11 | """A class for loading and working with pre-trained word embeddings.
12 |
13 | Args:
14 | filepath (str): Path to file which contains pre-trained word embeddings.
15 | token_map (dict): A dictionary which maps tokens to unique integer IDs.
16 | """
17 | def __init__(self, filepath, token_map, **kwargs):
18 | self.filepath = filepath
19 | self.token_map = token_map
20 |
21 | self.matrix = None # matrix containing row vectors for all embedded tokens
22 | self.num_found = None # number of loaded embeddings
23 | self.num_embed = None # final count of embedded words
24 | self.dimension = None # dimension of these embeddings
25 |
26 | for key, value in kwargs.items():
27 | setattr(self, key, value)
28 |
29 | def load(self, binary=True, load_all=False):
30 | """Coordinates the loading of pre-trained word embeddings.
31 |
32 | Creates an embedding matrix from the pre-trained word embeddings given at `self.filepath`,
33 | whose ith row corresponds to the word embedding for the word with value i in
34 | `self.token_map`. Updates the instance attributes `self.matrix, `self.num_found`,
35 | `self.dimension`, and `self.num_embed`.
36 |
37 | Args:
38 | binary (bool): True if pre-trained embeddings are in C binary format, False if they are
39 | in C text format. Defaults to True.
40 | load_all (bool): True if all embeddings should be loaded. False if only words that
41 | appear in `self.token_map` should be loaded. Defaults to False.
42 |
43 | Returns:
44 | The token map used to map the embeddings to an embedding matrix.
45 | """
46 | # prepare the embedding indices
47 | embedding_idx = self._prepare_embedding_index(binary)
48 | self.num_found, self.dimension = len(embedding_idx), len(list(embedding_idx.values())[0])
49 | self.matrix, type_to_idx = self._prepare_embedding_matrix(embedding_idx, load_all)
50 | self.num_embed = self.matrix.shape[0] # num of embedded words
51 |
52 | return type_to_idx
53 |
54 | def _prepare_embedding_index(self, binary=True):
55 | """Returns an embedding index for pre-trained token embeddings.
56 |
57 | For pre-trained word embeddings given at `self.filepath`, returns a
58 | dictionary mapping words to their embedding (an 'embedding index'). If `self.debug` is
59 | True, only the first ten thousand vectors are loaded.
60 |
61 | Args:
62 | binary (bool): True if pre-trained embeddings are in C binary format, False if they are
63 | in C text format. Defaults to True.
64 |
65 | Returns:
66 | Dictionary mapping words to pre-trained word embeddings, known as an 'embedding index'.
67 | """
68 | limit = 10000 if self.__dict__.get("debug", False) else None
69 | vectors = KeyedVectors.load_word2vec_format(self.filepath, binary=binary, limit=limit)
70 | embedding_idx = {word: vectors[word] for word in vectors.vocab}
71 |
72 | return embedding_idx
73 |
74 | def _prepare_embedding_matrix(self, embedding_idx, load_all=False):
75 | """Returns an embedding matrix containing all pre-trained embeddings in `embedding_idx`.
76 |
77 | Creates an embedding matrix from `embedding_idx`, where the ith row contains the
78 | embedding for the word with value i in `self.token_map`. If no embedding exists for a given
79 | word in `embedding_idx`, the zero vector is used instead.
80 |
81 | Args:
82 | embedding_idx (dict): A Dictionary mapping words to their embeddings.
83 | load_all (bool): True if all embeddings should be loaded. False if only words that
84 | appear in `self.token_map` should be loaded. Defaults to False.
85 |
86 | Returns:
87 | A matrix whos ith row corresponds to the word embedding for the word with value i in
88 | `self.token_map`.
89 | """
90 | type_to_idx = None
91 | if load_all:
92 | # overwrite provided token map
93 | type_to_idx = self._generate_type_to_idx(embedding_idx)
94 | self.token_map = type_to_idx['word']
95 |
96 | # initialize the embeddings matrix
97 | embedding_matrix = np.zeros((len(self.token_map), self.dimension))
98 |
99 | # lookup embeddings for every word in the dataset
100 | for word, i in self.token_map.items():
101 | token_embedding = embedding_idx.get(word)
102 | if token_embedding is not None:
103 | # words not found in embedding index will be all-zeros.
104 | embedding_matrix[i] = token_embedding
105 |
106 | return embedding_matrix, type_to_idx
107 |
108 | @classmethod
109 | def _generate_type_to_idx(self, embedding_idx):
110 | """Returns a dictionary mapping tokens in `embedding_idx` to unique integer IDs.
111 | """
112 | word_types, char_types = list(embedding_idx.keys()), []
113 | for word in word_types:
114 | char_types.extend(list(word))
115 | char_types = list(set(char_types))
116 |
117 | type_to_idx = {
118 | 'word': Preprocessor.type_to_idx(word_types, constants.INITIAL_MAPPING['word']),
119 | 'char': Preprocessor.type_to_idx(char_types, constants.INITIAL_MAPPING['word'])
120 | }
121 |
122 | return type_to_idx
123 |
--------------------------------------------------------------------------------
/saber/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import multi_task_lstm_crf
4 |
--------------------------------------------------------------------------------
/saber/models/base_model.py:
--------------------------------------------------------------------------------
1 | """Contains the BaseModel class, the parent class to all Keras models in Saber.
2 | """
3 | import json
4 | import logging
5 |
6 | from keras import optimizers
7 | from keras.models import model_from_json
8 |
9 | LOGGER = logging.getLogger(__name__)
10 |
11 | class BaseKerasModel(object):
12 | """Parent class of all Keras model classes implemented by Saber.
13 | """
14 | def __init__(self, config, datasets, embeddings=None, **kwargs):
15 | self.config = config # hyperparameters and model details
16 | self.datasets = datasets # dataset(s) tied to this instance
17 | self.embeddings = embeddings # pre-trained word embeddings tied to this instance
18 | self.models = [] # Keras model(s) tied to this instance
19 |
20 | for key, value in kwargs.items():
21 | setattr(self, key, value)
22 |
23 | def save(self, weights_filepath, model_filepath, model=0):
24 | """Save a model to disk.
25 |
26 | Saves a keras model to disk, by saving its architecture as a json file at `model_filepath`
27 | and its weights as a hdf5 file at `model_filepath`.
28 |
29 | Args:
30 | weights_filepath (str): filepath to the models wieghts (.hdf5 file).
31 | model_filepath (str): filepath to the models architecture (.json file).
32 | model (int): which model from `self.models` to save.
33 | """
34 | with open(model_filepath, 'w') as f:
35 | model_json = self.models[model].to_json()
36 | json.dump(json.loads(model_json), f, sort_keys=True, indent=4)
37 | self.models[model].save_weights(weights_filepath)
38 |
39 | def load(self, weights_filepath, model_filepath):
40 | """Load a model from disk.
41 |
42 | Loads a keras model from disk by loading its architecture from a json file at `model_filepath`
43 | and its weights from a hdf5 file at `model_filepath`.
44 |
45 | Args:
46 | weights_filepath (str): filepath to the models weights (.hdf5 file).
47 | model_filepath (str): filepath to the models architecture (.json file).
48 | """
49 | with open(model_filepath) as f:
50 | model = model_from_json(f.read())
51 | model.load_weights(weights_filepath)
52 | self.models.append(model)
53 |
54 | def prepare_data_for_training(self):
55 | """Returns a list containing the training data for each dataset at `self.datasets`.
56 |
57 | For each dataset at `self.datasets`, collects the data to be used for training.
58 | Each dataset is represented by a dictionary, where the keys 'x_' and
59 | 'y_' contain the inputs and targets for each partition 'train', 'valid', and
60 | 'test'.
61 | """
62 | training_data = []
63 | for ds in self.datasets:
64 | # collect train data, must be provided
65 | x_train = [ds.idx_seq['train']['word'], ds.idx_seq['train']['char']]
66 | y_train = ds.idx_seq['train']['tag']
67 | # collect valid and test data, may not be provided
68 | x_valid, y_valid, x_test, y_test = None, None, None, None
69 | if ds.idx_seq['valid'] is not None:
70 | x_valid = [ds.idx_seq['valid']['word'], ds.idx_seq['valid']['char']]
71 | y_valid = ds.idx_seq['valid']['tag']
72 | if ds.idx_seq['test'] is not None:
73 | x_test = [ds.idx_seq['test']['word'], ds.idx_seq['test']['char']]
74 | y_test = ds.idx_seq['test']['tag']
75 |
76 | training_data.append({'x_train': x_train, 'y_train': y_train, 'x_valid': x_valid,
77 | 'y_valid': y_valid, 'x_test': x_test, 'y_test': y_test})
78 |
79 | return training_data
80 |
81 | def _compile(self, model, loss_function, optimizer, lr=0.01, decay=0.0, clipnorm=0.0):
82 | """Compiles a model specified with Keras.
83 |
84 | See https://keras.io/optimizers/ for more info on each optimizer.
85 |
86 | Args:
87 | model: Keras model object to compile
88 | loss_function: Keras loss_function object to compile model with
89 | optimizer (str): the optimizer to use during training
90 | lr (float): learning rate to use during training
91 | decay (float): per epoch decay rate
92 | clipnorm (float): gradient normalization threshold
93 | """
94 | # The parameters of these optimizers can be freely tuned.
95 | if optimizer == 'sgd':
96 | optimizer_ = optimizers.SGD(lr=lr, decay=decay, clipnorm=clipnorm)
97 | elif optimizer == 'adam':
98 | optimizer_ = optimizers.Adam(lr=lr, decay=decay, clipnorm=clipnorm)
99 | elif optimizer == 'adamax':
100 | optimizer_ = optimizers.Adamax(lr=lr, decay=decay, clipnorm=clipnorm)
101 | # It is recommended to leave the parameters of this optimizer at their
102 | # default values (except the learning rate, which can be freely tuned).
103 | # This optimizer is usually a good choice for recurrent neural networks
104 | elif optimizer == 'rmsprop':
105 | optimizer_ = optimizers.RMSprop(lr=lr, clipnorm=clipnorm)
106 | # It is recommended to leave the parameters of these optimizers at their
107 | # default values.
108 | elif optimizer == 'adagrad':
109 | optimizer_ = optimizers.Adagrad(clipnorm=clipnorm)
110 | elif optimizer == 'adadelta':
111 | optimizer_ = optimizers.Adadelta(clipnorm=clipnorm)
112 | elif optimizer == 'nadam':
113 | optimizer_ = optimizers.Nadam(clipnorm=clipnorm)
114 | else:
115 | err_msg = "Argument for `optimizer` is invalid, got: {}".format(optimizer)
116 | LOGGER.error('ValueError %s', err_msg)
117 | raise ValueError(err_msg)
118 |
119 | model.compile(optimizer=optimizer_, loss=loss_function)
120 |
--------------------------------------------------------------------------------
/saber/models/multi_task_lstm_crf.py:
--------------------------------------------------------------------------------
1 | """Contains the Multi-task BiLSTM-CRF (MT-BILSTM-CRF) Keras model for squence labelling.
2 | """
3 | import logging
4 |
5 | import tensorflow as tf
6 | from keras.layers import (LSTM, Bidirectional, Concatenate, Dense, Dropout,
7 | Embedding, SpatialDropout1D, TimeDistributed)
8 | from keras.models import Input, Model, model_from_json
9 | from keras.utils import multi_gpu_model
10 | from keras_contrib.layers.crf import CRF
11 | from keras_contrib.losses.crf_losses import crf_loss
12 |
13 | from .. import constants
14 | from .base_model import BaseKerasModel
15 |
16 | LOGGER = logging.getLogger(__name__)
17 |
18 | class MultiTaskLSTMCRF(BaseKerasModel):
19 | """A Keras implementation of a BiLSTM-CRF for sequence labeling.
20 |
21 | A BiLSTM-CRF for NER implementation in Keras. Supports multi-task learning by default, just pass
22 | multiple Dataset objects via `datasets` to the constructor and the model will share the
23 | parameters of all layers, except for the final output layer, across all datasets, where each
24 | dataset represents a sequence labelling task.
25 |
26 | Args:
27 | config (Config): Contains a set of harmonzied arguments provided in a *.ini file and,
28 | optionally, from the command line.
29 | datasets (list): A list of Dataset objects.
30 | embeddings (numpy.ndarray): A numpy array where ith row contains the vector embedding for
31 | the ith word type.
32 |
33 | References:
34 | - Guillaume Lample, Miguel Ballesteros, Sandeep Subramanian, Kazuya Kawakami, Chris Dyer.
35 | "Neural Architectures for Named Entity Recognition". Proceedings of NAACL 2016.
36 | https://arxiv.org/abs/1603.01360
37 | """
38 | def __init__(self, config, datasets, embeddings=None, **kwargs):
39 | super().__init__(config, datasets, embeddings, **kwargs)
40 |
41 | def load(self, weights_filepath, model_filepath):
42 | """Load a model from disk.
43 |
44 | Loads a model from disk by loading its architecture from a json file at `model_filepath`
45 | and its weights from a hdf5 file at `model_filepath`.
46 |
47 | Args:
48 | weights_filepath (str): Filepath to the models wieghts (.hdf5 file).
49 | model_filepath (str): Filepath to the models architecture (.json file).
50 | """
51 | with open(model_filepath) as f:
52 | model = model_from_json(f.read(), custom_objects={'CRF': CRF})
53 | model.load_weights(weights_filepath)
54 | self.models.append(model)
55 |
56 | def specify(self):
57 | """Specifies a multi-task BiLSTM-CRF for sequence tagging using Keras.
58 |
59 | Implements a hybrid bidirectional long short-term memory network-condition random
60 | field (BiLSTM-CRF) multi-task model for sequence tagging.
61 | """
62 | # specify any shared layers outside the for loop
63 | # word-level embedding layer
64 | if self.embeddings is None:
65 | word_embeddings = Embedding(input_dim=len(self.datasets[0].type_to_idx['word']) + 1,
66 | output_dim=self.config.word_embed_dim,
67 | mask_zero=True,
68 | name="word_embedding_layer")
69 | else:
70 | word_embeddings = Embedding(input_dim=self.embeddings.num_embed,
71 | output_dim=self.embeddings.dimension,
72 | mask_zero=True,
73 | weights=[self.embeddings.matrix],
74 | trainable=self.config.fine_tune_word_embeddings,
75 | name="word_embedding_layer")
76 | # character-level embedding layer
77 | char_embeddings = Embedding(input_dim=len(self.datasets[0].type_to_idx['char']) + 1,
78 | output_dim=self.config.char_embed_dim,
79 | mask_zero=True,
80 | name="char_embedding_layer")
81 | # char-level BiLSTM
82 | char_BiLSTM = TimeDistributed(Bidirectional(LSTM(constants.UNITS_CHAR_LSTM // 2)),
83 | name='character_BiLSTM')
84 | # word-level BiLSTM
85 | word_BiLSTM_1 = Bidirectional(LSTM(units=constants.UNITS_WORD_LSTM // 2,
86 | return_sequences=True,
87 | dropout=self.config.dropout_rate['input'],
88 | recurrent_dropout=self.config.dropout_rate['recurrent']),
89 | name="word_BiLSTM_1")
90 | word_BiLSTM_2 = Bidirectional(LSTM(units=constants.UNITS_WORD_LSTM // 2,
91 | return_sequences=True,
92 | dropout=self.config.dropout_rate['input'],
93 | recurrent_dropout=self.config.dropout_rate['recurrent']),
94 | name="word_BiLSTM_2")
95 |
96 | # get all unique tag types across all datasets
97 | all_tags = [ds.type_to_idx['tag'] for ds in self.datasets]
98 | all_tags = set(x for l in all_tags for x in l)
99 |
100 | # feedforward before CRF, maps each time step to a vector
101 | dense_layer = TimeDistributed(Dense(len(all_tags), activation=self.config.activation),
102 | name='dense_layer')
103 |
104 | # specify model, taking into account the shared layers
105 | for dataset in self.datasets:
106 | # word-level embedding
107 | word_ids = Input(shape=(None, ), dtype='int32', name='word_id_inputs')
108 | word_embed = word_embeddings(word_ids)
109 |
110 | # character-level embedding
111 | char_ids = Input(shape=(None, None), dtype='int32', name='char_id_inputs')
112 | char_embed = char_embeddings(char_ids)
113 |
114 | # character-level BiLSTM + dropout. Spatial dropout applies the same dropout mask to all
115 | # timesteps which is necessary to implement variational dropout
116 | # (https://arxiv.org/pdf/1512.05287.pdf)
117 | char_embed = char_BiLSTM(char_embed)
118 | if self.config.variational_dropout:
119 | LOGGER.info('Used variational dropout')
120 | char_embed = SpatialDropout1D(self.config.dropout_rate['output'])(char_embed)
121 |
122 | # concatenate word- and char-level embeddings + dropout
123 | model = Concatenate()([word_embed, char_embed])
124 | model = Dropout(self.config.dropout_rate['output'])(model)
125 |
126 | # word-level BiLSTM + dropout
127 | model = word_BiLSTM_1(model)
128 | if self.config.variational_dropout:
129 | model = SpatialDropout1D(self.config.dropout_rate['output'])(model)
130 | model = word_BiLSTM_2(model)
131 | if self.config.variational_dropout:
132 | model = SpatialDropout1D(self.config.dropout_rate['output'])(model)
133 |
134 | # feedforward before CRF
135 | model = dense_layer(model)
136 |
137 | # CRF output layer
138 | crf = CRF(len(dataset.type_to_idx['tag']), name='crf_classifier')
139 | output_layer = crf(model)
140 |
141 | # fully specified model
142 | # https://github.com/keras-team/keras/blob/bf1378f39d02b7d0b53ece5458f9275ac8208046/keras/utils/multi_gpu_utils.py
143 | with tf.device('/cpu:0'):
144 | model = Model(inputs=[word_ids, char_ids], outputs=[output_layer])
145 | self.models.append(model)
146 |
147 | def compile(self):
148 | """Compiles the BiLSTM-CRF.
149 |
150 | Compiles the Keras model(s) at `self.models`. If multiple GPUs are detected, a model
151 | capable of training on all of them is compiled.
152 | """
153 | for i in range(len(self.models)):
154 | # parallize the model if multiple GPUs are available
155 | # https://github.com/keras-team/keras/pull/9226
156 | # awfully bad practice but this was the example given by Keras documentation
157 | try:
158 | self.models[i] = multi_gpu_model(self.models[i])
159 | LOGGER.info('Compiling the model on multiple GPUs')
160 | except:
161 | LOGGER.info('Compiling the model on a single CPU or GPU')
162 |
163 | self._compile(model=self.models[i],
164 | loss_function=crf_loss,
165 | optimizer=self.config.optimizer,
166 | lr=self.config.learning_rate,
167 | decay=self.config.decay,
168 | clipnorm=self.config.grad_norm)
169 |
170 | def prepare_for_transfer(self, datasets):
171 | """Prepares the BiLSTM-CRF for transfer learning by recreating its last layer.
172 |
173 | Prepares the BiLSTM-CRF model(s) at `self.models` for transfer learning by removing their
174 | CRF classifiers and replacing them with un-trained CRF classifiers of the appropriate size
175 | (i.e. number of units equal to number of output tags) for the target datasets.
176 |
177 | References:
178 | - https://stackoverflow.com/questions/41378461/how-to-use-models-from-keras-applications-for-transfer-learnig/41386444#41386444
179 | """
180 | self.datasets = datasets # replace with target datasets
181 | models, self.models = self.models, [] # wipe models
182 |
183 | for dataset, model in zip(self.datasets, models):
184 | # remove the old CRF classifier and define a new one
185 | model.layers.pop()
186 | new_crf = CRF(len(dataset.type_to_idx['tag']), name='target_crf_classifier')
187 | # create the new model
188 | new_input = model.input
189 | new_output = new_crf(model.layers[-1].output)
190 | self.models.append(Model(new_input, new_output))
191 |
192 | self.compile()
193 |
--------------------------------------------------------------------------------
/saber/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/tests/__init__.py
--------------------------------------------------------------------------------
/saber/tests/resources/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/tests/resources/__init__.py
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_config.ini:
--------------------------------------------------------------------------------
1 | [mode]
2 | # Possible models: [MT-LSTM-CRF, ]
3 | model_name = MT-LSTM-CRF
4 | # If True, model is compressed and saved in output_folder at the end of training. Model weights are
5 | # taken from the last epoch.
6 | save_model = False
7 |
8 | [data]
9 | # You can specify multiple datasets by listing their paths, separated by a comma.
10 | dataset_folder = saber/tests/resources/dummy_dataset_1
11 | output_folder = ../output
12 | # Path to pre-trained model. To train a model from scratch, leave this blank.
13 | pretrained_model =
14 | # Path to pre-trained word embeddings. In order to use random initialization, leave this blank.
15 | # Note that you can leave this blank when loading a pre-trained model
16 | # (via 'SequenceProcessor.load()') that was trained with pre-tained embeddings
17 | pretrained_embeddings = saber/tests/resources/dummy_word_embeddings/dummy_word_embeddings.txt
18 |
19 | [model]
20 | # If pre-trained word embeddings are provided, word_embed_dim will be the same size as these
21 | # embeddings and this argument will be ignored.
22 | word_embed_dim = 200
23 | char_embed_dim = 30
24 |
25 | [training]
26 | # Values chosen for each hyperparameter represent sensible defaults that perform well across a wide
27 | # range of NLP tasks (POS tagging, Chunking, NER, etc.) and thus should only be changed in special
28 | # circumstances.
29 | optimizer = nadam
30 | activation = relu
31 | # Set to 0 to turn off gradient normalization.
32 | grad_norm = 1.0
33 | # For certain optimizers, these values are ignored. See compile_model() in
34 | # saber/utils/model_utils.py.
35 | learning_rate = 0.0
36 | decay = 0.0
37 |
38 | # Three dropout values must be specified (separated by a comma), corresponding to the dropout rate
39 | # to apply to the input, output and recurrent connections respectively. Must be a value between 0.0
40 | # and 1.0.
41 | dropout_rate = 0.3, 0.3, 0.1
42 |
43 | batch_size = 32
44 | # If a test partition is supplied at 'dataset_folder' (test.*) then this argument is ignored, and a
45 | # simple train/valid/test scheme is used. A valid partition (valid.*) may optionally be provided
46 | # along with the test parition. If none is found, 10% of examples are randomly selected.
47 | k_folds = 2
48 | epochs = 50
49 |
50 | # Matching criteria used when determining whether or not a prediction is a true-positive. Choices
51 | # are 'left' for left-boundary matching, 'right' for right-boundary matching and 'exact' for
52 | # exact-boundary matching.
53 | criteria = exact
54 |
55 | [advanced]
56 | verbose = False
57 | debug = False
58 | # If True, per-epoch logs which can be visualized with TensorBoard are written to output_folder
59 | # Note: These logs can be quite large.
60 | tensorboard = False
61 | # If True, then during training the models weights for each and every epoch will be saved.
62 | # Otherwise, weights are only saved for epochs that achieve a new best on validation loss.
63 | save_all_weights = False
64 | # If True, tokens that occur less than 1 time in the training dataset (hapax legomenon) are replaced
65 | # with a special unknown token. This should result in faster loading times of pre-trained word
66 | # embeddings and faster training times.
67 | replace_rare_tokens = False
68 | # If True, then all pre-trained word embeddings provided via pretrained_embeddings are loaded.
69 | # Otherwise, only pre-trained embeddings for tokens found in the dataset(s) at dataset_folder are
70 | # loaded. For evaluation, it's best to leave this as False. For models that will be deployed, it's
71 | # best to set it to True.
72 | load_all_embeddings = False
73 | # If True, then pre-trained word embeddings will be fine-tuned with the other parameters of the
74 | # neural network during training. Generally, you should not set this to True unless you have a very
75 | # large training dataset.
76 | # NOTE: if 'pretrained_embeddings' are not provided, they will be randomly initialized and
77 | # fine-tuned during training, ignoring this argument.
78 | fine_tune_word_embeddings = False
79 | # TEMP. Set to true if variational dropout should be used.
80 | variational_dropout = False
81 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_constants.py:
--------------------------------------------------------------------------------
1 | """Constants used by the unit tests.
2 | """
3 | import logging
4 | import os
5 |
6 | import numpy as np
7 | from pkg_resources import resource_filename
8 |
9 | from ... import constants
10 |
11 | LOGGER = logging.getLogger(__name__)
12 |
13 | # relative paths for test resources
14 | PATH_TO_DUMMY_DATASET_1 = resource_filename(__name__, 'dummy_dataset_1')
15 | PATH_TO_DUMMY_DATASET_2 = resource_filename(__name__, 'dummy_dataset_2')
16 | PATH_TO_DUMMY_CONFIG = resource_filename(__name__, 'dummy_config.ini')
17 | PATH_TO_DUMMY_EMBEDDINGS = resource_filename(__name__, 'dummy_word_embeddings/dummy_word_embeddings.txt')
18 |
19 | ######################################### DUMMY EMBEDDINGS #########################################
20 |
21 | # for testing embeddings
22 | DUMMY_TOKEN_MAP = {'': 0, '': 1, 'the': 2, 'quick': 3, 'brown': 4, 'fox': 5}
23 | DUMMY_CHAR_MAP = {'': 0, '': 1, 'r': 2, 'u': 3, 'c': 4, 'f': 5, 'e': 6, 'o': 7, 'x': 8,
24 | 'h': 9, 'b': 10, 'n': 11, 'w': 12, 'i': 13, 't': 14, 'q': 15, 'k': 16}
25 | DUMMY_EMBEDDINGS_INDEX = {
26 | 'the': [0.15580128, -0.07108746, 0.055198, -0.14199848, 0.0005317868],
27 | 'quick': [-0.011208724, 0.21213274, -0.17233513, -0.4401193, 0.13930725],
28 | 'brown': [0.12754257, -0.07938199, 0.083904505, -0.24103324, 0.0084449835],
29 | 'fox': [0.2947119, 0.14794342, 0.10318808, 0.09019197, -0.24244581]
30 | }
31 | DUMMY_EMBEDDINGS_MATRIX = np.array([
32 | [0.0, 0.0, 0.0, 0.0, 0.0],
33 | [0.0, 0.0, 0.0, 0.0, 0.0],
34 | [0.15580128, -0.07108746, 0.055198, -0.14199848, 0.0005317868],
35 | [-0.011208724, 0.21213274, -0.17233513, -0.4401193, 0.13930725],
36 | [0.12754257, -0.07938199, 0.083904505, -0.24103324, 0.0084449835],
37 | [0.2947119, 0.14794342, 0.10318808, 0.09019197, -0.24244581]
38 | ])
39 |
40 | ########################################## DUMMY DATASET ##########################################
41 |
42 | DUMMY_WORD_SEQ = np.array([
43 | ['Human', 'APC2', 'maps', 'to', 'chromosome', '19p13', '.'],
44 | ['The', 'absence', 'of', 'functional', 'C7', 'activity', 'could', 'not', 'be', 'accounted',
45 | 'for', 'on', 'the', 'basis', 'of', 'an', 'inhibitor', '.'],
46 | ])
47 | DUMMY_CHAR_SEQ = np.array([
48 | [['H', 'u', 'm', 'a', 'n'], ['A', 'P', 'C', '2'], ['m', 'a', 'p', 's'], ['t', 'o'],
49 | ['c', 'h', 'r', 'o', 'm', 'o', 's', 'o', 'm', 'e'], ['1', '9', 'p', '1', '3'], ['.']],
50 | [['T', 'h', 'e'], ['a', 'b', 's', 'e', 'n', 'c', 'e'], ['o', 'f'],
51 | ['f', 'u', 'n', 'c', 't', 'i', 'o', 'n', 'a', 'l'], ['C', '7'],
52 | ['a', 'c', 't', 'i', 'v', 'i', 't', 'y'], ['c', 'o', 'u', 'l', 'd'], ['n', 'o', 't'],
53 | ['b', 'e'], ['a', 'c', 'c', 'o', 'u', 'n', 't', 'e', 'd'], ['f', 'o', 'r'], ['o', 'n'],
54 | ['t', 'h', 'e'], ['b', 'a', 's', 'i', 's'], ['o', 'f'], ['a', 'n'],
55 | ['i', 'n', 'h', 'i', 'b', 'i', 't', 'o', 'r'], ['.']]])
56 | DUMMY_TAG_SEQ = np.array([
57 | ['O', 'O', 'O', 'O', 'O', 'O', 'O'],
58 | ['O', 'B-DISO', 'I-DISO', 'I-DISO', 'E-DISO', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',
59 | 'O', 'O', 'O'],
60 | ])
61 | DUMMY_WORD_TYPES = ['Human', 'APC2', 'maps', 'to', 'chromosome', '19p13', '.', 'The', 'absence',
62 | 'of', 'functional', 'C7', 'activity', 'could', 'not', 'be', 'accounted',
63 | 'for', 'on', 'the', 'basis', 'an', 'inhibitor', constants.PAD, constants.UNK]
64 | DUMMY_CHAR_TYPES = ['2', 's', 'c', 'T', 'd', 'e', 'H', 'h', 'a', 'b', 'v', 'C', 'm', 't', '9', 'p',
65 | 'r', '3', 'u', '.', 'o', '7', 'n', 'f', 'y', 'l', '1', 'i', 'A', 'P',
66 | constants.PAD, constants.UNK]
67 | DUMMY_TAG_TYPES = ['O', 'B-DISO', 'I-DISO', 'E-DISO', constants.PAD]
68 |
69 | ########################################### DUMMY CONFIG ###########################################
70 |
71 | # Sections of the .ini file
72 | CONFIG_SECTIONS = ['mode', 'data', 'model', 'training', 'advanced']
73 |
74 | # Arg values before any processing
75 | DUMMY_ARGS_NO_PROCESSING = {'model_name': 'MT-LSTM-CRF',
76 | 'save_model': 'False',
77 | 'dataset_folder': 'saber/tests/resources/dummy_dataset_1',
78 | 'output_folder': '../output',
79 | 'pretrained_model': '',
80 | 'pretrained_embeddings': ('saber/tests/resources/'
81 | 'dummy_word_embeddings/'
82 | 'dummy_word_embeddings.txt'),
83 | 'word_embed_dim': '200',
84 | 'char_embed_dim': '30',
85 | 'optimizer': 'nadam',
86 | 'activation': 'relu',
87 | 'learning_rate': '0.0',
88 | 'grad_norm': '1.0',
89 | 'decay': '0.0',
90 | 'dropout_rate': '0.3, 0.3, 0.1',
91 | 'batch_size': '32',
92 | 'k_folds': '2',
93 | 'epochs': '50',
94 | 'criteria': 'exact',
95 | 'verbose': 'False',
96 | 'debug': 'False',
97 | 'save_all_weights': 'False',
98 | 'tensorboard': 'False',
99 | 'replace_rare_tokens': 'False',
100 | 'load_all_embeddings': 'False',
101 | 'fine_tune_word_embeddings': 'False',
102 | # TEMP
103 | 'variational_dropout': 'False',
104 | }
105 | # Final arg values when args provided in only config file
106 | DUMMY_ARGS_NO_CLI_ARGS = {'model_name': 'mt-lstm-crf',
107 | 'save_model': False,
108 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1],
109 | 'output_folder': os.path.abspath('../output'),
110 | 'pretrained_model': '',
111 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS,
112 | 'word_embed_dim': 200,
113 | 'char_embed_dim': 30,
114 | 'optimizer': 'nadam',
115 | 'activation': 'relu',
116 | 'learning_rate': 0.0,
117 | 'decay': 0.0,
118 | 'grad_norm': 1.0,
119 | 'dropout_rate': {'input': 0.3, 'output':0.3, 'recurrent': 0.1},
120 | 'batch_size': 32,
121 | 'k_folds': 2,
122 | 'epochs': 50,
123 | 'criteria': 'exact',
124 | 'verbose': False,
125 | 'debug': False,
126 | 'save_all_weights': False,
127 | 'tensorboard': False,
128 | 'replace_rare_tokens': False,
129 | 'load_all_embeddings': False,
130 | 'fine_tune_word_embeddings': False,
131 | # TEMP
132 | 'variational_dropout': False,
133 | }
134 | # Final arg values when args provided in config file and from CLI
135 | DUMMY_COMMAND_LINE_ARGS = {'optimizer': 'sgd',
136 | 'grad_norm': 1.0,
137 | 'learning_rate': 0.05,
138 | 'decay': 0.5,
139 | 'dropout_rate': [0.6, 0.6, 0.2],
140 | # the dataset and embeddings are used for test purposes so they must
141 | # point to the correct resources, this can be ensured by passing their
142 | # paths here
143 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1],
144 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS,
145 | }
146 | DUMMY_ARGS_WITH_CLI_ARGS = {'model_name': 'mt-lstm-crf',
147 | 'save_model': False,
148 | 'dataset_folder': [PATH_TO_DUMMY_DATASET_1],
149 | 'output_folder': os.path.abspath('../output'),
150 | 'pretrained_model': '',
151 | 'pretrained_embeddings': PATH_TO_DUMMY_EMBEDDINGS,
152 | 'word_embed_dim': 200,
153 | 'char_embed_dim': 30,
154 | 'optimizer': 'sgd',
155 | 'activation': 'relu',
156 | 'learning_rate': 0.05,
157 | 'decay': 0.5,
158 | 'grad_norm': 1.0,
159 | 'dropout_rate': {'input': 0.6, 'output': 0.6, 'recurrent': 0.2},
160 | 'batch_size': 32,
161 | 'k_folds': 2,
162 | 'epochs': 50,
163 | 'criteria': 'exact',
164 | 'verbose': False,
165 | 'debug': False,
166 | 'save_all_weights': False,
167 | 'tensorboard': False,
168 | 'replace_rare_tokens': False,
169 | 'load_all_embeddings': False,
170 | 'fine_tune_word_embeddings': False,
171 | # TEMP
172 | 'variational_dropout': False,
173 | }
174 | ########################################### WEB SERVICE ###########################################
175 |
176 | DUMMY_ENTITIES = {'ANAT': False,
177 | 'CHED': True,
178 | 'DISO': False,
179 | 'LIVB': True,
180 | 'PRGE': True,
181 | 'TRIG': False,
182 | }
183 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_1/test.tsv:
--------------------------------------------------------------------------------
1 | Human O
2 | APC2 O
3 | maps O
4 | to O
5 | chromosome O
6 | 19p13 O
7 | . O
8 |
9 | The O
10 | absence B-DISO
11 | of I-DISO
12 | functional I-DISO
13 | C7 E-DISO
14 | activity O
15 | could O
16 | not O
17 | be O
18 | accounted O
19 | for O
20 | on O
21 | the O
22 | basis O
23 | of O
24 | an O
25 | inhibitor O
26 | . O
27 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_1/train.tsv:
--------------------------------------------------------------------------------
1 | Human O
2 | APC2 O
3 | maps O
4 | to O
5 | chromosome O
6 | 19p13 O
7 | . O
8 |
9 | The O
10 | absence B-DISO
11 | of I-DISO
12 | functional I-DISO
13 | C7 E-DISO
14 | activity O
15 | could O
16 | not O
17 | be O
18 | accounted O
19 | for O
20 | on O
21 | the O
22 | basis O
23 | of O
24 | an O
25 | inhibitor O
26 | . O
27 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_1/valid.tsv:
--------------------------------------------------------------------------------
1 | Human O
2 | APC2 O
3 | maps O
4 | to O
5 | chromosome O
6 | 19p13 O
7 | . O
8 |
9 | The O
10 | absence B-DISO
11 | of I-DISO
12 | functional I-DISO
13 | C7 E-DISO
14 | activity O
15 | could O
16 | not O
17 | be O
18 | accounted O
19 | for O
20 | on O
21 | the O
22 | basis O
23 | of O
24 | an O
25 | inhibitor O
26 | . O
27 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_2/test.tsv:
--------------------------------------------------------------------------------
1 | Myoc B-PRGE
2 | alleles O
3 | do O
4 | not O
5 | associate O
6 | with O
7 | the O
8 | magnitude O
9 | of O
10 | IOP O
11 |
12 | Mutations O
13 | in O
14 | the O
15 | myocilin B-PRGE
16 | gene O
17 | ( O
18 | MYOC B-PRGE
19 | ) O
20 | cause O
21 | human O
22 | glaucoma O
23 | . O
24 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_2/train.tsv:
--------------------------------------------------------------------------------
1 | Myoc B-PRGE
2 | alleles O
3 | do O
4 | not O
5 | associate O
6 | with O
7 | the O
8 | magnitude O
9 | of O
10 | IOP O
11 |
12 | Mutations O
13 | in O
14 | the O
15 | myocilin B-PRGE
16 | gene O
17 | ( O
18 | MYOC B-PRGE
19 | ) O
20 | cause O
21 | human O
22 | glaucoma O
23 | . O
24 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_dataset_2/valid.tsv:
--------------------------------------------------------------------------------
1 | Myoc B-PRGE
2 | alleles O
3 | do O
4 | not O
5 | associate O
6 | with O
7 | the O
8 | magnitude O
9 | of O
10 | IOP O
11 |
12 | Mutations O
13 | in O
14 | the O
15 | myocilin B-PRGE
16 | gene O
17 | ( O
18 | MYOC B-PRGE
19 | ) O
20 | cause O
21 | human O
22 | glaucoma O
23 | . O
24 |
--------------------------------------------------------------------------------
/saber/tests/resources/dummy_word_embeddings/dummy_word_embeddings.txt:
--------------------------------------------------------------------------------
1 | 4 5
2 | the 0.15580128 -0.07108746 0.055198 -0.14199848 0.0005317868
3 | quick -0.011208724 0.21213274 -0.17233513 -0.4401193 0.13930725
4 | brown 0.12754257 -0.07938199 0.083904505 -0.24103324 0.0084449835
5 | fox 0.2947119 0.14794342 0.10318808 0.09019197 -0.24244581
6 |
--------------------------------------------------------------------------------
/saber/tests/resources/helpers.py:
--------------------------------------------------------------------------------
1 | """Any and all helper functions for Sabers unit tests.
2 | """
3 | import configparser
4 | import os
5 |
6 | from ... import constants
7 |
8 |
9 | def assert_type_to_idx_as_expected(actual, expected):
10 | """Asserts that a `type_to_idx` mapping is as expected. This involves checking that it contains
11 | the expected, keys, the expected values, and that the values are a consecutive mapping of
12 | integers beginning at 0.
13 | """
14 | # check keys
15 | assert all(word in actual['word'] for word in expected['word'])
16 | assert all(char in actual['char'] for char in expected['char'])
17 | # check that values are consectutive mapping of integers between 0 and length of the dictionary
18 | assert all(id in range(0, len(actual['word'])) for id in actual['word'].values())
19 | assert all(id in range(0, len(actual['char'])) for id in actual['char'].values())
20 | # check initial mapping items
21 | assert all(word in actual['word'] for word in constants.INITIAL_MAPPING['word'])
22 | assert all(word in actual['char'] for word in constants.INITIAL_MAPPING['word'])
23 |
24 | def load_saved_config(filepath):
25 | """Load a saved config.ConfigParser object at 'filepath/config.ini'.
26 |
27 | Args:
28 | filepath (str): filepath to the saved config file 'config.ini'
29 |
30 | Returns:
31 | parsed config.ConfigParser object at 'filepath/config.ini'.
32 | """
33 | saved_config_filepath = os.path.join(filepath, 'config.ini')
34 | saved_config = configparser.ConfigParser()
35 | saved_config.read(saved_config_filepath)
36 |
37 | return saved_config
38 |
39 | def unprocess_args(args):
40 | """Unprocesses processed config args.
41 |
42 | Given a dictionary of arguments ('arg'), returns a dictionary where all values have been
43 | converted to string representation.
44 |
45 | Returns:
46 | args, where all values have been replaced by a str representation.
47 | """
48 | unprocessed_args = {}
49 | for arg, value in args.items():
50 | if isinstance(value, list):
51 | unprocessed_arg = ', '.join(value)
52 | elif isinstance(value, dict):
53 | dict_values = [str(v) for v in value.values()]
54 | unprocessed_arg = ', '.join(dict_values)
55 | else:
56 | unprocessed_arg = str(value)
57 |
58 | unprocessed_args[arg] = unprocessed_arg
59 |
60 | return unprocessed_args
61 |
--------------------------------------------------------------------------------
/saber/tests/test_app_utils.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the app_utils (saber/utils/app_utils.py).
2 | """
3 | import pytest
4 |
5 | from ..utils import app_utils
6 | from .resources.dummy_constants import *
7 |
8 | ############################################ UNIT TESTS ############################################
9 |
10 | def test_get_pubmed_xml_errors():
11 | """Asserts that call to `app_utils.get_pubmed_xml()` raises a ValueError error when an invalid
12 | value for argument `pmid` is passed."""
13 | invalid_pmids = [["test"], "test", 0.0, 0, -1, (42,)]
14 |
15 | for pmid in invalid_pmids:
16 | with pytest.raises(ValueError):
17 | app_utils.get_pubmed_xml(pmid)
18 |
19 | def test_harmonize_entities():
20 | """Asserts that app_utils.harmonize_entities() returns the expected results."""
21 | # single bool test
22 | one_on_test = {'PRGE': True}
23 | one_on_expected = {'ANAT': False, 'CHED': False, 'DISO': False,
24 | 'LIVB': False, 'PRGE': True, 'TRIG': False}
25 | # multi bool test
26 | multi_on_test = {'PRGE': True, 'CHED': True, 'TRIG': False}
27 | multi_on_expected = {'ANAT': False, 'CHED': True, 'DISO': False,
28 | 'LIVB': False, 'PRGE': True, 'TRIG': False}
29 | # null test
30 | none_on_test = {}
31 | none_on_expected = {'ANAT': False, 'CHED': False, 'DISO': False,
32 | 'LIVB': False, 'PRGE': False, 'TRIG': False}
33 |
34 | assert one_on_expected == \
35 | app_utils.harmonize_entities(DUMMY_ENTITIES, one_on_test)
36 | assert multi_on_expected == \
37 | app_utils.harmonize_entities(DUMMY_ENTITIES, multi_on_test)
38 | assert none_on_expected == \
39 | app_utils.harmonize_entities(DUMMY_ENTITIES, none_on_test)
40 |
--------------------------------------------------------------------------------
/saber/tests/test_base_model.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the BaseKerasModel (saber/models/base_model.py).
2 | """
3 | import pytest
4 |
5 | from ..config import Config
6 | from ..dataset import Dataset
7 | from ..embeddings import Embeddings
8 | from ..models.base_model import BaseKerasModel
9 | from .resources.dummy_constants import *
10 |
11 | ######################################### PYTEST FIXTURES #########################################
12 |
13 | @pytest.fixture
14 | def dummy_config():
15 | """Returns an instance of a Config object."""
16 | return Config(PATH_TO_DUMMY_CONFIG)
17 |
18 | @pytest.fixture
19 | def dummy_dataset_1():
20 | """Returns a single dummy Dataset instance after calling `Dataset.load()`.
21 | """
22 | # Don't replace rare tokens for the sake of testing
23 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False)
24 | dataset.load()
25 |
26 | return dataset
27 |
28 | @pytest.fixture
29 | def dummy_dataset_2():
30 | """Returns a single dummy Dataset instance after calling `Dataset.load()`.
31 | """
32 | # Don't replace rare tokens for the sake of testing
33 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False)
34 | dataset.load()
35 |
36 | return dataset
37 |
38 | @pytest.fixture
39 | def dummy_embeddings(dummy_dataset_1):
40 | """Returns an instance of an `Embeddings()` object AFTER the `.load()` method is called.
41 | """
42 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS,
43 | token_map=dummy_dataset_1.idx_to_tag)
44 | embeddings.load(binary=False) # txt file format is easier to test
45 | return embeddings
46 |
47 | @pytest.fixture
48 | def single_model(dummy_config, dummy_dataset_1, dummy_embeddings):
49 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
50 | model = BaseKerasModel(config=dummy_config,
51 | datasets=[dummy_dataset_1],
52 | # to test passing of arbitrary keyword args to constructor
53 | totally_arbitrary='arbitrary')
54 | return model
55 |
56 | @pytest.fixture
57 | def single_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings):
58 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
59 | loaded embeddings"""
60 | model = BaseKerasModel(config=dummy_config,
61 | datasets=[dummy_dataset_1],
62 | embeddings=dummy_embeddings,
63 | # to test passing of arbitrary keyword args to constructor
64 | totally_arbitrary='arbitrary')
65 | return model
66 |
67 | ############################################ UNIT TESTS ############################################
68 |
69 | def test_compile_value_error(single_model):
70 | """Asserts that `BaseKerasModel._compile()` returns a ValueError when an invalid argument for
71 | `optimizer` is passed.
72 | """
73 | with pytest.raises(ValueError):
74 | single_model._compile('arbitrary', 'arbitrary', 'invalid')
75 |
76 | def test_attributes_init_of_single_model(dummy_config, dummy_dataset_1, single_model):
77 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is
78 | initialized without embeddings (`embeddings` attribute is None.)
79 | """
80 | assert isinstance(single_model, BaseKerasModel)
81 | # attributes that are passed to __init__
82 | assert single_model.config is dummy_config
83 | assert single_model.datasets[0] is dummy_dataset_1
84 | assert single_model.embeddings is None
85 | # other instance attributes
86 | assert single_model.models == []
87 | # test that we can pass arbitrary keyword arguments
88 | assert single_model.totally_arbitrary == 'arbitrary'
89 |
90 | def test_attributes_init_of_single_model_embeddings(dummy_config, dummy_dataset_1,
91 | dummy_embeddings, single_model_embeddings):
92 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is
93 | initialized with embeddings (`embeddings` attribute is not None.)
94 | """
95 | assert isinstance(single_model_embeddings, BaseKerasModel)
96 | # attributes that are passed to __init__
97 | assert single_model_embeddings.config is dummy_config
98 | assert single_model_embeddings.datasets[0] is dummy_dataset_1
99 | assert single_model_embeddings.embeddings is dummy_embeddings
100 | # other instance attributes
101 | assert single_model_embeddings.models == []
102 | # test that we can pass arbitrary keyword arguments
103 | assert single_model_embeddings.totally_arbitrary == 'arbitrary'
104 |
105 | def test_prepare_data_for_training(dummy_dataset_1, single_model):
106 | """Assert that the values returned from call to `BaseKerasModel.prepare_data_for_training()` are
107 | as expected.
108 | """
109 | training_data = single_model.prepare_data_for_training()
110 | partitions = ['x_train', 'y_train', 'x_valid', 'y_valid', 'x_test', 'y_test']
111 |
112 | # assert each item in training_data contains the expected keys
113 | assert all(partition in data for data in training_data for partition in partitions)
114 |
115 | # assert that the items in training_data contain the expected values
116 | assert all(data['x_train'] == [dummy_dataset_1.idx_seq['train']['word'], dummy_dataset_1.idx_seq['train']['char']]
117 | for data in training_data)
118 | assert all(data['x_valid'] == [dummy_dataset_1.idx_seq['valid']['word'], dummy_dataset_1.idx_seq['valid']['char']]
119 | for data in training_data)
120 | assert all(data['x_test'] == [dummy_dataset_1.idx_seq['test']['word'], dummy_dataset_1.idx_seq['test']['char']]
121 | for data in training_data)
122 | assert all(np.array_equal(data['y_train'], dummy_dataset_1.idx_seq['train']['tag']) for data in training_data)
123 | assert all(np.array_equal(data['y_valid'], dummy_dataset_1.idx_seq['valid']['tag']) for data in training_data)
124 | assert all(np.array_equal(data['y_test'], dummy_dataset_1.idx_seq['test']['tag']) for data in training_data)
125 |
--------------------------------------------------------------------------------
/saber/tests/test_config.py:
--------------------------------------------------------------------------------
1 | """Contains any and all unit tests for the config.Config class (saber/config.py).
2 | """
3 | import pytest
4 |
5 | from .. import config
6 | from .resources.dummy_constants import *
7 | from .resources.helpers import *
8 |
9 | ######################################### PYTEST FIXTURES #########################################
10 |
11 | @pytest.fixture
12 | def config_no_cli_args():
13 | """Returns an instance of a config.Config object after parsing the dummy config file with no command
14 | line interface (CLI) args."""
15 | # parse the dummy config
16 | dummy_config = config.Config(PATH_TO_DUMMY_CONFIG)
17 |
18 | return dummy_config
19 |
20 | @pytest.fixture
21 | def config_with_cli_args():
22 | """Returns an instance of a config.config.Config object after parsing the dummy config file with command line
23 | interface (CLI) args."""
24 | # parse the dummy config, leave cli false and instead pass command line args manually
25 | dummy_config = config.Config(PATH_TO_DUMMY_CONFIG)
26 | # this is a bit of a hack, but need to simulate providing commands at the command line
27 | dummy_config.cli_args = DUMMY_COMMAND_LINE_ARGS
28 | dummy_config.harmonize_args(DUMMY_COMMAND_LINE_ARGS)
29 |
30 | return dummy_config
31 |
32 | ############################################ UNIT TESTS ############################################
33 |
34 | def test_process_args_no_cli_args(config_no_cli_args):
35 | """Asserts the config.Config.config object contains the expected attributes after initializing a config.Config
36 | object without CLI args."""
37 | # check filepath attribute
38 | assert config_no_cli_args.filepath == PATH_TO_DUMMY_CONFIG
39 | # check that the config file contains the same values as DUMMY_ARGS_NO_PROCESSING
40 | config = config_no_cli_args.config
41 | for section in CONFIG_SECTIONS:
42 | for arg, value in config[section].items():
43 | assert value == DUMMY_ARGS_NO_PROCESSING[arg]
44 | # check cli_args attribute
45 | assert config_no_cli_args.cli_args == {}
46 |
47 | def test_process_args_with_cli_args(config_with_cli_args):
48 | """Asserts the config.Config.config object contains the expected attributes after initializing a config.Config
49 | object with CLI args."""
50 | # check filepath attribute
51 | assert config_with_cli_args.filepath == os.path.join(os.path.dirname( \
52 | os.path.os.path.abspath(__file__)), PATH_TO_DUMMY_CONFIG)
53 | config = config_with_cli_args.config
54 | # check that the config file contains the same values as DUMMY_ARGS_NO_PROCESSING
55 | for section in CONFIG_SECTIONS:
56 | for arg, value in config[section].items():
57 | assert value == DUMMY_ARGS_NO_PROCESSING[arg]
58 | # check cli_args attribute
59 | assert config_with_cli_args.cli_args == DUMMY_COMMAND_LINE_ARGS
60 |
61 | def test_config_attributes_no_cli_args(config_no_cli_args):
62 | """Asserts that the class attributes of a config.Config object are of the expected value/type after
63 | objects initialization, with NO command line arguments.
64 | """
65 | # check that we get the values we expected
66 | for arg, value in DUMMY_ARGS_NO_CLI_ARGS.items():
67 | assert value == getattr(config_no_cli_args, arg)
68 |
69 | def test_config_attributes_with_cli_args(config_with_cli_args):
70 | """Asserts that the class attributes of a config.Config object are of the expected value/type after
71 | object initialization, taking into account command line arguments, which take precedence over
72 | config arguments.
73 | """
74 | # check that we get the values we expected, specifically, check that our command line arguments
75 | # have overwritten our config arguments
76 | for arg, value in DUMMY_ARGS_WITH_CLI_ARGS.items():
77 | assert value == getattr(config_with_cli_args, arg)
78 |
79 | def test_resolve_filepath(config_no_cli_args):
80 | """Asserts that `Config._resolve_filepath()` returns the expected values.
81 | """
82 | # tests for when neither filepath nor cli_args arguments are provided
83 | filepath_none_cli_args_none_expected = resource_filename(config.__name__, constants.CONFIG_FILENAME)
84 | filepath_none_cli_args_none_actual = config_no_cli_args._resolve_filepath(filepath=None, cli_args={})
85 | # tests for when cli_args argument is provided
86 | filepath_none_cli_args_expected = 'arbitrary/filepath/to/config.ini'
87 | dummy_cli_args = {'config_filepath': filepath_none_cli_args_expected}
88 | filepath_none_cli_args_actual = config_no_cli_args._resolve_filepath(filepath=None,
89 | cli_args=dummy_cli_args)
90 | # tests for when filepath argument is provided
91 | filepath_cli_args_none_expected = filepath_none_cli_args_expected
92 | filepath_cli_args_none_actual = config_no_cli_args._resolve_filepath(filepath=filepath_cli_args_none_expected,
93 | cli_args={})
94 | # tests for when both filepath and cli_args arguments are provided
95 | filepath_cli_args_expected = filepath_none_cli_args_expected
96 | filepath_cli_args_actual = config_no_cli_args._resolve_filepath(filepath=filepath_cli_args_expected,
97 | cli_args=dummy_cli_args)
98 |
99 | assert filepath_none_cli_args_none_expected == filepath_none_cli_args_none_actual
100 | assert filepath_none_cli_args_expected == filepath_none_cli_args_actual
101 | assert filepath_cli_args_none_expected == filepath_cli_args_none_actual
102 | assert filepath_cli_args_expected == filepath_cli_args_actual
103 |
104 | def test_key_error(tmpdir):
105 | """Assert that a KeyError is raised when Config object is initialized with a value for
106 | `filepath` that does does contain a valid *.ini file.
107 | """
108 | with pytest.raises(KeyError):
109 | dummy_config = config.Config(tmpdir.strpath)
110 |
111 | def test_save_no_cli_args(config_no_cli_args, tmpdir):
112 | """Asserts that a saved config file contains the correct arguments and values."""
113 | # save the config to temporary directory created by py.test
114 | config_no_cli_args.save(tmpdir.strpath)
115 | # load the saved config
116 | saved_config = load_saved_config(tmpdir.strpath)
117 | # need to 'unprocess' the args to check them against the saved config file
118 | unprocessed_args = unprocess_args(DUMMY_ARGS_NO_CLI_ARGS)
119 | # ensure the saved config file matches the original arguments used to create it
120 | for section in CONFIG_SECTIONS:
121 | for arg, value in saved_config[section].items():
122 | assert value == unprocessed_args[arg]
123 |
124 | def test_save_with_cli_args(config_with_cli_args, tmpdir):
125 | """Asserts that a saved config file contains the correct arguments and values, taking into
126 | account command line arguments, which take precedence over config arguments.
127 | """
128 | # save the config to temporary directory created by py.test
129 | config_with_cli_args.save(tmpdir.strpath)
130 | # load the saved config
131 | saved_config = load_saved_config(tmpdir.strpath)
132 | # need to 'unprocess' the args to check them against the saved config file
133 | unprocessed_args = unprocess_args(DUMMY_ARGS_WITH_CLI_ARGS)
134 | # ensure the saved config file matches the original arguments used to create it
135 | for section in CONFIG_SECTIONS:
136 | for arg, value in saved_config[section].items():
137 | assert value == unprocessed_args[arg]
138 |
--------------------------------------------------------------------------------
/saber/tests/test_data_utils.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the data_utils (saber/utils/data_utils.py).
2 | """
3 | import numpy as np
4 | import pytest
5 |
6 | from ..config import Config
7 | from ..dataset import Dataset
8 | from ..utils import data_utils
9 | from .resources.dummy_constants import *
10 |
11 | ######################################### PYTEST FIXTURES #########################################
12 |
13 | @pytest.fixture
14 | def dummy_config():
15 | """Returns an instance of a Config object."""
16 | dummy_config = Config(PATH_TO_DUMMY_CONFIG)
17 | return dummy_config
18 |
19 | @pytest.fixture
20 | def dummy_dataset_1():
21 | """Returns a single dummy Dataset instance after calling Dataset.load().
22 | """
23 | # Don't replace rare tokens for the sake of testing
24 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False)
25 | dataset.load()
26 |
27 | return dataset
28 |
29 | @pytest.fixture
30 | def dummy_dataset_2():
31 | """Returns a single dummy Dataset instance after calling `Dataset.load()`.
32 | """
33 | # Don't replace rare tokens for the sake of testing
34 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False)
35 | dataset.load()
36 |
37 | return dataset
38 |
39 | @pytest.fixture
40 | def dummy_compound_dataset(dummy_config):
41 | """
42 | """
43 | dummy_config.dataset_folder = [PATH_TO_DUMMY_DATASET_1, PATH_TO_DUMMY_DATASET_2]
44 | dummy_config.replace_rare_tokens = False
45 | dataset = data_utils.load_compound_dataset(dummy_config)
46 |
47 | return dataset
48 |
49 | @pytest.fixture(scope='session')
50 | def dummy_dataset_paths_all(tmpdir_factory):
51 | """Creates and returns the path to a temporary dataset folder, and train, valid, test files.
52 | """
53 | # create a dummy dataset folder
54 | dummy_dir = tmpdir_factory.mktemp('dummy_dataset')
55 | # create train, valid and train partitions in this folder
56 | train_file = dummy_dir.join('train.tsv')
57 | train_file.write('arbitrary') # need to write content or else the file wont exist
58 | valid_file = dummy_dir.join('valid.tsv')
59 | valid_file.write('arbitrary')
60 | test_file = dummy_dir.join('test.tsv')
61 | test_file.write('arbitrary')
62 |
63 | return dummy_dir.strpath, train_file.strpath, valid_file.strpath, test_file.strpath
64 |
65 | @pytest.fixture(scope='session')
66 | def dummy_dataset_paths_no_valid(tmpdir_factory):
67 | """Creates and returns the path to a temporary dataset folder, and train, and test files.
68 | """
69 | # create a dummy dataset folder
70 | dummy_dir = tmpdir_factory.mktemp('dummy_dataset')
71 | # create train, valid and train partitions in this folder
72 | train_file = dummy_dir.join('train.tsv')
73 | train_file.write('arbitrary') # need to write content or else the file wont exist
74 | test_file = dummy_dir.join('test.tsv')
75 | test_file.write('arbitrary')
76 |
77 | return dummy_dir.strpath, train_file.strpath, test_file.strpath
78 |
79 | ############################################ UNIT TESTS ############################################
80 |
81 | def test_get_filepaths_value_error(tmpdir):
82 | """Asserts that a ValueError is raised when `data_utils.get_filepaths(tmpdir)` is called and
83 | no file '/train.*' exists.
84 | """
85 | with pytest.raises(ValueError):
86 | data_utils.get_filepaths(tmpdir.strpath)
87 |
88 | def test_get_filepaths_all(dummy_dataset_paths_all):
89 | """Asserts that `data_utils.get_filepaths()` returns the expected filepaths when all partitions
90 | (train/test/valid) are provided.
91 | """
92 | dummy_dataset_directory, train_filepath, valid_filepath, test_filepath = dummy_dataset_paths_all
93 | expected = {'train': train_filepath,
94 | 'valid': valid_filepath,
95 | 'test': test_filepath
96 | }
97 | actual = data_utils.get_filepaths(dummy_dataset_directory)
98 |
99 | assert actual == expected
100 |
101 | def test_get_filepaths_no_valid(dummy_dataset_paths_no_valid):
102 | """Asserts that `data_utils.get_filepaths()` returns the expected filepaths when train and
103 | test partitions are provided.
104 | """
105 | dummy_dataset_directory, train_filepath, test_filepath = dummy_dataset_paths_no_valid
106 | expected = {'train': train_filepath,
107 | 'valid': None,
108 | 'test': test_filepath
109 | }
110 | actual = data_utils.get_filepaths(dummy_dataset_directory)
111 |
112 | assert actual == expected
113 |
114 | def test_load_single_dataset(dummy_config, dummy_dataset_1):
115 | """Asserts that `data_utils.load_single_dataset()` returns the expected value.
116 | """
117 | actual = data_utils.load_single_dataset(dummy_config)
118 | expected = [dummy_dataset_1]
119 |
120 | # essentially redundant, but if we dont return a [Dataset] object then the error message from
121 | # the final test could be cryptic
122 | assert isinstance(actual, list)
123 | assert len(actual) == 1
124 | assert isinstance(actual[0], Dataset)
125 | # the test we actually care about, least roundabout way of asking if the two Dataset objects
126 | # are identical
127 | assert dir(actual[0].__dict__) == dir(expected[0].__dict__)
128 |
129 | def test_load_compound_dataset_unchanged_attributes(dummy_dataset_1,
130 | dummy_dataset_2,
131 | dummy_compound_dataset):
132 | """Asserts that attributes of `Dataset` objects which are expected to remain unchanged
133 | are unchanged after call to `data_utils.load_compound_dataset()`.
134 | """
135 | actual = dummy_compound_dataset
136 | expected = [dummy_dataset_1, dummy_dataset_2]
137 |
138 | # essentially redundant, but if we dont return a [Dataset, Dataset] object then the error
139 | # messages from the downstream tests could be cryptic
140 | assert isinstance(actual, list)
141 | assert len(actual) == 2
142 | assert all([isinstance(ds, Dataset) for ds in actual])
143 |
144 | # attributes that are unchanged in case of compound dataset
145 | assert actual[0].directory == expected[0].directory
146 | assert actual[0].replace_rare_tokens == expected[0].replace_rare_tokens
147 | assert actual[0].type_seq == expected[0].type_seq
148 | assert actual[0].type_to_idx['tag'] == expected[0].type_to_idx['tag']
149 | assert actual[0].idx_to_tag == expected[0].idx_to_tag
150 |
151 | assert actual[-1].directory == expected[-1].directory
152 | assert actual[-1].replace_rare_tokens == expected[-1].replace_rare_tokens
153 | assert actual[-1].type_seq == expected[-1].type_seq
154 | assert actual[-1].type_to_idx['tag'] == expected[-1].type_to_idx['tag']
155 | assert actual[-1].idx_to_tag == expected[-1].idx_to_tag
156 |
157 | def test_load_compound_dataset_changed_attributes(dummy_dataset_1,
158 | dummy_dataset_2,
159 | dummy_compound_dataset):
160 | """Asserts that attributes of `Dataset` objects which are expected to be changed are changed
161 | after call to `data_utils.load_compound_dataset()`.
162 | """
163 | actual = dummy_compound_dataset
164 | expected = [dummy_dataset_1, dummy_dataset_2]
165 |
166 | # essentially redundant, but if we don't return a [Dataset, Dataset] object then the error
167 | # messages from the downstream tests could be cryptic
168 | assert isinstance(actual, list)
169 | assert len(actual) == 2
170 | assert all([isinstance(ds, Dataset) for ds in actual])
171 |
172 | # attributes that are changed in case of compound dataset
173 | assert actual[0].type_to_idx['word'] == actual[-1].type_to_idx['word']
174 | assert actual[0].type_to_idx['char'] == actual[-1].type_to_idx['char']
175 |
176 | # TODO: Need to assert that all types in idx_seq map to the same integers
177 | # across the compound datasets
178 |
179 | def test_setup_dataset_for_transfer(dummy_dataset_1, dummy_dataset_2):
180 | """Asserts that the `type_to_idx` attribute of a "source" dataset and a "target" dataset are
181 | as expected after call to `data_utils.setup_dataset_for_transfer()`.
182 | """
183 | source_type_to_idx = dummy_dataset_1.type_to_idx
184 | data_utils.setup_dataset_for_transfer(dummy_dataset_2, source_type_to_idx)
185 |
186 | assert all(dummy_dataset_2.type_to_idx[type_] == source_type_to_idx[type_] for type_ in ['word', 'char'])
187 |
--------------------------------------------------------------------------------
/saber/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | """Contains any and all unit tests for the `Dataset` class (saber/dataset.py).
2 | """
3 | import os
4 |
5 | import numpy as np
6 | import pytest
7 | from nltk.corpus.reader.conll import ConllCorpusReader
8 |
9 | from .. import constants
10 | from ..dataset import Dataset
11 | from ..utils import generic_utils
12 | from .resources.dummy_constants import *
13 |
14 | # TODO (johngiorgi): Need to include tests for valid/test partitions
15 | # TODO (johngiorgi): Need to include tests for compound datasets
16 |
17 | ######################################### PYTEST FIXTURES #########################################
18 | @pytest.fixture
19 | def empty_dummy_dataset():
20 | """Returns an empty single dummy Dataset instance.
21 | """
22 | # Don't replace rare tokens for the sake of testing
23 | return Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False,
24 | # to test passing of arbitrary keyword args to constructor
25 | totally_arbitrary='arbitrary')
26 |
27 | @pytest.fixture
28 | def loaded_dummy_dataset():
29 | """Returns a single dummy Dataset instance after calling Dataset.load().
30 | """
31 | # Don't replace rare tokens for the sake of testing
32 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False)
33 | dataset.load()
34 |
35 | return dataset
36 |
37 | ############################################ UNIT TESTS ############################################
38 |
39 | # Generic object tests
40 |
41 | def test_attributes_after_initilization_of_dataset(empty_dummy_dataset):
42 | """Asserts instance attributes are initialized correctly when dataset is empty (i.e.,
43 | `Dataset.load()` has not been called).
44 | """
45 | # attributes that are passed to __init__
46 | for partition in empty_dummy_dataset.directory:
47 | expected = os.path.join(PATH_TO_DUMMY_DATASET_1, '{}.tsv'.format(partition))
48 | assert empty_dummy_dataset.directory[partition] == expected
49 | assert not empty_dummy_dataset.replace_rare_tokens
50 | # other instance attributes
51 | assert empty_dummy_dataset.conll_parser.root == PATH_TO_DUMMY_DATASET_1
52 | assert empty_dummy_dataset.type_seq == {'train': None, 'valid': None, 'test': None}
53 | assert empty_dummy_dataset.type_to_idx == {'word': None, 'char': None, 'tag': None}
54 | assert empty_dummy_dataset.idx_to_tag is None
55 | assert empty_dummy_dataset.idx_seq == {'train': None, 'valid': None, 'test': None}
56 | # test that we can pass arbitrary keyword arguments
57 | assert empty_dummy_dataset.totally_arbitrary == 'arbitrary'
58 |
59 | def test_value_error_load(empty_dummy_dataset):
60 | """Asserts that `Dataset.load()` raises a ValueError when `Dataset.directory` is None.
61 | """
62 | # Set directory to None to force error to arise
63 | empty_dummy_dataset.directory = None
64 | with pytest.raises(ValueError):
65 | empty_dummy_dataset.load()
66 |
67 | # SINGLE DATASET
68 |
69 | def test_get_types_single_dataset(empty_dummy_dataset):
70 | """Asserts that `Dataset._get_types()` returns the expected values.
71 | """
72 | actual = empty_dummy_dataset._get_types()
73 | expected = {'word': DUMMY_WORD_TYPES, 'char': DUMMY_CHAR_TYPES, 'tag': DUMMY_TAG_TYPES}
74 |
75 | # sort allows us to assert that the two lists are identical
76 | assert all(actual['word'].sort() == v.sort() for k, v in expected.items())
77 |
78 | # Tests on unloaded Dataset object (`Dataset.load()` was not called)
79 |
80 | def test_get_type_seq_single_dataset_before_load(empty_dummy_dataset):
81 | """Asserts that `Dataset.type_seq` is updated as expected after call to
82 | `Dataset._get_type_seq()`.
83 | """
84 | empty_dummy_dataset._get_type_seq()
85 |
86 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['word'], DUMMY_WORD_SEQ)
87 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['char'], DUMMY_CHAR_SEQ)
88 | assert np.array_equal(empty_dummy_dataset.type_seq['train']['tag'], DUMMY_TAG_SEQ)
89 |
90 | def test_get_idx_maps_single_dataset_before_load(empty_dummy_dataset):
91 | """Asserts that `Dataset.type_to_idx` is updated as expected after successive calls to
92 | `Dataset._get_types()` and `Dataset._get_idx_maps()`.
93 | """
94 | types = empty_dummy_dataset._get_types()
95 | empty_dummy_dataset._get_idx_maps(types)
96 |
97 | # ensure that index mapping is a contigous sequence of numbers starting at 0
98 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['word'].values())
99 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['char'].values())
100 | assert generic_utils.is_consecutive(empty_dummy_dataset.type_to_idx['tag'].values())
101 | # ensure that type to index mapping contains the expected keys
102 | assert all(key in DUMMY_WORD_TYPES for key in empty_dummy_dataset.type_to_idx['word'])
103 | assert all(key in DUMMY_CHAR_TYPES for key in empty_dummy_dataset.type_to_idx['char'])
104 | assert all(key in DUMMY_TAG_TYPES for key in empty_dummy_dataset.type_to_idx['tag'])
105 |
106 | def test_get_idx_maps_single_dataset_before_load_special_tokens(empty_dummy_dataset):
107 | """Asserts that `Dataset.type_to_idx` contains the special tokens as keys with expected values
108 | after successive calls to `Dataset._get_types()` and `Dataset._get_idx_maps()`.
109 | """
110 | types = empty_dummy_dataset._get_types()
111 | empty_dummy_dataset._get_idx_maps(types)
112 | # assert special tokens are mapped to the correct indices
113 | assert all(empty_dummy_dataset.type_to_idx['word'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items())
114 | assert all(empty_dummy_dataset.type_to_idx['char'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items())
115 | assert all(empty_dummy_dataset.type_to_idx['tag'][k] == v for k, v in constants.INITIAL_MAPPING['tag'].items())
116 |
117 | def test_get_idx_seq_single_dataset_before_load(empty_dummy_dataset):
118 | """Asserts that `Dataset.idx_seq` is updated as expected after successive calls to
119 | `Dataset._get_type_seq()`, `Dataset._get_idx_maps()` and `Dataset.get_idx_seq()`.
120 | """
121 | types = empty_dummy_dataset._get_types()
122 | empty_dummy_dataset._get_type_seq()
123 | empty_dummy_dataset._get_idx_maps(types)
124 | empty_dummy_dataset.get_idx_seq()
125 |
126 | # as a workaround to testing this directly, just check that shapes are as expected
127 | expected_word_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN)
128 | expected_char_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, constants.MAX_CHAR_LEN)
129 | expected_tag_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, len(DUMMY_TAG_TYPES))
130 |
131 | assert all(empty_dummy_dataset.idx_seq[partition]['word'].shape == expected_word_idx_shape
132 | for partition in ['train', 'test', 'valid'])
133 | assert all(empty_dummy_dataset.idx_seq[partition]['char'].shape == expected_char_idx_shape
134 | for partition in ['train', 'test', 'valid'])
135 | assert all(empty_dummy_dataset.idx_seq[partition]['tag'].shape == expected_tag_idx_shape
136 | for partition in ['train', 'test', 'valid'])
137 |
138 | # tests on loaded Dataset object (`Dataset.load()` was called)
139 |
140 | def test_get_type_seq_single_dataset_after_load(loaded_dummy_dataset):
141 | """Asserts that `Dataset.type_seq` is updated as expected after call to `Dataset.load()`.
142 | """
143 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['word'], DUMMY_WORD_SEQ)
144 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['char'], DUMMY_CHAR_SEQ)
145 | assert np.array_equal(loaded_dummy_dataset.type_seq['train']['tag'], DUMMY_TAG_SEQ)
146 |
147 | def test_get_idx_maps_single_dataset_after_load(loaded_dummy_dataset):
148 | """Asserts that `Dataset.type_to_idx` is updated as expected after call to `Dataset.load()`.
149 | """
150 | # ensure that index mapping is a contigous sequence of numbers starting at 0
151 | # ensure that index mapping is a contigous sequence of numbers starting at 0
152 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['word'].values())
153 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['char'].values())
154 | assert generic_utils.is_consecutive(loaded_dummy_dataset.type_to_idx['tag'].values())
155 | # ensure that type to index mapping contains the expected keys
156 | assert all(key in DUMMY_WORD_TYPES for key in loaded_dummy_dataset.type_to_idx['word'])
157 | assert all(key in DUMMY_CHAR_TYPES for key in loaded_dummy_dataset.type_to_idx['char'])
158 | assert all(key in DUMMY_TAG_TYPES for key in loaded_dummy_dataset.type_to_idx['tag'])
159 |
160 | def test_get_idx_maps_single_dataset_after_load_special_tokens(loaded_dummy_dataset):
161 | """Asserts that `Dataset.type_to_idx` contains the special tokens as keys with expected values
162 | after successive calls to `Dataset._get_types()` and `Dataset.get_idx_maps()`.
163 | """
164 | # assert special tokens are mapped to the correct indices
165 | assert all(loaded_dummy_dataset.type_to_idx['word'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items())
166 | assert all(loaded_dummy_dataset.type_to_idx['char'][k] == v for k, v in constants.INITIAL_MAPPING['word'].items())
167 | assert all(loaded_dummy_dataset.type_to_idx['tag'][k] == v for k, v in constants.INITIAL_MAPPING['tag'].items())
168 |
169 | def test_get_idx_seq_after_load(loaded_dummy_dataset):
170 | """Asserts that `Dataset.idx_seq` is updated as expected after calls to `Dataset.load()`.
171 | """
172 | # as a workaround to testing this directly, just check that shapes are as expected
173 | expected_word_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN)
174 | expected_char_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, constants.MAX_CHAR_LEN)
175 | expected_tag_idx_shape = (len(DUMMY_WORD_SEQ), constants.MAX_SENT_LEN, len(DUMMY_TAG_TYPES))
176 |
177 | assert all(loaded_dummy_dataset.idx_seq[partition]['word'].shape == expected_word_idx_shape
178 | for partition in ['train', 'test', 'valid'])
179 | assert all(loaded_dummy_dataset.idx_seq[partition]['char'].shape == expected_char_idx_shape
180 | for partition in ['train', 'test', 'valid'])
181 | assert all(loaded_dummy_dataset.idx_seq[partition]['tag'].shape == expected_tag_idx_shape
182 | for partition in ['train', 'test', 'valid'])
183 |
184 | # COMPOUND DATASET
185 |
--------------------------------------------------------------------------------
/saber/tests/test_embeddings.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the Embeddings class (saber/embeddings.py).
2 | """
3 | import numpy as np
4 | import pytest
5 |
6 | from ..embeddings import Embeddings
7 | from .resources import helpers
8 | from .resources.dummy_constants import *
9 |
10 | # TODO (johngiorgi): write tests using a binary format file
11 | # TODO (johngiorgi): write tests to test for debug functionality
12 |
13 | ######################################### PYTEST FIXTURES #########################################
14 |
15 | @pytest.fixture
16 | def dummy_embedding_idx():
17 | """Returns embedding index from call to `Embeddings._prepare_embedding_index()`.
18 | """
19 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP)
20 | embedding_idx = embeddings._prepare_embedding_index(binary=False)
21 | return embedding_idx
22 |
23 | @pytest.fixture
24 | def dummy_embedding_matrix_and_type_to_idx():
25 | """Returns the `embedding_matrix` and `type_to_index` objects from call to
26 | `Embeddings._prepare_embedding_matrix(load_all=False)`.
27 | """
28 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP)
29 | embedding_idx = embeddings._prepare_embedding_index(binary=False)
30 | embeddings.num_found = len(embedding_idx)
31 | embeddings.dimension = len(list(embedding_idx.values())[0])
32 | embedding_matrix, type_to_idx = embeddings._prepare_embedding_matrix(embedding_idx, load_all=False)
33 | embeddings.num_embed = embedding_matrix.shape[0] # num of embedded words
34 |
35 | return embedding_matrix, type_to_idx
36 |
37 | @pytest.fixture
38 | def dummy_embedding_matrix_and_type_to_idx_load_all():
39 | """Returns the embedding matrix and type to index objects from call to
40 | `Embeddings._prepare_embedding_matrix(load_all=True)`.
41 | """
42 | # this should be different than DUMMY_TOKEN_MAP for a reliable test
43 | test = {"This": 0, "is": 1, "a": 2, "test": 3}
44 |
45 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=test)
46 | embedding_idx = embeddings._prepare_embedding_index(binary=False)
47 | embeddings.num_found = len(embedding_idx)
48 | embeddings.dimension = len(list(embedding_idx.values())[0])
49 | embedding_matrix, type_to_idx = embeddings._prepare_embedding_matrix(embedding_idx, load_all=True)
50 | embeddings.num_embed = embedding_matrix.shape[0] # num of embedded words
51 |
52 | return embedding_matrix, type_to_idx
53 |
54 | @pytest.fixture
55 | def dummy_embeddings_before_load():
56 | """Returns an instance of an Embeddings() object BEFORE the `Embeddings.load()` method is
57 | called.
58 | """
59 | return Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS,
60 | token_map=DUMMY_TOKEN_MAP,
61 | # to test passing of arbitrary keyword args to constructor
62 | totally_arbitrary='arbitrary')
63 |
64 | @pytest.fixture
65 | def dummy_embeddings_after_load():
66 | """Returns an instance of an Embeddings() object AFTER `Embeddings.load(load_all=False)` is
67 | called.
68 | """
69 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=DUMMY_TOKEN_MAP)
70 | embeddings.load(binary=False, load_all=False) # txt file format is easier to test
71 | return embeddings
72 |
73 | @pytest.fixture
74 | def dummy_embeddings_after_load_with_load_all():
75 | """Returns an instance of an Embeddings() object AFTER `Embeddings.load(load_all=True)` is
76 | called.
77 | """
78 | # this should be different than DUMMY_TOKEN_MAP for a reliable test
79 | test = {"This": 0, "is": 1, "a": 2, "test": 3}
80 |
81 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS, token_map=test)
82 | embeddings.load(binary=False, load_all=True) # txt file format is easier to test
83 | return embeddings
84 |
85 | ############################################ UNIT TESTS ############################################
86 |
87 | def test_initialization(dummy_embeddings_before_load):
88 | """Asserts that Embeddings object contains the expected attribute values after initialization.
89 | """
90 | # test attributes whos values are passed to the constructor
91 | assert dummy_embeddings_before_load.filepath == PATH_TO_DUMMY_EMBEDDINGS
92 | assert dummy_embeddings_before_load.token_map == DUMMY_TOKEN_MAP
93 | # test attributes initialized with default values
94 | assert dummy_embeddings_before_load.matrix is None
95 | assert dummy_embeddings_before_load.num_found is None
96 | assert dummy_embeddings_before_load.num_embed is None
97 | assert dummy_embeddings_before_load.dimension is None
98 | # test that we can pass arbitrary keyword arguments
99 | assert dummy_embeddings_before_load.totally_arbitrary == 'arbitrary'
100 |
101 | def test_prepare_embedding_index(dummy_embedding_idx):
102 | """Asserts that we get the expected value back after call to
103 | `Embeddings._prepare_embedding_index()`.
104 | """
105 | # need to check keys and values differently
106 | assert dummy_embedding_idx.keys() == DUMMY_EMBEDDINGS_INDEX.keys()
107 | assert all(np.allclose(actual, expected) for actual, expected in
108 | zip(dummy_embedding_idx.values(), DUMMY_EMBEDDINGS_INDEX.values()))
109 |
110 | def test_prepare_embedding_matrix(dummy_embedding_matrix_and_type_to_idx):
111 | """Asserts that we get the expected value back after successive calls to
112 | `Embeddings._prepare_embedding_index()` and
113 | `Embeddings._prepare_embedding_matrix(load_all=False)`.
114 | """
115 | # expected_values
116 | embedding_matrix_expected, type_to_idx_expected = DUMMY_EMBEDDINGS_MATRIX, None
117 | # actual values
118 | embedding_matrix_actual, type_to_idx_actual = dummy_embedding_matrix_and_type_to_idx
119 |
120 | assert np.allclose(embedding_matrix_actual, embedding_matrix_expected)
121 | assert type_to_idx_actual is type_to_idx_expected
122 |
123 | def test_prepare_embedding_matrix_load_all(dummy_embedding_matrix_and_type_to_idx_load_all):
124 | """Asserts that we get the expected value back after successive calls to
125 | `Embeddings._prepare_embedding_index()` and
126 | `Embeddings._prepare_embedding_matrix(load_all=True)`.
127 | """
128 | # expected values
129 | embedding_matrix_expected = DUMMY_EMBEDDINGS_MATRIX
130 | type_to_idx_expected = {"word": DUMMY_TOKEN_MAP, "char": DUMMY_CHAR_MAP}
131 | # actual values
132 | embedding_matrix_actual, type_to_idx_actual = dummy_embedding_matrix_and_type_to_idx_load_all
133 |
134 | assert np.allclose(embedding_matrix_actual, embedding_matrix_expected)
135 | helpers.assert_type_to_idx_as_expected(actual=type_to_idx_actual, expected=type_to_idx_expected)
136 |
137 | def test_matrix_after_load(dummy_embeddings_after_load):
138 | """Asserts that pre-trained token embeddings are loaded correctly when
139 | `Embeddings.load(load_all=False)` is called."""
140 | assert np.allclose(dummy_embeddings_after_load.matrix, DUMMY_EMBEDDINGS_MATRIX)
141 |
142 | def test_matrix_after_load_with_load_all(dummy_embeddings_after_load):
143 | """Asserts that pre-trained token embeddings are loaded correctly when
144 | `Embeddings.load(load_all=True)` is called."""
145 | assert np.allclose(dummy_embeddings_after_load.matrix, DUMMY_EMBEDDINGS_MATRIX)
146 |
147 | def test_attributes_after_load(dummy_embedding_idx, dummy_embeddings_after_load):
148 | """Asserts that attributes of Embeddings object are updated as expected after
149 | `Embeddings.load(load_all=False)` is called.
150 | """
151 | # expected values
152 | num_found_expected = len(dummy_embedding_idx)
153 | dimension_expected = len(list(dummy_embedding_idx.values())[0])
154 | num_embed_expected = dummy_embeddings_after_load.matrix.shape[0]
155 | # actual values
156 | num_found_actual = dummy_embeddings_after_load.num_found
157 | dimension_actual = dummy_embeddings_after_load.dimension
158 | num_embed_actual = dummy_embeddings_after_load.num_embed
159 |
160 | assert num_found_expected == num_found_actual
161 | assert dimension_expected == dimension_actual
162 | assert num_embed_expected == num_embed_actual
163 |
164 | def test_attributes_after_load_with_load_all(dummy_embedding_idx,
165 | dummy_embeddings_after_load_with_load_all):
166 | """Asserts that attributes of Embeddings object are updated as expected after
167 | `Embeddings.load(load_all=True)` is called.
168 | """
169 | # expected values
170 | num_found_expected = len(dummy_embedding_idx)
171 | dimension_expected = len(list(dummy_embedding_idx.values())[0])
172 | num_embed_expected = dummy_embeddings_after_load_with_load_all.matrix.shape[0]
173 | # actual values
174 | num_found_actual = dummy_embeddings_after_load_with_load_all.num_found
175 | dimension_actual = dummy_embeddings_after_load_with_load_all.dimension
176 | num_embed_actual = dummy_embeddings_after_load_with_load_all.num_embed
177 |
178 | assert num_found_expected == num_found_actual
179 | assert dimension_expected == dimension_actual
180 | assert num_embed_expected == num_embed_actual
181 |
182 | def test_generate_type_to_idx(dummy_embeddings_before_load):
183 | """Asserts that the dictionary returned from 'Embeddings._generate_type_to_idx()' is as
184 | expected.
185 | """
186 | test = {'This': 0, 'is': 1, 'a': 2, 'test': 3}
187 |
188 | # expected values
189 | expected = {
190 | "word": list(test.keys()),
191 | "char": []
192 | }
193 | for word in expected['word']:
194 | expected['char'].extend(list(word))
195 | expected['char'] = list(set(expected['char']))
196 | # actual values
197 | actual = dummy_embeddings_before_load._generate_type_to_idx(test)
198 |
199 | helpers.assert_type_to_idx_as_expected(actual=actual, expected=expected)
200 |
--------------------------------------------------------------------------------
/saber/tests/test_generic_utils.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the generic_utils (saber/utils/generic_utils.py).
2 | """
3 | import os
4 |
5 | import pytest
6 |
7 | from ..config import Config
8 | from ..utils import generic_utils
9 | from .resources.dummy_constants import *
10 |
11 | ######################################### PYTEST FIXTURES #########################################
12 |
13 | @pytest.fixture(scope='session')
14 | def dummy_dir(tmpdir_factory):
15 | """Returns the path to a temporary directory.
16 | """
17 | dummy_dir = tmpdir_factory.mktemp('dummy_dir')
18 | return dummy_dir.strpath
19 |
20 | @pytest.fixture
21 | def dummy_config():
22 | """Returns an instance of a Config object."""
23 | dummy_config = Config(PATH_TO_DUMMY_CONFIG)
24 | return dummy_config
25 |
26 | ############################################ UNIT TESTS ############################################
27 |
28 | def test_is_consecutive_empty():
29 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed an
30 | empty list.
31 | """
32 | test = []
33 |
34 | expected = True
35 | actual = generic_utils.is_consecutive(test)
36 |
37 | assert actual == expected
38 |
39 | def test_is_consecutive_simple_sorted_list_no_duplicates():
40 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
41 | simple sorted list with no duplicates.
42 | """
43 | test_true = [0, 1, 2, 3, 4, 5]
44 | test_false = [1, 2, 3, 4, 5, 6]
45 |
46 | expected_true = True
47 | expected_false = False
48 |
49 | actual_true = generic_utils.is_consecutive(test_true)
50 | actual_false = generic_utils.is_consecutive(test_false)
51 |
52 | assert actual_true == expected_true
53 | assert actual_false == expected_false
54 |
55 | def test_is_consecutive_simple_unsorted_list_no_duplicates():
56 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
57 | simple unsorted list with no duplicates.
58 | """
59 | test_true = [0, 1, 3, 2, 4, 5]
60 | test_false = [1, 2, 4, 3, 5, 6]
61 |
62 | expected_true = True
63 | expected_false = False
64 |
65 | actual_true = generic_utils.is_consecutive(test_true)
66 | actual_false = generic_utils.is_consecutive(test_false)
67 |
68 | assert actual_true == expected_true
69 | assert actual_false == expected_false
70 |
71 | def test_is_consecutive_simple_sorted_list_duplicates():
72 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
73 | simple sorted list with duplicates.
74 | """
75 | test = [0, 1, 2, 3, 3, 4, 5]
76 |
77 | expected = False
78 | actual = generic_utils.is_consecutive(test)
79 |
80 | assert actual == expected
81 |
82 | def test_is_consecutive_simple_unsorted_list_duplicates():
83 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
84 | simple unsorted list with duplicates.
85 | """
86 | test = [0, 1, 4, 3, 3, 2, 5]
87 |
88 | expected = False
89 | actual = generic_utils.is_consecutive(test)
90 |
91 | assert actual == expected
92 |
93 | def test_is_consecutive_simple_dict_no_duplicates():
94 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
95 | dictionaries values no duplicates.
96 | """
97 | test_true = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5}
98 | test_false = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6}
99 |
100 | expected_true = True
101 | expected_false = False
102 |
103 | actual_true = generic_utils.is_consecutive(test_true.values())
104 | actual_false = generic_utils.is_consecutive(test_false.values())
105 |
106 | assert actual_true == expected_true
107 | assert actual_false == expected_false
108 |
109 | def test_is_consecutive_simple_dict_duplicates():
110 | """Asserts that `generic_utils.is_consecutive()` returns the expected value when passed a
111 | dictionaries values with duplicates.
112 | """
113 | test = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 3, 'f': 4, 'g': 5}
114 |
115 | expected = False
116 | actual = generic_utils.is_consecutive(test)
117 |
118 | assert actual == expected
119 |
120 | def test_reverse_dict_empty():
121 | """Asserts that `generic_utils.reverse_dictionary()` returns the expected value when given an
122 | empty dictionary.
123 | """
124 | test = {}
125 | expected = {}
126 | actual = generic_utils.reverse_dict(test)
127 |
128 | assert actual == expected
129 |
130 | def test_reverse_mapping_simple():
131 | """Asserts that `generic_utils.reverse_dictionary()` returns the expected value when given a
132 | simply dictionary.
133 | """
134 | test = {'a': 1, 'b': 2, 'c': 3}
135 |
136 | expected = {1: 'a', 2: 'b', 3: 'c'}
137 | actual = generic_utils.reverse_dict(test)
138 |
139 | assert actual == expected
140 |
141 | def test_make_dir_new(tmpdir):
142 | """Assert that `generic_utils.make_dir()` creates a directory as expected when it does not
143 | already exist.
144 | """
145 | dummy_dirpath = os.path.join(tmpdir.strpath, 'dummy_dir')
146 | generic_utils.make_dir(dummy_dirpath)
147 | assert os.path.isdir(dummy_dirpath)
148 |
149 | def test_make_dir_exists(dummy_dir):
150 | """Assert that `generic_utils.make_dir()` fails silently when trying to create a directory that
151 | already exists.
152 | """
153 | generic_utils.make_dir(dummy_dir)
154 | assert os.path.isdir(dummy_dir)
155 |
156 | def test_clean_path():
157 | """Asserts that filepath returned by `generic_utils.clean_path()` is as expected.
158 | """
159 | test = ' this/is//a/test/ '
160 | expected = os.path.abspath('this/is/a/test')
161 |
162 | assert generic_utils.clean_path(test) == expected
163 |
164 | def test_decompress_model():
165 | """Asserts that `generic_utils.decompress_model()` decompresses a given directory.
166 | """
167 | pass
168 |
169 | def test_compress_model():
170 | """Asserts that `generic_utils.compress_model()` compresses a given directory.
171 | """
172 | pass
173 |
--------------------------------------------------------------------------------
/saber/tests/test_grounding_utils.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the grounding_utils (saber/utils/grounding_utils.py).
2 | """
3 | import copy
4 |
5 | import pytest
6 |
7 | from .. import constants
8 | from ..utils import grounding_utils
9 |
10 |
11 | @pytest.fixture
12 | def blank_annotation():
13 | """Returns an annotation with no identified entities.
14 | """
15 | annotation = {"ents": [],
16 | "text": "This is a test with no entities.",
17 | "title": ""}
18 | return annotation
19 |
20 | @pytest.fixture
21 | def ched_annotation():
22 | """Returns an annotation with chemical entities (CHED) identified.
23 | """
24 | annotation = {"ents": [{"text": "glucose", "label": "CHED", "start": 0, "end": 0},
25 | {"text": "fructose", "label": "CHED", "start": 0, "end": 0}],
26 | "text": "glucose and fructose",
27 | "title": ""}
28 |
29 | return annotation
30 |
31 | @pytest.fixture
32 | def diso_annotation():
33 | """Returns an annotation with disease entities (DISO) identified.
34 | """
35 | annotation = {"ents": [{"text": "cancer", "label": "DISO", "start": 0, "end": 0},
36 | {"text": "cystic fibrosis", "label": "DISO", "start": 0, "end": 0}],
37 | "text": "cancer and cystic fibrosis",
38 | "title": ""}
39 |
40 | return annotation
41 |
42 | @pytest.fixture
43 | def livb_annotation():
44 | """Returns an annotation with species entities (LIVB) identified.
45 | """
46 | annotation = {"ents": [{"text": "mouse", "label": "LIVB", "start": 0, "end": 0},
47 | {"text": "human", "label": "LIVB", "start": 0, "end": 0}],
48 | "text": "mouse and human",
49 | "title": ""}
50 |
51 | return annotation
52 |
53 | @pytest.fixture
54 | def prge_annotation():
55 | """Returns an annotation with protein/gene entities (PRGE) identified.
56 | """
57 | annotation = {"ents": [{"text": "p53", "label": "PRGE", "start": 0, "end": 0},
58 | {"text": "MK2", "label": "PRGE", "start": 0, "end": 0}],
59 | "text": "p53 and MK2",
60 | "title": ""}
61 |
62 | return annotation
63 |
64 | def test_ground_no_entites(blank_annotation):
65 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with
66 | no identified entities.
67 | """
68 |
69 | actual = grounding_utils.ground(blank_annotation)
70 | expected = blank_annotation
71 |
72 | assert actual == expected
73 |
74 | def test_ground_chemicals(ched_annotation):
75 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with
76 | chemical entities.
77 | """
78 | actual = grounding_utils.ground(copy.deepcopy(ched_annotation))
79 |
80 | # create expected value
81 | glucose_xrefs = [
82 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs00005793'},
83 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs10954115'},
84 | {'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs53782692'},
85 | ]
86 | fructose_xrefs = [{'namespace': constants.NAMESPACES['CHED'], 'id': 'CIDs00439709'}]
87 |
88 | ched_annotation['ents'][0].update(xrefs=glucose_xrefs)
89 | ched_annotation['ents'][1].update(xrefs=fructose_xrefs)
90 |
91 | expected = ched_annotation
92 |
93 | assert actual == expected
94 |
95 | def test_ground_diso(diso_annotation):
96 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with
97 | disease entities.
98 | """
99 | actual = grounding_utils.ground(copy.deepcopy(diso_annotation))
100 |
101 | # create expected value
102 | cancer_xrefs = [{'namespace': constants.NAMESPACES['DISO'], 'id': 'DOID:162'}]
103 | cystic_fibrosis_xrefs = [{'namespace': constants.NAMESPACES['DISO'], 'id': 'DOID:1485'}]
104 |
105 | diso_annotation['ents'][0].update(xrefs=cancer_xrefs)
106 | diso_annotation['ents'][1].update(xrefs=cystic_fibrosis_xrefs)
107 |
108 | expected = diso_annotation
109 |
110 | assert actual == expected
111 |
112 | def test_ground_livb(livb_annotation):
113 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with
114 | species entities.
115 | """
116 | actual = grounding_utils.ground(copy.deepcopy(livb_annotation))
117 |
118 | # create expected value
119 | mouse_xrefs = [
120 | {'namespace': constants.NAMESPACES['LIVB'], 'id': '10090'},
121 | {'namespace': constants.NAMESPACES['LIVB'], 'id': '10088'},
122 | ]
123 | human_xrefs = [{'namespace': constants.NAMESPACES['LIVB'], 'id': '9606'}]
124 |
125 | livb_annotation['ents'][0].update(xrefs=mouse_xrefs)
126 | livb_annotation['ents'][1].update(xrefs=human_xrefs)
127 |
128 | expected = livb_annotation
129 |
130 | assert actual == expected
131 |
132 | def test_ground_prge(prge_annotation):
133 | """Asserts that `grounding_utils.ground()` returns the expected value for a simple example with
134 | species entities.
135 | """
136 | actual = grounding_utils.ground(copy.deepcopy(prge_annotation))
137 |
138 | # create expected value
139 | p53_xrefs = [
140 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000269305', 'organism-id': '9606'}
141 | ]
142 | mk2_xrefs = [
143 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000356070', 'organism-id': '9606'},
144 | {'namespace': constants.NAMESPACES['PRGE'], 'id': 'ENSP00000433109', 'organism-id': '9606'},
145 | ]
146 |
147 | prge_annotation['ents'][0].update(xrefs=p53_xrefs)
148 | prge_annotation['ents'][1].update(xrefs=mk2_xrefs)
149 |
150 | expected = prge_annotation
151 |
152 | assert actual == expected
153 |
--------------------------------------------------------------------------------
/saber/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the Metrics class (saber/metrics.py).
2 | """
3 | import pytest
4 |
5 | from .. import constants
6 | from ..config import Config
7 | from ..dataset import Dataset
8 | from ..metrics import Metrics
9 | from ..utils import model_utils
10 | from .resources.dummy_constants import *
11 |
12 | PATH_TO_METRICS_OUTPUT = 'totally/arbitrary'
13 |
14 | ######################################### PYTEST FIXTURES #########################################
15 |
16 | @pytest.fixture
17 | def dummy_config():
18 | """Returns an instance of a Config object."""
19 | dummy_config = Config(PATH_TO_DUMMY_CONFIG)
20 | return dummy_config
21 |
22 | @pytest.fixture
23 | def dummy_dataset():
24 | """Returns a single dummy Dataset instance after calling Dataset.load().
25 | """
26 | # Don't replace rare tokens for the sake of testing
27 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False)
28 | dataset.load()
29 |
30 | return dataset
31 |
32 | @pytest.fixture
33 | def dummy_output_dir(tmpdir, dummy_config):
34 | """Returns list of output directories."""
35 | # make sure top-level directory is the pytest tmpdir
36 | dummy_config.output_folder = tmpdir.strpath
37 | output_dirs = model_utils.prepare_output_directory(dummy_config)
38 |
39 | return output_dirs
40 |
41 | @pytest.fixture
42 | def dummy_training_data(dummy_dataset):
43 | """Returns training data from `dummy_dataset`.
44 | """
45 | training_data = {'x_train': [dummy_dataset.idx_seq['train']['word'],
46 | dummy_dataset.idx_seq['train']['char']],
47 | 'x_valid': None,
48 | 'x_test': None,
49 | 'y_train': dummy_dataset.idx_seq['train']['tag'],
50 | 'y_valid': None,
51 | 'y_test': None,
52 | }
53 |
54 | return training_data
55 |
56 | @pytest.fixture
57 | def dummy_metrics(dummy_config, dummy_dataset, dummy_training_data, dummy_output_dir):
58 | """Returns an instance of Metrics.
59 | """
60 | metrics = Metrics(config=dummy_config,
61 | training_data=dummy_training_data,
62 | index_map=dummy_dataset.idx_to_tag,
63 | output_dir=dummy_output_dir,
64 | # to test passing of arbitrary keyword args to constructor
65 | totally_arbitrary='arbitrary')
66 | return metrics
67 |
68 | ############################################ UNIT TESTS ############################################
69 |
70 | def test_attributes_after_initilization(dummy_config,
71 | dummy_dataset,
72 | dummy_output_dir,
73 | dummy_training_data,
74 | dummy_metrics):
75 | """Asserts instance attributes are initialized correctly when Metrics object is initialized."""
76 | # attributes that are passed to __init__
77 | assert dummy_metrics.config is dummy_config
78 | assert dummy_metrics.training_data is dummy_training_data
79 | assert dummy_metrics.index_map is dummy_dataset.idx_to_tag
80 | assert dummy_metrics.output_dir == dummy_output_dir
81 | # other instance attributes
82 | assert dummy_metrics.current_epoch == 0
83 | assert dummy_metrics.performance_metrics == {p: [] for p in constants.PARTITIONS}
84 | # test that we can pass arbitrary keyword arguments
85 | assert dummy_metrics.totally_arbitrary == 'arbitrary'
86 |
87 | def test_precision_recall_f1_support_value_error():
88 | """Asserts that call to `Metrics.get_precision_recall_f1_support` raises a `ValueError` error
89 | when an invalid value for parameter `criteria` is passed."""
90 | # these are totally arbitrary
91 | y_true = [('test', 0, 3), ('test', 4, 7), ('test', 8, 11)]
92 | y_pred = [('test', 0, 3), ('test', 4, 7), ('test', 8, 11)]
93 |
94 | # anything but 'exact', 'left', or 'right' should throw an error
95 | invalid_args = ['right ', 'LEFT', 'eXact', 0, []]
96 |
97 | for arg in invalid_args:
98 | with pytest.raises(ValueError):
99 | Metrics.get_precision_recall_f1_support(y_true, y_pred, criteria=arg)
100 |
--------------------------------------------------------------------------------
/saber/tests/test_model_utils.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the model_utils (saber/utils/model_utils.py).
2 | """
3 | import os
4 |
5 | import pytest
6 | from keras.callbacks import ModelCheckpoint, TensorBoard
7 |
8 | from ..config import Config
9 | from ..utils import model_utils
10 | from .resources.dummy_constants import *
11 |
12 | ######################################### PYTEST FIXTURES #########################################
13 |
14 | @pytest.fixture
15 | def dummy_config():
16 | """Returns an instance of a Config object."""
17 | return Config(PATH_TO_DUMMY_CONFIG)
18 |
19 | @pytest.fixture
20 | def dummy_output_dir(tmpdir, dummy_config):
21 | """Returns list of output directories."""
22 | # make sure top-level directory is the pytest tmpdir
23 | dummy_config.output_folder = tmpdir.strpath
24 | output_dirs = model_utils.prepare_output_directory(dummy_config)
25 |
26 | return output_dirs
27 |
28 | ############################################ UNIT TESTS ############################################
29 |
30 | def test_prepare_output_directory(dummy_config, dummy_output_dir):
31 | """Assert that `model_utils.prepare_output_directory()` creates the expected directories
32 | with the expected content.
33 | """
34 | # TODO (johngiorgi): need to test the actual output of the function!
35 | # check that the expected directories are created
36 | assert all([os.path.isdir(dir_) for dir_ in dummy_output_dir])
37 | # check that they contain config files
38 | assert all([os.path.isfile(os.path.join(dir_, 'config.ini')) for dir_ in dummy_output_dir])
39 |
40 | def test_prepare_pretrained_model_dir(dummy_config):
41 | """Asserts that filepath returned by `generic_utils.get_pretrained_model_dir()` is as expected.
42 | """
43 | dataset = os.path.basename(dummy_config.dataset_folder[0])
44 | expected = os.path.join(dummy_config.output_folder, constants.PRETRAINED_MODEL_DIR, dataset)
45 | assert model_utils.prepare_pretrained_model_dir(dummy_config) == expected
46 |
47 | def test_setup_checkpoint_callback(dummy_config, dummy_output_dir):
48 | """Check that we get the expected results from call to
49 | `model_utils.setup_checkpoint_callback()`.
50 | """
51 | simple_actual = model_utils.setup_checkpoint_callback(dummy_config, dummy_output_dir)
52 | blank_actual = model_utils.setup_checkpoint_callback(dummy_config, [])
53 |
54 | # should get as many Callback objects as datasets
55 | assert len(dummy_output_dir) == len(simple_actual)
56 | # all objects in returned list should be of type ModelCheckpoint
57 | assert all([isinstance(x, ModelCheckpoint) for x in simple_actual])
58 |
59 | # blank input should return blank list
60 | assert blank_actual == []
61 |
62 | def test_setup_tensorboard_callback(dummy_output_dir):
63 | """Check that we get the expected results from call to
64 | `model_utils.setup_tensorboard_callback()`.
65 | """
66 | simple_actual = model_utils.setup_tensorboard_callback(dummy_output_dir)
67 | blank_actual = model_utils.setup_tensorboard_callback([])
68 |
69 | # should get as many Callback objects as datasets
70 | assert len(dummy_output_dir) == len(simple_actual)
71 | # all objects in returned list should be of type TensorBoard
72 | assert all([isinstance(x, TensorBoard) for x in simple_actual])
73 |
74 | # blank input should return blank list
75 | assert blank_actual == []
76 |
77 | def test_setup_metrics_callback():
78 | """
79 | """
80 | pass
81 |
82 | def test_setup_callbacks(dummy_config, dummy_output_dir):
83 | """Check that we get the expected results from call to
84 | `model_utils.setup_callbacks()`.
85 | """
86 | # setup callbacks with config.tensorboard == True
87 | dummy_config.tensorboard = True
88 | with_tensorboard_actual = model_utils.setup_callbacks(dummy_config, dummy_output_dir)
89 | # setup callbacks with config.tensorboard == False
90 | dummy_config.tensorboard = False
91 | without_tensorboard_actual = model_utils.setup_callbacks(dummy_config, dummy_output_dir)
92 |
93 | blank_actual = []
94 |
95 | # should get as many Callback objects as datasets
96 | assert all([len(x) == len(dummy_output_dir) for x in with_tensorboard_actual])
97 | assert all([len(x) == len(dummy_output_dir) for x in without_tensorboard_actual])
98 |
99 | # all objects in returned list should be of expected type
100 | assert all([isinstance(x, ModelCheckpoint) for x in with_tensorboard_actual[0]])
101 | assert all([isinstance(x, TensorBoard) for x in with_tensorboard_actual[1]])
102 | assert all([isinstance(x, ModelCheckpoint) for x in without_tensorboard_actual[0]])
103 |
104 | # blank input should return blank list
105 | assert blank_actual == []
106 |
107 | def test_precision_recall_f1_support():
108 | """Asserts that model_utils.precision_recall_f1_support returns the expected values."""
109 | TP_dummy = 100
110 | FP_dummy = 10
111 | FN_dummy = 20
112 |
113 | prec_dummy = TP_dummy / (TP_dummy + FP_dummy)
114 | rec_dummy = TP_dummy / (TP_dummy + FN_dummy)
115 | f1_dummy = 2 * prec_dummy * rec_dummy / (prec_dummy + rec_dummy)
116 | support_dummy = TP_dummy + FN_dummy
117 |
118 | test_scores_no_null = model_utils.precision_recall_f1_support(TP_dummy, FP_dummy, FN_dummy)
119 | test_scores_TP_null = model_utils.precision_recall_f1_support(0, FP_dummy, FN_dummy)
120 | test_scores_FP_null = model_utils.precision_recall_f1_support(TP_dummy, 0, FN_dummy)
121 | f1_FP_null = 2 * 1. * rec_dummy / (1. + rec_dummy)
122 | test_scores_FN_null = model_utils.precision_recall_f1_support(TP_dummy, FP_dummy, 0)
123 | f1_FN_null = 2 * prec_dummy * 1. / (prec_dummy + 1.)
124 | test_scores_all_null = model_utils.precision_recall_f1_support(0, 0, 0)
125 |
126 | assert test_scores_no_null == (prec_dummy, rec_dummy, f1_dummy, support_dummy)
127 | assert test_scores_TP_null == (0., 0., 0., FN_dummy)
128 | assert test_scores_FP_null == (1., rec_dummy, f1_FP_null, support_dummy)
129 | assert test_scores_FN_null == (prec_dummy, 1., f1_FN_null, TP_dummy)
130 | assert test_scores_all_null == (0., 0., 0., 0)
131 |
--------------------------------------------------------------------------------
/saber/tests/test_multi_task_lstm_crf.py:
--------------------------------------------------------------------------------
1 | """Any and all unit tests for the MultiTaskLSTMCRF (saber/models/multi_task_lstm_crf.py).
2 | """
3 | import pytest
4 | from keras.engine.training import Model
5 |
6 | from ..config import Config
7 | from ..dataset import Dataset
8 | from ..embeddings import Embeddings
9 | from ..models.base_model import BaseKerasModel
10 | from ..models.multi_task_lstm_crf import MultiTaskLSTMCRF
11 | from .resources.dummy_constants import *
12 |
13 | ######################################### PYTEST FIXTURES #########################################
14 |
15 | @pytest.fixture
16 | def dummy_config():
17 | """Returns an instance of a Config object."""
18 | return Config(PATH_TO_DUMMY_CONFIG)
19 |
20 | @pytest.fixture
21 | def dummy_dataset_1():
22 | """Returns a single dummy Dataset instance after calling `Dataset.load()`.
23 | """
24 | # Don't replace rare tokens for the sake of testing
25 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_1, replace_rare_tokens=False)
26 | dataset.load()
27 |
28 | return dataset
29 |
30 | @pytest.fixture
31 | def dummy_dataset_2():
32 | """Returns a single dummy Dataset instance after calling `Dataset.load()`.
33 | """
34 | # Don't replace rare tokens for the sake of testing
35 | dataset = Dataset(directory=PATH_TO_DUMMY_DATASET_2, replace_rare_tokens=False)
36 | dataset.load()
37 |
38 | return dataset
39 |
40 | @pytest.fixture
41 | def dummy_embeddings(dummy_dataset_1):
42 | """Returns an instance of an `Embeddings()` object AFTER the `.load()` method is called.
43 | """
44 | embeddings = Embeddings(filepath=PATH_TO_DUMMY_EMBEDDINGS,
45 | token_map=dummy_dataset_1.idx_to_tag)
46 | embeddings.load(binary=False) # txt file format is easier to test
47 | return embeddings
48 |
49 | @pytest.fixture
50 | def single_model(dummy_config, dummy_dataset_1, dummy_embeddings):
51 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration."""
52 | model = MultiTaskLSTMCRF(config=dummy_config,
53 | datasets=[dummy_dataset_1],
54 | # to test passing of arbitrary keyword args to constructor
55 | totally_arbitrary='arbitrary')
56 | return model
57 |
58 | @pytest.fixture
59 | def single_model_specify(single_model):
60 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
61 | a single specified model."""
62 | single_model.specify()
63 |
64 | return single_model
65 |
66 | @pytest.fixture
67 | def single_model_embeddings(dummy_config, dummy_dataset_1, dummy_embeddings):
68 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file and
69 | loaded embeddings"""
70 | model = MultiTaskLSTMCRF(config=dummy_config,
71 | datasets=[dummy_dataset_1],
72 | embeddings=dummy_embeddings,
73 | # to test passing of arbitrary keyword args to constructor
74 | totally_arbitrary='arbitrary')
75 | return model
76 |
77 | @pytest.fixture
78 | def single_model_embeddings_specify(single_model_embeddings):
79 | """Returns an instance of MultiTaskLSTMCRF initialized with the default configuration file,
80 | loaded embeddings and single specified model."""
81 | single_model_embeddings.specify()
82 |
83 | return single_model_embeddings
84 |
85 | ############################################ UNIT TESTS ############################################
86 |
87 | def test_attributes_init_of_single_model(dummy_config, dummy_dataset_1, single_model):
88 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is
89 | initialized without embeddings (`embeddings` attribute is None.)
90 | """
91 | assert isinstance(single_model, (MultiTaskLSTMCRF, BaseKerasModel))
92 | # attributes that are passed to __init__
93 | assert single_model.config is dummy_config
94 | assert single_model.datasets[0] is dummy_dataset_1
95 | assert single_model.embeddings is None
96 | # other instance attributes
97 | assert single_model.models == []
98 | # test that we can pass arbitrary keyword arguments
99 | assert single_model.totally_arbitrary == 'arbitrary'
100 |
101 | def test_attributes_init_of_single_model_specify(dummy_config, dummy_dataset_1, single_model_specify):
102 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF`
103 | model is initialized without embeddings (`embeddings` attribute is None) and
104 | `MultiTaskLSTMCRF.specify()` has been called.
105 | """
106 | assert isinstance(single_model_specify, (MultiTaskLSTMCRF, BaseKerasModel))
107 | # attributes that are passed to __init__
108 | assert single_model_specify.config is dummy_config
109 | assert single_model_specify.datasets[0] is dummy_dataset_1
110 | assert single_model_specify.embeddings is None
111 | # other instance attributes
112 | assert all([isinstance(model, Model) for model in single_model_specify.models])
113 | # test that we can pass arbitrary keyword arguments
114 | assert single_model_specify.totally_arbitrary == 'arbitrary'
115 |
116 | def test_attributes_init_of_single_model_embeddings(dummy_config, dummy_dataset_1,
117 | dummy_embeddings, single_model_embeddings):
118 | """Asserts instance attributes are initialized correctly when single `MultiTaskLSTMCRF` model is
119 | initialized with embeddings (`embeddings` attribute is not None.)
120 | """
121 | assert isinstance(single_model_embeddings, (MultiTaskLSTMCRF, BaseKerasModel))
122 | # attributes that are passed to __init__
123 | assert single_model_embeddings.config is dummy_config
124 | assert single_model_embeddings.datasets[0] is dummy_dataset_1
125 | assert single_model_embeddings.embeddings is dummy_embeddings
126 | # other instance attributes
127 | assert single_model_embeddings.models == []
128 | # test that we can pass arbitrary keyword arguments
129 | assert single_model_embeddings.totally_arbitrary == 'arbitrary'
130 |
131 | def test_attributes_init_of_single_model_embeddings_specify(dummy_config,
132 | dummy_dataset_1,
133 | dummy_embeddings,
134 | single_model_embeddings_specify):
135 | """Asserts instance attributes are initialized correctly when single MultiTaskLSTMCRF
136 | model is initialized with embeddings (`embeddings` attribute is not None) and
137 | `MultiTaskLSTMCRF.specify()` has been called.
138 | """
139 | assert isinstance(single_model_embeddings_specify, (MultiTaskLSTMCRF, BaseKerasModel))
140 | # attributes that are passed to __init__
141 | assert single_model_embeddings_specify.config is dummy_config
142 | assert single_model_embeddings_specify.datasets[0] is dummy_dataset_1
143 | assert single_model_embeddings_specify.embeddings is dummy_embeddings
144 | # other instance attributes
145 | assert all([isinstance(model, Model) for model in single_model_embeddings_specify.models])
146 | # test that we can pass arbitrary keyword arguments
147 | assert single_model_embeddings_specify.totally_arbitrary == 'arbitrary'
148 |
149 | def test_crf_after_transfer(single_model_specify, dummy_dataset_2):
150 | """Asserts that the CRF output layer of a model is replaced with a new layer when
151 | `MultiTaskLSTMCRF.prepare_for_transfer()` is called by testing that the `name` attribute
152 | of the final layer.
153 | """
154 | # shorten test statements
155 | test_model = single_model_specify
156 |
157 | # get output layer names before transfer
158 | expected_before_transfer = ['crf_classifier']
159 | actual_before_transfer = [model.layers[-1].name for model in test_model.models]
160 | # get output layer names after transfer
161 | test_model.prepare_for_transfer([dummy_dataset_2])
162 | expected_after_transfer = ['target_crf_classifier']
163 | actual_after_transfer = [model.layers[-1].name for model in test_model.models]
164 |
165 | assert actual_before_transfer == expected_before_transfer
166 | assert actual_after_transfer == expected_after_transfer
167 |
--------------------------------------------------------------------------------
/saber/tests/test_preprocessor.py:
--------------------------------------------------------------------------------
1 | """Contains any and all unit tests for the `Preprocessor` class (saber/preprocessor.py).
2 | """
3 | import en_coref_md
4 | import pytest
5 |
6 | from .. import constants
7 | from ..preprocessor import Preprocessor
8 |
9 | ######################################### PYTEST FIXTURES #########################################
10 |
11 | @pytest.fixture
12 | def preprocessor():
13 | """Returns an instance of a Preprocessor object."""
14 | return Preprocessor()
15 |
16 | @pytest.fixture
17 | def nlp():
18 | """Returns Sacy NLP model."""
19 | return en_coref_md.load()
20 |
21 | ############################################ UNIT TESTS ############################################
22 |
23 | def test_process_text(preprocessor, nlp):
24 | """Asserts that call to Preprocessor._process_text() returns the expected
25 | results."""
26 | # simple test and its expected value
27 | simple_text = nlp("Simple example. With two sentences!")
28 | simple_expected = ([['Simple', 'example', '.'], ['With', 'two', \
29 | 'sentences', '!']], [[(0, 6), (7, 14), (14, 15)], [(16, 20),\
30 | (21, 24), (25, 34), (34, 35)]])
31 | # blank value test and its expected value
32 | blank_test = nlp("")
33 | blank_expected = ([], [])
34 |
35 | assert preprocessor._process_text(simple_text) == simple_expected
36 | assert preprocessor._process_text(blank_test) == blank_expected
37 |
38 | def test_type_to_idx_value_error():
39 | """
40 | """
41 | with pytest.raises(ValueError):
42 | invalid_input = {'a': 0, 'b': 2, 'c': 3}
43 | Preprocessor.type_to_idx([], initial_mapping=invalid_input)
44 |
45 | def test_type_to_idx_empty_input():
46 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when
47 | an empty list is passed as input."""
48 | expected = {}
49 | actual = Preprocessor.type_to_idx([])
50 |
51 | assert actual == expected
52 |
53 | def test_type_to_idx_simple_input():
54 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when
55 | a simple list of strings is passed as input."""
56 | test = ["This", "is", "a", "test", "."]
57 | expected = {'This': 0, 'is': 1, 'a': 2, 'test': 3, '.': 4}
58 | actual = Preprocessor.type_to_idx(test)
59 |
60 | assert actual == expected
61 |
62 | def test_type_to_idx_intial_mapping():
63 | """Asserts that call to Preprocessor.test_type_to_idx() returns the expected results when
64 | a simple list of strings is passed as input and a supplied `intitial_mapping` argument"""
65 | test = ["This", "is", "a", "test", "."]
66 | initial_mapping = {'This': 0, 'is': 1}
67 |
68 | expected = {'This': 0, 'is': 1, 'a': 2, 'test': 3, '.': 4}
69 | actual = Preprocessor.type_to_idx(test, initial_mapping=initial_mapping)
70 |
71 | assert actual == expected
72 |
73 | def test_get_type_to_idx_sequence():
74 | """"""
75 | simple_seq = ["This", "is", "a", "test", ".", constants.UNK]
76 | simple_type_to_idx = Preprocessor.type_to_idx(simple_seq)
77 | simple_expected = [0, 1, 2, 3, 4]
78 | simple_actual = Preprocessor.get_type_idx_sequence(simple_seq, type_to_idx=simple_type_to_idx)
79 |
80 | pass
81 |
82 | def test_chunk_entities():
83 | """Asserts that call to Preprocessor.chunk_entities() returns the
84 | expected results."""
85 | simple_seq = ['B-PRGE', 'I-PRGE', 'O', 'B-PRGE']
86 | simple_expected = [('PRGE', 0, 2), ('PRGE', 3, 4)]
87 |
88 | two_type_seq = ['B-LIVB', 'I-LIVB', 'O', 'B-PRGE']
89 | two_type_expected = [('LIVB', 0, 2), ('PRGE', 3, 4)]
90 |
91 | invalid_seq = ['O', 'I-CHED', 'I-CHED', 'O']
92 | invalid_expected = []
93 |
94 | blank_seq = []
95 | blank_expected = []
96 |
97 | assert Preprocessor.chunk_entities(simple_seq) == simple_expected
98 | assert Preprocessor.chunk_entities(two_type_seq) == two_type_expected
99 | assert Preprocessor.chunk_entities(invalid_seq) == invalid_expected
100 | assert Preprocessor.chunk_entities(blank_seq) == blank_expected
101 |
102 | def test_sterilize():
103 | """Asserts that call to Preprocessor.sterilize() returns the
104 | expected results."""
105 | # test for proceeding and preeceding spaces
106 | simple_text = " This is an easy test. "
107 | simple_expected = "This is an easy test."
108 | # test for mutliple inline spacing errors
109 | multiple_spaces_text = "This is a test with improper spacing. "
110 | multiple_spaces_expected = "This is a test with improper spacing."
111 | # blank value test and its expected value
112 | blank_text = ""
113 | blank_expected = ""
114 |
115 | assert Preprocessor.sterilize(simple_text) == simple_expected
116 | assert Preprocessor.sterilize(multiple_spaces_text) == multiple_spaces_expected
117 | assert Preprocessor.sterilize(blank_text) == blank_expected
118 |
--------------------------------------------------------------------------------
/saber/tests/test_text_utils.py:
--------------------------------------------------------------------------------
1 | import en_coref_md
2 | import pytest
3 |
4 | from ..utils import text_utils
5 |
6 | ######################################### PYTEST FIXTURES #########################################
7 |
8 | @pytest.fixture
9 | def nlp():
10 | """Returns an instance of a spaCy's nlp object after replacing the default tokenizer with
11 | our modified one."""
12 | custom_nlp = en_coref_md.load()
13 | custom_nlp.tokenizer = text_utils.biomedical_tokenizer(custom_nlp)
14 | return custom_nlp
15 |
16 | ############################################ UNIT TESTS ############################################
17 |
18 | def test_biomedical_tokenizer(nlp):
19 | """Asserts that call to customized spaCy tokenizer returns the expected results.
20 | """
21 | # the empty string
22 | blank_text = ""
23 | blank_expected = []
24 | # simple test with no complexities
25 | simple_text = "This is an easy test."
26 | simple_expected = ["This", "is", "an", "easy", "test", "."]
27 | # complicated test with some important edge cases
28 | complicated_test = ("This test's tokenizers handeling of very-tricky situations, 3X, "
29 | "more/or/less.")
30 | complicated_expected = ["This", "test", "'", "s", "tokenizers", "handeling", "of",
31 | "very", "-", "tricky", "situations", ",", "3X", ",", "more", "/", "or",
32 | "/", "less", "."]
33 |
34 | # these tests were taken straight from training data
35 | from_CHED_ds = ("The results have shown that the degradation product p-choloroaniline is not "
36 | "a significant factor in chlorhexidine-digluconate associated erosive "
37 | "cystitis.")
38 | from_CHED_ds_expected = ['The', 'results', 'have', 'shown', 'that', 'the', 'degradation',
39 | 'product', 'p', '-', 'choloroaniline', 'is', 'not', 'a', 'significant',
40 | 'factor', 'in', 'chlorhexidine', '-', 'digluconate', 'associated',
41 | 'erosive', 'cystitis', '.']
42 | from_DISO_ds = ("Rats were treated with seven day intravenous infusion of fucoidan "
43 | "(30 micrograms h-1) or vehicle.")
44 | from_DISO_expected = ['Rats', 'were', 'treated', 'with', 'seven', 'day', 'intravenous',
45 | 'infusion', 'of', 'fucoidan', '(', '30', 'micrograms', 'h', '-', '1',
46 | ')', 'or', 'vehicle', '.']
47 | from_LIVB_ds = ("Methanoregula formicica sp. nov., a methane-producing archaeon isolated from "
48 | "methanogenic sludge.")
49 | from_LIVB_ds_expected = ['Methanoregula', 'formicica', 'sp', '.', 'nov', '.', ',', 'a',
50 | 'methane', '-', 'producing', 'archaeon', 'isolated', 'from',
51 | 'methanogenic', 'sludge', '.']
52 | from_PRGE_ds = ("Here we report the cloning, expression, and biochemical characterization of "
53 | "the 32-kDa subunit of human (h) TFIID, termed hTAFII32.")
54 | from_PRGE_ds_expected = ['Here', 'we', 'report', 'the', 'cloning', ',', 'expression', ',',
55 | 'and', 'biochemical', 'characterization', 'of', 'the', '32', '-',
56 | 'kDa', 'subunit', 'of', 'human', '(', 'h', ')', 'TFIID', ',', 'termed',
57 | 'hTAFII32', '.']
58 |
59 | # generic tests
60 | assert [t.text for t in nlp(blank_text)] == blank_expected
61 | assert [t.text for t in nlp(simple_text)] == simple_expected
62 | assert [t.text for t in nlp(complicated_test)] == complicated_expected
63 | # tests taken straight from training data
64 | assert [t.text for t in nlp(from_CHED_ds)] == from_CHED_ds_expected
65 | assert [t.text for t in nlp(from_DISO_ds)] == from_DISO_expected
66 | assert [t.text for t in nlp(from_LIVB_ds)] == from_LIVB_ds_expected
67 | assert [t.text for t in nlp(from_PRGE_ds)] == from_PRGE_ds_expected
68 |
--------------------------------------------------------------------------------
/saber/tests/test_trainer.py:
--------------------------------------------------------------------------------
1 | ######################################### PYTEST FIXTURES #########################################
2 |
3 | ############################################ UNIT TESTS ############################################
4 |
--------------------------------------------------------------------------------
/saber/trainer.py:
--------------------------------------------------------------------------------
1 | """Contains the Trainer class, which coordinates the training of Keras ML models for Saber.
2 | """
3 | import logging
4 | import random
5 |
6 | from .utils import data_utils, model_utils
7 |
8 | LOGGER = logging.getLogger(__name__)
9 |
10 | class Trainer(object):
11 | """A class for co-ordinating the training of Keras model(s).
12 |
13 | Args:
14 | config (Config): A Config object which contains a set of harmonized arguments provided in
15 | a *.ini file and, optionally, from the command line.
16 | datasets (list): A list containing one or more Dataset objects.
17 | model (BaseModel): The model to train.
18 | """
19 | def __init__(self, config, datasets, model):
20 | self.config = config # hyperparameters and model details
21 | self.datasets = datasets # dataset(s) tied to this instance
22 | self.model = model # model tied to this instance
23 |
24 | self.output_dir = model_utils.prepare_output_directory(self.config)
25 | self.callbacks = model_utils.setup_callbacks(self.config, self.output_dir)
26 | self.training_data = self.model.prepare_data_for_training()
27 |
28 | def train(self):
29 | """Co-ordinates the training of Keras model(s) at `self.model.models`.
30 |
31 | Coordinates the training of one or more Keras models (given at `self.model.models`). If a
32 | valid or test set is provided (`Dataset.directory['valid']` or `Dataset.directory['test']`
33 | are not None) a simple train/valid/test strategy is used. Otherwise, cross-validation is
34 | used.
35 |
36 | Args:
37 | callbacks (dict): Dictionary containing Keras callback objects.
38 | output_dir (list): List of directories to save model output to, one for each model.
39 | """
40 | # TODO: ugly, is there a better way to check for this? what if dif ds follow dif schemes?
41 | if (self.training_data[0]['x_valid'] is not None or
42 | self.training_data[0]['x_test'] is not None):
43 | self._train_valid_test()
44 | else:
45 | self._cross_validation()
46 |
47 | def _train_valid_test(self):
48 | """Trains a Keras model with a standard train/valid/test strategy.
49 |
50 | Trains a Keras model (`self.model.models`), or models in the case of multi-task learning
51 | (`self.model.models` is a list of Keras models) using a simple train/valid/test strategy.
52 | Minimally expects a train partition and one or both of valid and test partitions to be
53 | supplied in the Dataset objects at `self.datasets`.
54 |
55 | Args:
56 | callbacks (dict): Dictionary containing Keras callback objects.
57 | output_dir (list): List of directories to save model output to, one for each model.
58 | """
59 | print('Using train/test/valid strategy...')
60 | LOGGER.info('Using a train/test/valid strategy for training')
61 | # use 10% of train data as validation data if no validation data provided
62 | if self.training_data[0]['x_valid'] is None:
63 | self.training_data = data_utils.collect_valid_data(self.training_data)
64 | # get list of Keras Callback objects for computing/storing metrics
65 | metrics = model_utils.setup_metrics_callback(config=self.config,
66 | datasets=self.datasets,
67 | training_data=self.training_data,
68 | output_dir=self.output_dir)
69 | # training loop
70 | for epoch in range(self.config.epochs):
71 | print('Global epoch: {}/{}\n{}'.format(epoch + 1, self.config.epochs, '-' * 20))
72 | # get a random ordering of the dataset/model indices
73 | ds_idx = random.sample(range(0, len(self.datasets)), len(self.datasets))
74 | for i in ds_idx:
75 | self.model.models[i].fit(x=self.training_data[i]['x_train'],
76 | y=self.training_data[i]['y_train'],
77 | batch_size=self.config.batch_size,
78 | callbacks=[cb[i] for cb in self.callbacks] + [metrics[i]],
79 | validation_data=(self.training_data[i]['x_valid'],
80 | self.training_data[i]['y_valid']),
81 | verbose=1,
82 | # required for Keras to properly display current epoch
83 | initial_epoch=epoch,
84 | epochs=epoch + 1)
85 |
86 | def _cross_validation(self):
87 | """Trains a Keras model with a cross-validation strategy.
88 |
89 | Trains a Keras model (self.model.models) or models in the case of multi-task learning
90 | (self.model.models is a list of Keras models) using a cross-validation strategy. Expects
91 | only a train partition to be supplied in `training_data`.
92 |
93 | Args:
94 | training_data (dict): a dictionary of dictionaries, where the first set of keys are
95 | dataset indices (0, 1, ...) and the second set of keys are dataset partitions
96 | ('X_train', 'y_train', 'X_valid', ...)
97 | output_dir (lst): a list of output directories, one for each dataset
98 | callbacks: a Keras CallBack object for per epoch model checkpointing.
99 | """
100 | print('Using {}-fold cross-validation strategy...'.format(self.config.k_folds))
101 | LOGGER.info('Using a %s-fold cross-validation strategy for training', self.config.k_folds)
102 | # get the train/valid partitioned data for all datasets and all folds
103 | self.training_data = data_utils.collect_cv_data(self.training_data, self.config.k_folds)
104 | # training loop
105 | for fold in range(self.config.k_folds):
106 | # get list of Keras Callback objects for computing/storing metrics
107 | metrics = model_utils.setup_metrics_callback(config=self.config,
108 | datasets=self.datasets,
109 | training_data=self.training_data,
110 | output_dir=self.output_dir,
111 | fold=fold)
112 | for epoch in range(self.config.epochs):
113 | train_info = (fold + 1, self.config.k_folds, epoch + 1, self.config.epochs)
114 | print('Fold: {}/{}; Global epoch: {}/{}\n{}'.format(*train_info, '-' * 30))
115 | # get a random ordering of the dataset/model indices
116 | ds_idx = random.sample(range(0, len(self.datasets)), len(self.datasets))
117 | for i in ds_idx:
118 | self.model.models[i].fit(
119 | x=self.training_data[i][fold]['x_train'],
120 | y=self.training_data[i][fold]['y_train'],
121 | batch_size=self.config.batch_size,
122 | callbacks=[cb[i] for cb in self.callbacks] + [metrics[i]],
123 | validation_data=(self.training_data[i][fold]['x_valid'],
124 | self.training_data[i][fold]['y_valid']),
125 | verbose=1,
126 | # required for Keras to properly display current epoch
127 | initial_epoch=epoch,
128 | epochs=epoch + 1)
129 |
130 | # clear and rebuild the model at end of each fold (except for the last fold)
131 | if fold < self.config.k_folds - 1:
132 | self._reset_model()
133 |
134 | def _reset_model(self):
135 | """Clears and rebuilds the model at the end of a cross-validation fold.
136 | """
137 | # destroys current TF graph and creates new one, useful for avoiding clutter from old models
138 | # K.clear_session()
139 | self.model.models = []
140 | self.model.specify()
141 | self.model.compile()
142 |
--------------------------------------------------------------------------------
/saber/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BaderLab/saber/876be6bfdb1bc5b18cbcfa848c94b0d20c940f02/saber/utils/__init__.py
--------------------------------------------------------------------------------
/saber/utils/app_utils.py:
--------------------------------------------------------------------------------
1 | """A collection of web-service-related helper/utility functions.
2 | """
3 | import logging
4 | import os.path
5 | import traceback
6 | import xml.etree.ElementTree as ET
7 | from urllib.error import HTTPError
8 | from urllib.request import urlopen
9 |
10 | import tensorflow as tf
11 |
12 | from .. import constants
13 | from ..saber import Saber
14 | from ..utils import generic_utils
15 |
16 | # TODO: Need better error handeling here
17 | LOGGER = logging.getLogger(__name__)
18 |
19 | def get_pubmed_xml(pmid):
20 | """Uses the Entrez Utilities Web Service API to fetch XML representation of PubMed document.
21 |
22 | Args:
23 | pmid (int): the PubMed ID of the abstract to fetch
24 |
25 | Returns:
26 | response from Entrez Utilities Web Service API
27 |
28 | Raises:
29 | ValueError if 'pmid' is not an integer
30 | ValueError if 'pmid' has value less than 1
31 | AssertionError if the requested PubMed ID, 'pmid' and the return PubMedID do not match.
32 | """
33 | if not isinstance(pmid, int):
34 | err_msg = "Argument 'pmid' must be of type {}, not {}.".format(int, type(pmid))
35 | LOGGER.error('ValueError %s', err_msg)
36 | raise ValueError(err_msg)
37 | if pmid < 1:
38 | err_msg = "Argument 'pmid' must have a value of 1 or greater. Got {}".format(pmid)
39 | LOGGER.error('ValueError %s', err_msg)
40 | raise ValueError(err_msg)
41 |
42 | try:
43 | request = '{}{}'.format(constants.EUTILS_API_ENDPOINT, pmid)
44 | response = urlopen(request).read()
45 | except HTTPError:
46 | err_msg = ("HTTP Error 400: Bad Request was returned. Check that the supplied value for "
47 | "'pmid' ({}) is a valid PubMed ID.".format(pmid))
48 | traceback.print_exc()
49 | LOGGER.error('HTTPError %s', err_msg)
50 | print(err_msg)
51 | else:
52 | root = get_root(response)
53 | response_pmid = root.find('PubmedArticle').find('MedlineCitation').find('PMID').text
54 | # ensure that requested and returned pubmed ids are the same
55 | if not int(response_pmid) == pmid:
56 | err_msg = ('Requested PubMed ID and PubMed ID returned by Entrez Utilities Web Service '
57 | 'API do not match.')
58 | LOGGER.error('AssertionError %s', err_msg)
59 | raise AssertionError(err_msg)
60 |
61 | return response
62 |
63 | def get_pubmed_text(pmid):
64 | """Returns the abstract title and text for a given PubMed id using the the Entrez Utilities Web
65 | Service API.
66 |
67 | Args:
68 | pmid (int): the PubMed ID of the abstract to fetch
69 |
70 | Returns:
71 | two-tuple containing the abstract title and text for PubMed ID 'pmid'
72 | """
73 | xml = get_pubmed_xml(pmid)
74 | root = get_root(xml)
75 | # TODO: There has got to be a better way to do this.
76 | # recurse down the xml tree to abstractText
77 | abstract_title = root.find('PubmedArticle').find('MedlineCitation').find('Article').find('ArticleTitle').text
78 | abstract_text = root.find('PubmedArticle').find('MedlineCitation').find('Article').find('Abstract').find('AbstractText').text
79 |
80 | return abstract_title, abstract_text
81 |
82 | def get_root(xml):
83 | """Return root of given XML string.
84 |
85 | Args:
86 | xml (str): a string containing the contents of an XML file.
87 |
88 | Returns:
89 | root of the given XML file, `xml`.
90 | """
91 | return ET.fromstring(xml)
92 |
93 | def load_models(ents):
94 | """Loads a model for each entity in `ents`.
95 |
96 | Given a dict with key (str): value (bool) pairs, loads each model (key) for which value is True.
97 |
98 | Args:
99 | ents (dict): A dictionary where the keys correspond to entities and the values are booleans.
100 |
101 | Returns:
102 | A dictionary with keys representing the model and values a Saber object with a loaded model.
103 | """
104 | models = {} # acc for models
105 | for ent, value in ents.items():
106 | if value:
107 | # create and load the pre-trained models
108 | saber = Saber()
109 | saber.load(ent)
110 | models[ent] = saber
111 | # TEMP: Weird solution to a weird bug.
112 | # https://github.com/tensorflow/tensorflow/issues/14356#issuecomment-385962623
113 | graph = tf.get_default_graph()
114 |
115 | return models, graph
116 |
117 | def harmonize_entities(default_ents, requested_ents):
118 | """Harmonizes two dictionaries representing default_ents and requested requested_ents.
119 |
120 | Given two dictionaries of entity: boolean key: value pairs, returns a
121 | dictionary where the values of entities specified in `requested_ents` override those specified
122 | in `default_ents`. Entities present in `default_ents` but not in `requested_ents` will be set to
123 | False by default.
124 |
125 | Args:
126 | default_ents (dict): contains entity (str): boolean key: value pairs representing which
127 | entities should be annotated in a given text.
128 | requested_ents (dict): contains entity (str): boolean key: value pairs representing which
129 | entities should be predicted in a given text.
130 |
131 | Returns: a dictionary containing all key, value pairs in `default_ents`, where values in
132 | `requested_ents` override those in default_ents. Any key in `default_ents` but not in
133 | `requested_ents` will have its value set to False by default.
134 | """
135 | entities = {}
136 | for ent in default_ents:
137 | entities[ent] = False
138 | for ent, value in requested_ents.items():
139 | if ent in entities:
140 | entities[ent] = value
141 |
142 | return entities
143 |
144 | def parse_request_json(request):
145 | """Returns a dictionary of data parsed from a JSON payload passed in a POST request to Saber.
146 | """
147 | request_json = request.get_json(force=True)
148 | parsed_request_json = {
149 | 'text': request_json.get('text', None),
150 | 'pmid': request_json.get('pmid', None),
151 | 'ents': request_json.get('ents', None),
152 | 'coref': request_json.get('coref', False),
153 | 'ground': request_json.get('ground', False),
154 | }
155 |
156 | # decide which entities to annotate
157 | default_ents, requested_ents = constants.ENTITIES, parsed_request_json['ents']
158 | if requested_ents is not None:
159 | parsed_request_json['ents'] = harmonize_entities(default_ents, requested_ents)
160 | else:
161 | parsed_request_json['ents'] = default_ents
162 |
163 | return parsed_request_json
164 |
165 | def combine_annotations(annotations):
166 | """Given a list of annotations made by a Saber model, combines all annotations under one dict.
167 |
168 | Args:
169 | annotations (list): a list of annotations returned by a Saber model
170 |
171 | Returns:
172 | a dict containing all annotations in `annotations`.
173 | """
174 | combined_anns = []
175 | for ann in annotations:
176 | combined_anns.extend(ann['ents'])
177 | # create json containing combined annotation
178 | return combined_anns
179 |
--------------------------------------------------------------------------------
/saber/utils/generic_utils.py:
--------------------------------------------------------------------------------
1 | """A collection of generic helper/utility functions.
2 | """
3 | import errno
4 | import logging
5 | import os
6 | import shutil
7 |
8 | from setuptools.archive_util import unpack_archive
9 |
10 | LOGGER = logging.getLogger(__name__)
11 |
12 | def is_consecutive(lst):
13 | """Returns True if `lst` contains all numbers from 0 to `len(lst)` with no duplicates.
14 | """
15 | return sorted(lst) == list(range(len(lst)))
16 |
17 | def reverse_dict(mapping):
18 | """Returns a dictionary composed of the reverse v, k pairs of a dictionary `mapping`.
19 | """
20 | return {v: k for k, v in mapping.items()}
21 |
22 | # https://stackoverflow.com/questions/273192/how-can-i-create-a-directory-if-it-does-not-exist#273227
23 | def make_dir(directory):
24 | """Creates a directory at `directory` if it does not already exist.
25 | """
26 | try:
27 | os.makedirs(directory)
28 | except OSError as err:
29 | if err.errno != errno.EEXIST:
30 | raise
31 |
32 | def clean_path(filepath):
33 | """Returns normalized and absolutized `filepath`.
34 | """
35 | filepath = filepath.strip() if isinstance(filepath, str) else filepath
36 | return os.path.abspath(os.path.normpath(filepath))
37 |
38 | def extract_directory(directory):
39 | """Extracts bz2 compressed directory at `directory` if directory is compressed.
40 | """
41 | if not os.path.isdir(directory):
42 | head, _ = os.path.split(os.path.abspath(directory))
43 |
44 | print('Unzipping...', end=' ', flush=True)
45 | unpack_archive(directory + '.tar.bz2', extract_dir=head)
46 |
47 | def compress_directory(directory):
48 | """Compresses a given directory using bz2 compression.
49 |
50 | Raises:
51 | ValueError: if no directory at `directory` exists or if `directory`.tar.bz2 already exists.
52 | """
53 | # clean/normalize directory
54 | directory = os.path.abspath(os.path.normcase(os.path.normpath(directory)))
55 |
56 | # raise ValueError if directory.tar.bz2 already exists or if directory not valid
57 | output_filepath = '{}.tar.bz2'.format(directory)
58 | if os.path.exists(output_filepath):
59 | err_msg = "{} already exists.".format(output_filepath)
60 | LOGGER.error('ValueError %s', err_msg)
61 | raise ValueError(err_msg)
62 | if not os.path.exists(directory):
63 | err_msg = "File or directory at 'directory' does not exist."
64 | LOGGER.error('ValueError %s', err_msg)
65 | raise ValueError(err_msg)
66 |
67 | # create bz2 compressed directory, remove uncompressed directory
68 | root_dir = os.path.abspath(''.join(os.path.split(directory)[:-1]))
69 | base_dir = os.path.basename(directory)
70 | shutil.make_archive(base_name=directory, format='bztar', root_dir=root_dir, base_dir=base_dir)
71 | shutil.rmtree(directory)
72 |
--------------------------------------------------------------------------------
/saber/utils/grounding_utils.py:
--------------------------------------------------------------------------------
1 | """A collection of helper/utility functions for grounding entities.
2 | """
3 | import logging
4 |
5 | import requests
6 |
7 | from .. import constants
8 |
9 | LOGGER = logging.getLogger(__name__)
10 |
11 | def ground(annotation):
12 | """Maps entities in `annotation` to unique indentifiers in an external database or ontology.
13 |
14 | For each entry in `annotation[ents]`, the text representing the annotation (`ent['text']`) is
15 | mapped to a unique identifier in an external database or ontology (if such a unique identifier
16 | is found). Each annotation in `annotation` is updated with an 'xrefs' key which contains a
17 | dictionary with information representing the mapping.
18 |
19 | This function relies on the EXTRACT API to perform the mapping.
20 |
21 | Args:
22 | annotation (dict): A dict containing a list of annotations at key 'ents'. Each annotation
23 | is expected to have a key 'text'.
24 |
25 | Resources:
26 | - EXTRACT 2.0 API: https://extract.jensenlab.org/
27 | """
28 | request = 'https://tagger.jensenlab.org/GetEntities?format=tsv&document='
29 |
30 | # collect annotations made by Saber in a dictionary
31 | annotations = {
32 | 'CHED': [ent for ent in annotation['ents'] if ent['label'] == 'CHED'],
33 | 'DISO': [ent for ent in annotation['ents'] if ent['label'] == 'DISO'],
34 | 'LIVB': [ent for ent in annotation['ents'] if ent['label'] == 'LIVB'],
35 | 'PRGE': [ent for ent in annotation['ents'] if ent['label'] == 'PRGE'],
36 | }
37 |
38 | for label, anns in annotations.items():
39 | if anns:
40 | # prepand to GET request the text to ground along with its entity type
41 | current_request = '{}{}'.format(request, '+'.join([ann['text'] for ann in anns]))
42 | if label in constants.ENTITY_TYPES:
43 | current_request += '&entity_types={}'.format(constants.ENTITY_TYPES[label])
44 |
45 | # request to EXTRACT 2.0 API
46 | response = requests.get(current_request).text
47 | entries = [entry.split('\t') for entry in response.split('\n')] if response else []
48 |
49 | xrefs = {}
50 |
51 | # collect unique identifiers returned by EXTRACT 2.0 API
52 | for entry in entries:
53 | xref = {'namespace': constants.NAMESPACES[label], 'id': entry[-1]}
54 | # in the future, EXTRACT 2.0 API will to assign organism-ids to PRGE labels
55 | if label == 'PRGE':
56 | xref['organism-id'] = entry[1]
57 |
58 | if entry[0] in xrefs:
59 | xrefs[entry[0]].append(xref)
60 | else:
61 | xrefs[entry[0]] = [xref]
62 |
63 | # update annotations with xrefs field
64 | for ann in anns:
65 | if ann['text'] in xrefs:
66 | ann.update(xrefs=xrefs[ann['text']])
67 |
68 | return annotation
69 |
--------------------------------------------------------------------------------
/saber/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | """A collection of model-related helper/utility functions.
2 | """
3 | import os
4 | from time import strftime
5 |
6 | from keras.callbacks import ModelCheckpoint, TensorBoard
7 |
8 | from .. import constants
9 | from ..metrics import Metrics
10 | from .generic_utils import make_dir
11 |
12 | # I/O
13 |
14 | def prepare_output_directory(config):
15 | """Create output directories `config.output_folder/config.dataset_folder` for each dataset.
16 |
17 | Creates the following directory structure:
18 | .
19 | ├── config.output_folder
20 | | └──
21 | | └──
22 | | └── train_session___
__
23 | | └──
24 | | └── train_session___
__
25 | | └──
26 | | └── train_session___
__
27 |
28 | In the case of only a single dataset,
29 | and are
30 | collapsed into a single directory. Saves a copy of the config file used to train the model
31 | (`config`) to the top level of this directory.
32 |
33 | Args:
34 | config (Config): A Config object which contains a set of harmonized arguments provided in
35 | a *.ini file and, optionally, from the command line.
36 |
37 | Returns:
38 | a list of directory paths to the subdirectories
39 | train_session___
__, one for each dataset in `dataset_folder`.
40 | """
41 | output_dirs = []
42 | output_folder = config.output_folder
43 | # if multiple datasets, create additional directory to house all output directories
44 | if len(config.dataset_folder) > 1:
45 | dataset_names = '_'.join([os.path.basename(ds) for ds in config.dataset_folder])
46 | output_folder = os.path.join(output_folder, dataset_names)
47 | make_dir(output_folder)
48 |
49 | for dataset in config.dataset_folder:
50 | # create a subdirectory for each datasets name
51 | dataset_dir = os.path.join(output_folder, os.path.basename(dataset))
52 | # create a subdirectory for each train session
53 | train_session_dir = strftime("train_session_%a_%b_%d_%I_%M_%S").lower()
54 | dataset_train_session_dir = os.path.join(dataset_dir, train_session_dir)
55 | output_dirs.append(dataset_train_session_dir)
56 | make_dir(dataset_train_session_dir)
57 |
58 | # copy config file to top level directory
59 | config.save(dataset_train_session_dir)
60 |
61 | return output_dirs
62 |
63 | def prepare_pretrained_model_dir(config):
64 | """Returns path to top-level directory to save a pre-trained model.
65 |
66 | Returns a directory path to save a pre-trained model based on `config.dataset_folder` and
67 | `config.output_folder`. The folder which contains the saved model is named from each dataset
68 | name in `config.dataset_folder` joined by an underscore:
69 | .
70 | ├── config.output_folder
71 | | └──
72 | | └──
73 |
74 | config (Config): A Config object which contains a set of harmonized arguments provided in
75 | a *.ini file and, optionally, from the command line.
76 |
77 | Returns:
78 | Full path to save a pre-trained model based on `config.dataset_folder` and
79 | `config.dataset_folder`.
80 | """
81 | ds_names = '_'.join([os.path.basename(ds) for ds in config.dataset_folder])
82 | return os.path.join(config.output_folder, constants.PRETRAINED_MODEL_DIR, ds_names)
83 |
84 | # Callbacks
85 |
86 | def setup_checkpoint_callback(config, output_dir):
87 | """Sets up per epoch model checkpointing.
88 |
89 | Sets up model checkpointing by creating a Keras CallBack for each output directory in
90 | `output_dir` (corresponding to individual datasets).
91 |
92 | Args:
93 | output_dir (lst): A list of output directories, one for each dataset.
94 |
95 | Returns:
96 | checkpointer: A Keras CallBack object for per epoch model checkpointing.
97 | """
98 | checkpointers = []
99 | for dir_ in output_dir:
100 | # if only saving best weights, filepath needs to be the same so it gets overwritten
101 | if config.save_all_weights:
102 | filepath = os.path.join(dir_, 'weights_epoch_{epoch:03d}_val_loss_{val_loss:.4f}.hdf5')
103 | else:
104 | filepath = os.path.join(dir_, 'weights_best_epoch.hdf5')
105 |
106 | checkpointer = ModelCheckpoint(filepath=filepath,
107 | monitor='val_loss',
108 | save_best_only=(not config.save_all_weights),
109 | save_weights_only=True)
110 | checkpointers.append(checkpointer)
111 |
112 | return checkpointers
113 |
114 | def setup_tensorboard_callback(output_dir):
115 | """Setup logs for use with TensorBoard.
116 |
117 | This callback writes a log for TensorBoard, which allows you to visualize dynamic graphs of
118 | your training and test metrics, as well as activation histograms for the different layers in
119 | your model. Logs are saved as `tensorboard_logs` at the top level of each directory in
120 | `output_dir`.
121 |
122 | Args:
123 | output_dir (lst): A list of output directories, one for each dataset.
124 |
125 | Returns:
126 | A list of Keras CallBack object for logging TensorBoard visualizations.
127 |
128 | Example:
129 | >>> tensorboard --logdir=/path_to_tensorboard_logs
130 | """
131 | tensorboards = []
132 | for dir_ in output_dir:
133 | tensorboard_dir = os.path.join(dir_, 'tensorboard_logs')
134 | tensorboards.append(TensorBoard(log_dir=tensorboard_dir))
135 |
136 | return tensorboards
137 |
138 | def setup_metrics_callback(config, datasets, training_data, output_dir, fold=None):
139 | """Creates Keras Metrics Callback objects, one for each dataset in `datasets`.
140 |
141 | Args:
142 | config (Config): Contains a set of harmonzied arguments provided in a *.ini file and,
143 | optionally, from the command line.
144 | datasets (list): A list of Dataset objects.
145 | training_data (dict): A dictionary containing training data (inputs and targets).
146 | output_dir (list): List of directories to save model output to, one for each model.
147 | fold (int): The current fold in k-fold cross-validation. Defaults to None.
148 |
149 | Returns:
150 | A list of Metric objects, one for each dataset in `datasets`.
151 | """
152 | metrics = []
153 | for i, dataset in enumerate(datasets):
154 | eval_data = training_data[i] if fold is None else training_data[i][fold]
155 | metric = Metrics(config=config,
156 | training_data=eval_data,
157 | index_map=dataset.idx_to_tag,
158 | output_dir=output_dir[i],
159 | fold=fold)
160 | metrics.append(metric)
161 |
162 | return metrics
163 |
164 | def setup_callbacks(config, output_dir):
165 | """Returns a list of Keras Callback objects to use during training.
166 |
167 | Args:
168 | config (Config): A Config object which contains a set of harmonized arguments provided in
169 | a *.ini file and, optionally, from the command line.
170 | output_dir (list): A list of filepaths, one for each dataset in `self.datasets`.
171 |
172 | Returns:
173 | A list of Keras Callback objects to use during training.
174 | """
175 | callbacks = []
176 | # model checkpointing
177 | callbacks.append(setup_checkpoint_callback(config, output_dir))
178 | # tensorboard
179 | if config.tensorboard:
180 | callbacks.append(setup_tensorboard_callback(output_dir))
181 |
182 | return callbacks
183 |
184 | # Evaluation metrics
185 |
186 | def precision_recall_f1_support(true_positives, false_positives, false_negatives):
187 | """Returns the precision, recall, F1 and support from TP, FP and FN counts.
188 |
189 | Returns a four-tuple containing the precision, recall, F1-score and support
190 | For the given true_positive (TP), false_positive (FP) and
191 | false_negative (FN) counts.
192 |
193 | Args:
194 | true_positives (int): Number of true-positives predicted by classifier.
195 | false_positives (int): Number of false-positives predicted by classifier.
196 | false_negatives (int): Number of false-negatives predicted by classifier.
197 |
198 | Returns:
199 | Four-tuple containing (precision, recall, f1, support).
200 | """
201 | precision = true_positives / (true_positives + false_positives) if true_positives > 0 else 0.
202 | recall = true_positives / (true_positives + false_negatives) if true_positives > 0 else 0.
203 | f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.
204 | support = true_positives + false_negatives
205 |
206 | return precision, recall, f1_score, support
207 |
208 | # Saving/loading
209 |
210 | def load_pretrained_model(config, datasets, weights_filepath, model_filepath):
211 | """Loads a pre-trained Keras model from its pre-trained weights and architecture files.
212 |
213 | Loads a pre-trained Keras model given by its pre-trained weights (`weights_filepath`) and
214 | architecture files (`model_filepath`). The type of model to load is specificed in
215 | `config.model_name`.
216 |
217 | Args:
218 | config (Config): config (Config): A Config object which contains a set of harmonized
219 | arguments provided in a *.ini file and, optionally, from the command line.
220 | datasets (Dataset): A list of Dataset objects.
221 | weights_filepath (str): A filepath to the weights of a pre-trained Keras model.
222 | model_filepath (str): A filepath to the architecture of a pre-trained Keras model.
223 |
224 | Returns:
225 | A pre-trained Keras model.
226 | """
227 | if config.model_name == 'mt-lstm-crf':
228 | from ..models.multi_task_lstm_crf import MultiTaskLSTMCRF
229 | model = MultiTaskLSTMCRF(config, datasets)
230 | model.load(weights_filepath, model_filepath)
231 | model.compile()
232 |
233 | return model
234 |
--------------------------------------------------------------------------------
/saber/utils/text_utils.py:
--------------------------------------------------------------------------------
1 | """A collection of helper/utility functions for processing text.
2 | """
3 | import re
4 |
5 | from spacy.tokenizer import Tokenizer
6 |
7 | # NERsuite-like tokenization: alnum sequences preserved as single
8 | # tokens, rest are single-character tokens.
9 | # https://github.com/spyysalo/standoff2conll/blob/master/common.py
10 | INFIX_RE = re.compile(r'''([0-9a-zA-Z]+|[^0-9a-zA-Z])''')
11 |
12 | # https://spacy.io/usage/linguistic-features#native-tokenizers
13 | def biomedical_tokenizer(nlp):
14 | """
15 | Customizes spaCy's tokenizer class for better handling of biomedical text.
16 | """
17 | return Tokenizer(nlp.vocab, infix_finditer=INFIX_RE.finditer)
18 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="saber",
8 | version="0.1.0-alpha",
9 | author="John Giorgi",
10 | author_email="johnmgiorgi@gmail.com",
11 | license="MIT",
12 | description="Saber: Sequence Annotator for Biomedical Entities and Relations",
13 | long_description=long_description,
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/BaderLab/saber",
16 | python_requires='>=3.5',
17 | packages=setuptools.find_packages(),
18 | classifiers=[
19 | "Development Status :: 3 - Alpha",
20 | "Framework :: Flask",
21 | "License :: OSI Approved :: MIT License",
22 | "Programming Language :: Python :: 3 :: Only",
23 | "Programming Language :: Python :: 3.5",
24 | "Programming Language :: Python :: 3.6",
25 | "Programming Language :: Python :: 3.7",
26 | "Operating System :: OS Independent",
27 | ],
28 | keywords=[
29 | 'Natural Language Processing',
30 | 'Named Entity Recognition',
31 | ],
32 | install_requires=[
33 | 'scikit-learn>=0.20.1',
34 | 'tensorflow>=1.12.0',
35 | 'Flask>=1.0.2',
36 | 'waitress>=1.1.0',
37 | 'keras==2.2.4',
38 | 'PTable>=0.9.2',
39 | 'spacy>=2.0.11, <=2.0.13',
40 | 'gensim>=3.4.0',
41 | 'nltk>=3.3',
42 | 'googledrivedownloader>=0.3',
43 | 'google-compute-engine',
44 | 'msgpack==0.5.6',
45 | 'keras-contrib @ git+https://www.github.com/keras-team/keras-contrib.git',
46 | 'en-coref-md @ https://github.com/huggingface/neuralcoref-models/releases/download/en_coref_md-3.0.0/en_coref_md-3.0.0.tar.gz',
47 | ],
48 | include_package_data=True,
49 | # allows us to install + run tests with `python setup.py test`
50 | # https://docs.pytest.org/en/latest/goodpractices.html#integrating-with-setuptools-python-setup-py-test-pytest-runner
51 | setup_requires=['pytest-runner'],
52 | tests_require=['pytest'],
53 | zip_safe=False,
54 | )
55 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist =
3 | manifest
4 | pyroma
5 | py
6 |
7 | [testenv]
8 | commands = pytest --cov=saber -v
9 | deps =
10 | pytest
11 | pytest-cov
12 | coveralls
13 |
14 | [testenv:manifest]
15 | deps = check-manifest
16 | skip_install = true
17 | commands = check-manifest
18 |
19 | [testenv:pyroma]
20 | deps =
21 | pygments
22 | pyroma
23 | skip_install = true
24 | commands = pyroma --min=10 .
25 | description = Run the pyroma tool to check the project's package friendliness.
26 |
27 | [testenv:coverage-clean]
28 | deps = coverage
29 | skip_install = true
30 | commands = coverage erase
31 |
32 | [testenv:coverage-report]
33 | deps = coverage
34 | skip_install = true
35 | commands =
36 | coverage combine
37 | coverage report
38 |
39 | ####################
40 | # Deployment tools #
41 | ####################
42 | [testenv:build]
43 | skip_install = true
44 | deps =
45 | wheel
46 | setuptools
47 | commands =
48 | python setup.py -q sdist bdist_wheel
49 |
50 | [testenv:release]
51 | skip_install = true
52 | deps =
53 | {[testenv:build]deps}
54 | twine >= 1.5.0
55 | commands =
56 | {[testenv:build]commands}
57 | twine upload --skip-existing dist/*
58 |
--------------------------------------------------------------------------------