├── .nojekyll ├── docs ├── .nojekyll ├── index.html ├── requirements.txt ├── _build │ ├── html │ │ ├── objects.inv │ │ ├── _static │ │ │ ├── file.png │ │ │ ├── minus.png │ │ │ ├── plus.png │ │ │ ├── fonts │ │ │ │ ├── Lato-Bold.ttf │ │ │ │ ├── Inconsolata.ttf │ │ │ │ ├── Lato-Regular.ttf │ │ │ │ ├── Lato │ │ │ │ │ ├── lato-bold.eot │ │ │ │ │ ├── lato-bold.ttf │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-italic.eot │ │ │ │ │ ├── lato-italic.ttf │ │ │ │ │ ├── lato-italic.woff │ │ │ │ │ ├── lato-italic.woff2 │ │ │ │ │ ├── lato-regular.eot │ │ │ │ │ ├── lato-regular.ttf │ │ │ │ │ ├── lato-regular.woff │ │ │ │ │ ├── lato-regular.woff2 │ │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ │ └── lato-bolditalic.woff2 │ │ │ │ ├── Inconsolata-Bold.ttf │ │ │ │ ├── RobotoSlab-Bold.ttf │ │ │ │ ├── Inconsolata-Regular.ttf │ │ │ │ ├── RobotoSlab-Regular.ttf │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ └── RobotoSlab │ │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── css │ │ │ │ ├── fonts │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-normal.woff │ │ │ │ │ ├── lato-normal.woff2 │ │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ │ ├── lato-normal-italic.woff2 │ │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ │ └── fontawesome-webfont.woff2 │ │ │ │ └── badge_only.css │ │ │ ├── documentation_options.js │ │ │ └── js │ │ │ │ ├── badge_only.js │ │ │ │ ├── html5shiv.min.js │ │ │ │ ├── html5shiv-printshiv.min.js │ │ │ │ └── theme.js │ │ ├── .buildinfo │ │ ├── _sources │ │ │ ├── wellcomeml.io.epmc.rst.txt │ │ │ ├── wellcomeml.spacy.rst.txt │ │ │ ├── wellcomeml.metrics.rst.txt │ │ │ ├── wellcomeml.io.rst.txt │ │ │ ├── wellcomeml.rst.txt │ │ │ ├── wellcomeml.viz.rst.txt │ │ │ ├── wellcomeml.datasets.rst.txt │ │ │ ├── modules.md.txt │ │ │ ├── index.rst.txt │ │ │ └── wellcomeml.ml.rst.txt │ │ └── search.html │ └── doctrees │ │ ├── index.doctree │ │ ├── modules.doctree │ │ ├── clustering.doctree │ │ ├── environment.pickle │ │ ├── examples.doctree │ │ ├── wellcomeml.doctree │ │ ├── wellcomeml.io.doctree │ │ ├── wellcomeml.ml.doctree │ │ ├── wellcomeml.viz.doctree │ │ ├── wellcomeml.spacy.doctree │ │ ├── wellcomeml.datasets.doctree │ │ ├── wellcomeml.io.epmc.doctree │ │ └── wellcomeml.metrics.doctree ├── wellcomeml.io.epmc.rst ├── wellcomeml.spacy.rst ├── Makefile ├── wellcomeml.metrics.rst ├── wellcomeml.io.rst ├── wellcomeml.rst ├── wellcomeml.viz.rst ├── make.bat ├── wellcomeml.datasets.rst ├── index.rst ├── conf.py └── wellcomeml.ml.rst ├── tests ├── __init__.py ├── test_data │ ├── mock_s3_contents.json.gz │ ├── mock_winer_CoarseNE.tar.bz2 │ ├── mock_winer_Documents.tar.bz2 │ ├── mock_winer_document.vocab │ ├── test_conll │ └── test_jsonl.jsonl ├── test_logger.py ├── common.py ├── test_sent2vec.py ├── test_heatmap.py ├── test_palettes.py ├── datasets │ ├── test_conll.py │ └── test_winer.py ├── test_clustering_visualisation.py ├── test_doc2vec.py ├── test_vectorizer.py ├── test_frequency_vectorizer.py ├── ml │ └── test_keras_utils.py ├── test_extras.py ├── test_spacy_classifier.py ├── test_bert_vectorizer.py ├── test_io.py ├── test_clustering.py ├── test_entity_linking.py ├── io │ └── epmc │ │ └── test_client.py ├── metrics │ └── test_f1.py ├── test_transformers_tokenizer.py ├── test_keras_vectorizer.py ├── test_spacy_entity_linking.py ├── test_s3_policy_data.py ├── test_bert_classifier.py ├── test_spacy.py └── test_ner_spacy.py ├── wellcomeml ├── ml │ ├── __init__.py │ ├── constants.py │ ├── sent2vec_vectorizer.py │ ├── attention.py │ ├── bert_vectorizer.py │ ├── frequency_vectorizer.py │ ├── vectorizer.py │ └── voting_classifier.py ├── viz │ ├── __init__.py │ └── palettes.py ├── io │ ├── epmc │ │ └── __init__.py │ ├── __init__.py │ ├── io.py │ └── s3_policy_data.py ├── datasets │ ├── __init__.py │ ├── hoc.py │ ├── download.py │ └── conll.py ├── spacy │ ├── __init__.py │ └── spacy_doc_to_prodigy.py ├── metrics │ ├── __init__.py │ ├── f1.py │ └── ner_classification_report.py ├── __init__.py ├── __version__.py ├── __main__.py ├── logger.py └── utils.py ├── .flake8 ├── requirements_test.txt ├── pull_request_template.md ├── pytest.ini ├── examples ├── spacy_classifier.py ├── bert_classifier.py ├── epmc_client.py ├── bert_embeddings.py ├── policy_docs_from_s3.py ├── cnn_classifier.py ├── bilstm_classifier.py ├── visualize_clusters.py ├── sent2vec_vectorizer.py ├── doc2vec.py ├── bert_text_similarity_fine_tune.py ├── heatmap.py ├── text_clustering.py ├── entity_linking.py └── voting_classifier_ensemble.py ├── tox.ini ├── .gitignore ├── codecov.yml ├── .travis.yml ├── WINDOWS_USERS.md ├── .github └── workflows │ └── main.yml ├── LICENSE ├── create_release.sh ├── setup.py └── Makefile /.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wellcomeml/ml/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wellcomeml/ml/constants.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wellcomeml/viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wellcomeml/io/epmc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 99 3 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | codecov 4 | ipython 5 | tox 6 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | recommonmark 2 | sphinx 3 | sphinx-rtd-theme 4 | sphinx-markdown-tables 5 | -------------------------------------------------------------------------------- /wellcomeml/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.datasets.hoc import load_hoc 2 | 3 | __all__ = ["load_hoc"] 4 | -------------------------------------------------------------------------------- /docs/_build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/objects.inv -------------------------------------------------------------------------------- /docs/_build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/_build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/file.png -------------------------------------------------------------------------------- /docs/_build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/_build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/plus.png -------------------------------------------------------------------------------- /wellcomeml/spacy/__init__.py: -------------------------------------------------------------------------------- 1 | from .spacy_doc_to_prodigy import SpacyDocToProdigy 2 | 3 | __all__ = ['SpacyDocToProdigy'] 4 | -------------------------------------------------------------------------------- /docs/_build/doctrees/modules.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/modules.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/clustering.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/clustering.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/_build/doctrees/examples.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/examples.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.doctree -------------------------------------------------------------------------------- /tests/test_data/mock_s3_contents.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/tests/test_data/mock_s3_contents.json.gz -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.io.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.io.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.ml.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.ml.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.viz.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.viz.doctree -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /tests/test_data/mock_winer_CoarseNE.tar.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/tests/test_data/mock_winer_CoarseNE.tar.bz2 -------------------------------------------------------------------------------- /tests/test_data/mock_winer_Documents.tar.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/tests/test_data/mock_winer_Documents.tar.bz2 -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.spacy.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.spacy.doctree -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.datasets.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.datasets.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.io.epmc.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.io.epmc.doctree -------------------------------------------------------------------------------- /docs/_build/doctrees/wellcomeml.metrics.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/doctrees/wellcomeml.metrics.doctree -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /pull_request_template.md: -------------------------------------------------------------------------------- 1 | Description 2 | --- 3 | 4 | Checklist 5 | --- 6 | 7 | - [ ] Added link to Github issue or Notion card 8 | - [ ] Added tests 9 | -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.logger import logger 2 | 3 | 4 | def test_logging(): 5 | """Tests the logger name""" 6 | assert logger.name == 'wellcomeml.logger' 7 | -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wellcometrust/WellcomeML/HEAD/docs/_build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /tests/test_data/mock_winer_document.vocab: -------------------------------------------------------------------------------- 1 | this 123 2 | is 345 3 | a 777 4 | sentence 33 5 | about 12 6 | james 43 7 | bond 3 8 | another 8 9 | just 2 10 | street 99 11 | jane 14 12 | oxford 9 -------------------------------------------------------------------------------- /wellcomeml/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .f1 import f1_loss, f1_metric 2 | from .ner_classification_report import ner_classification_report 3 | 4 | __all__ = ["ner_classification_report", "f1_metric", "f1_loss"] 5 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | 4 | def get_path(p): 5 | return os.path.join( 6 | os.path.dirname(__file__), 7 | p 8 | ) 9 | 10 | 11 | TEST_JSONL = get_path('test_data/test_jsonl.jsonl') 12 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --strict-markers 3 | markers = 4 | integration: integration tests 5 | bert: tests that use bert (usually heavy tests) 6 | extras: tests that will install/uninstall a library/load/reload a module 7 | -------------------------------------------------------------------------------- /wellcomeml/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .io import read_jsonl, write_jsonl 2 | from .epmc.client import EPMCClient 3 | from .s3_policy_data import PolicyDocumentsDownloader 4 | 5 | __all__ = ['read_jsonl', 'write_jsonl', 'PolicyDocumentsDownloader', 'EPMCClient'] 6 | -------------------------------------------------------------------------------- /docs/_build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 52103617ea8272af3d7f8ac1d726093b 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /examples/spacy_classifier.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.spacy_classifier import SpacyClassifier 2 | 3 | X = ["One, three", "one", "two, three"] 4 | Y = [[1, 0, 1], [1, 0, 0], [0, 1, 1]] 5 | 6 | spacy_classifier = SpacyClassifier() 7 | spacy_classifier.fit(X, Y) 8 | print(spacy_classifier.score(X, Y)) 9 | -------------------------------------------------------------------------------- /examples/bert_classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from wellcomeml.ml.bert_classifier import BertClassifier 4 | 5 | X = ["Hot and cold", "Hot", "Cold"] 6 | Y = np.array([[1, 1], [1, 0], [0, 1]]) 7 | 8 | bert = BertClassifier(batch_size=8) 9 | bert.fit(X, Y) 10 | print(bert.score(X, Y)) 11 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37, py38 3 | 4 | [testenv] 5 | deps = 6 | -r requirements_test.txt 7 | .[all] 8 | 9 | commands = python -m spacy download en_core_web_sm 10 | pytest -m '{env:TEST_SUITE:}' -s -v --durations=0 --disable-warnings --tb=line --cov=wellcomeml ./tests 11 | 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | **/build/* 3 | *.egg-info/* 4 | **/*.egg-info/* 5 | 6 | *.DS_Store 7 | *.pyc 8 | *~ 9 | **/.ipynb_checkpoints* 10 | 11 | # Credentials 12 | 13 | # IDE dirs 14 | *.idea/ 15 | 16 | **/.mypy_cache/ 17 | **/.envrc 18 | *.coverage 19 | notebooks 20 | logs/* 21 | 22 | .tox/ 23 | .python-version 24 | 25 | sent2vec 26 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 10% 9 | patch: 10 | default: 11 | enabled: no 12 | if_not_found: success 13 | changes: 14 | default: 15 | enabled: no 16 | if_not_found: success 17 | -------------------------------------------------------------------------------- /wellcomeml/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import ( 2 | __name__, __version__, __description__, __url__, 3 | __author__, __author_email__, __license__ 4 | ) 5 | from .logger import logger 6 | 7 | __all__ = [ 8 | '__name__', '__version__', '__description__', '__url__', 9 | '__author__', '__author_email__', '__license__', 'logger' 10 | ] 11 | -------------------------------------------------------------------------------- /examples/epmc_client.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.io.epmc.client import EPMCClient 2 | 3 | client = EPMCClient(max_retries=3) 4 | session = client.requests_session() 5 | pmid = "34215990" 6 | 7 | references = client.get_references(session, pmid) 8 | print(f"Found {len(references)} references") 9 | 10 | result = client.search_by_pmid(session, pmid) 11 | print(f"Found pub with keys {result.keys()}") 12 | -------------------------------------------------------------------------------- /wellcomeml/__version__.py: -------------------------------------------------------------------------------- 1 | __name__ = "wellcomeml" 2 | __version__ = "2.0.3" 3 | __description__ = """Utilities for managing nlp models and for 4 | processing text-related data at the Wellcome Trust""" 5 | __url__ = "https://github.com/wellcometrust/wellcomeml/tree/main" 6 | __author__ = "Wellcome Trust Data Science Team" 7 | __author_email__ = "Grp_datalabs-datascience@Wellcomecloud.onmicrosoft.com" 8 | __license__ = "MIT" 9 | -------------------------------------------------------------------------------- /docs/_build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '2.0.3', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false 12 | }; -------------------------------------------------------------------------------- /docs/wellcomeml.io.epmc.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.io.epmc package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.io.epmc.client module 8 | -------------------------------- 9 | 10 | .. automodule:: wellcomeml.io.epmc.client 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: wellcomeml.io.epmc 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.io.epmc.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.io.epmc package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.io.epmc.client module 8 | -------------------------------- 9 | 10 | .. automodule:: wellcomeml.io.epmc.client 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: wellcomeml.io.epmc 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/wellcomeml.spacy.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.spacy package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.spacy.spacy\_doc\_to\_prodigy module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: wellcomeml.spacy.spacy_doc_to_prodigy 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: wellcomeml.spacy 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.spacy.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.spacy package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.spacy.spacy\_doc\_to\_prodigy module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: wellcomeml.spacy.spacy_doc_to_prodigy 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: wellcomeml.spacy 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: bionic 2 | 3 | language: python 4 | 5 | python: 6 | - 3.7 7 | - 3.8 8 | 9 | install: 10 | - pip install -r requirements_test.txt 11 | - pip install tox-travis 12 | 13 | env: 14 | jobs: 15 | - TEST_SUITE='bert' 16 | - TEST_SUITE='not bert' 17 | 18 | global: 19 | - TF_CPP_MIN_LOG_LEVEL=2 20 | - DISABLE_DIRECT_IMPORTS=1 21 | 22 | script: 23 | - tox 24 | - pip freeze 25 | 26 | cache: pip 27 | 28 | branches: 29 | only: 30 | - main 31 | - feature/visualisation 32 | 33 | after_success: 34 | - python -m codecov 35 | -------------------------------------------------------------------------------- /examples/bert_embeddings.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import cosine_similarity 2 | from sklearn.pipeline import Pipeline 3 | from sklearn.svm import SVC 4 | 5 | from wellcomeml.ml.bert_vectorizer import BertVectorizer 6 | 7 | 8 | X = [ 9 | "Elizabeth is the queen of England", 10 | "Felipe is the king of Spain", 11 | "I like to travel", 12 | ] 13 | y = [1, 1, 0] 14 | 15 | vectorizer = BertVectorizer() 16 | X_transformed = vectorizer.fit_transform(X) 17 | print(cosine_similarity(X_transformed)) 18 | 19 | pipeline = Pipeline([("bert", BertVectorizer()), ("svm", SVC(kernel="linear"))]) 20 | pipeline.fit(X, y) 21 | print(pipeline.score(X, y)) 22 | -------------------------------------------------------------------------------- /examples/policy_docs_from_s3.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | 3 | from wellcomeml.io import PolicyDocumentsDownloader 4 | 5 | word_list = ["malaria"] 6 | 7 | s3 = boto3.client("s3") 8 | policy_s3 = PolicyDocumentsDownloader( 9 | s3=s3, 10 | bucket_name="datalabs-dev", 11 | dir_path="reach-airflow/output/policy/parsed-pdfs", 12 | ) 13 | hash_dicts = policy_s3.get_hashes(word_list=word_list) 14 | 15 | hash_list = [hash_dict["file_hash"] for hash_dict in hash_dicts] 16 | 17 | print(hash_list[0:10]) 18 | 19 | documents = policy_s3.download(hash_list=hash_list[0:10]) 20 | 21 | # Get the first 100 characters of the text from 22 | # these documents 23 | print([d["text"][0:100] for d in documents]) 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/wellcomeml.metrics.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.metrics package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.metrics.f1 module 8 | ---------------------------- 9 | 10 | .. automodule:: wellcomeml.metrics.f1 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.metrics.ner\_classification\_report module 16 | ----------------------------------------------------- 17 | 18 | .. automodule:: wellcomeml.metrics.ner_classification_report 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: wellcomeml.metrics 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /examples/cnn_classifier.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.cnn import CNNClassifier 2 | from wellcomeml.ml.keras_vectorizer import KerasVectorizer 3 | from sklearn.pipeline import Pipeline 4 | 5 | import numpy as np 6 | 7 | X = ["One", "three", "one", "two", "four"] 8 | Y = np.array([1, 0, 1, 0, 0]) 9 | 10 | cnn_pipeline = Pipeline([("vec", KerasVectorizer()), ("clf", CNNClassifier())]) 11 | cnn_pipeline.fit(X, Y) 12 | print(cnn_pipeline.score(X, Y)) 13 | 14 | X = ["One, three", "one", "two, three"] 15 | Y = np.array([[1, 0, 1], [1, 0, 0], [0, 1, 1]]) 16 | 17 | cnn_pipeline = Pipeline( 18 | [("vec", KerasVectorizer()), ("clf", CNNClassifier(multilabel=True))] 19 | ) 20 | cnn_pipeline.fit(X, Y) 21 | print(cnn_pipeline.score(X, Y)) 22 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.metrics.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.metrics package 2 | ========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.metrics.f1 module 8 | ---------------------------- 9 | 10 | .. automodule:: wellcomeml.metrics.f1 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.metrics.ner\_classification\_report module 16 | ----------------------------------------------------- 17 | 18 | .. automodule:: wellcomeml.metrics.ner_classification_report 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: wellcomeml.metrics 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /tests/test_sent2vec.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from wellcomeml.ml.sent2vec_vectorizer import Sent2VecVectorizer 4 | 5 | 6 | @pytest.mark.skip(reason="Consumes too much memory") 7 | def test_fit_transform(): 8 | X = [ 9 | "Malaria is a disease that kills people", 10 | "Heart problems comes first in the global burden of disease", 11 | "Wellcome also funds policy and culture research" 12 | ] 13 | sent2vec = Sent2VecVectorizer("sent2vec_wiki_unigrams") 14 | sent2vec.fit(X) 15 | X_vec = sent2vec.transform(X) 16 | assert X_vec.shape == (3, 600) 17 | 18 | 19 | def test_fit(): 20 | with pytest.raises(NotImplementedError): 21 | sent2vec = Sent2VecVectorizer() 22 | sent2vec.fit(["Custom data"]) 23 | -------------------------------------------------------------------------------- /examples/bilstm_classifier.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.bilstm import BiLSTMClassifier 2 | from wellcomeml.ml.keras_vectorizer import KerasVectorizer 3 | from sklearn.pipeline import Pipeline 4 | 5 | import numpy as np 6 | 7 | X = ["One", "three", "one", "two", "four"] 8 | Y = np.array([1, 0, 1, 0, 0]) 9 | 10 | bilstm_pipeline = Pipeline([("vec", KerasVectorizer()), ("clf", BiLSTMClassifier())]) 11 | bilstm_pipeline.fit(X, Y) 12 | print(bilstm_pipeline.score(X, Y)) 13 | 14 | X = ["One, three", "one", "two, three"] 15 | Y = np.array([[1, 0, 1], [1, 0, 0], [0, 1, 1]]) 16 | 17 | bilstm_pipeline = Pipeline( 18 | [("vec", KerasVectorizer()), ("clf", BiLSTMClassifier(multilabel=True))] 19 | ) 20 | bilstm_pipeline.fit(X, Y) 21 | print(bilstm_pipeline.score(X, Y)) 22 | -------------------------------------------------------------------------------- /examples/visualize_clusters.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from wellcomeml.ml.clustering import TextClustering 3 | from wellcomeml.viz.visualize_clusters import visualize_clusters 4 | 5 | url = "https://datalabs-public.s3.eu-west-2.amazonaws.com/" \ 6 | "datasets/epmc/random_sample.csv" 7 | data = pd.read_csv(url) 8 | 9 | text = list(data['text']) 10 | 11 | clustering = TextClustering(embedding='tf-idf', reducer='umap', params={ 12 | 'reducer': {'min_dist': 0.1, 'n_neighbors': 10}, 13 | 'vectorizer': {'min_df': 0.0002}, 14 | 'clustering': {'min_samples': 20, 'eps': 0.2} 15 | }) 16 | 17 | clustering.fit(text) 18 | 19 | visualize_clusters(clustering, 0.05, 0.8, output_in_notebook=False, 20 | output_file_path="test.html") 21 | -------------------------------------------------------------------------------- /docs/wellcomeml.io.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.io package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 5 9 | 10 | wellcomeml.io.epmc 11 | 12 | Submodules 13 | ---------- 14 | 15 | wellcomeml.io.io module 16 | ----------------------- 17 | 18 | .. automodule:: wellcomeml.io.io 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.io.s3\_policy\_data module 24 | ------------------------------------- 25 | 26 | .. automodule:: wellcomeml.io.s3_policy_data 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: wellcomeml.io 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /WINDOWS_USERS.md: -------------------------------------------------------------------------------- 1 | # Installation on Windows machines 2 | 3 | ## Requirements 4 | This package requires an up to date Windows 10, as well as the following: 5 | - Visual Studio Build tools 2019 with `Desktop Development with C++` installed 6 | - Python 3.8 installed at the root of your machine (the Makefile will look for it in C://Python38) 7 | - Administration rights 8 | - GNU Make 9 | - The Makefile will assume the `OS` environment variable is set to its default value 10 | - (Optional but recommended) Cygwin 11 | 12 | 13 | ## Installation 14 | Run the following Makefile: 15 | `make virtualenv` 16 | 17 | ## Tests 18 | Running tests might take a bit of time on the first run, as you will need to download some models and build a few libraries. 19 | `make test` 20 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.io.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.io package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 5 9 | 10 | wellcomeml.io.epmc 11 | 12 | Submodules 13 | ---------- 14 | 15 | wellcomeml.io.io module 16 | ----------------------- 17 | 18 | .. automodule:: wellcomeml.io.io 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.io.s3\_policy\_data module 24 | ------------------------------------- 25 | 26 | .. automodule:: wellcomeml.io.s3_policy_data 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: wellcomeml.io 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /examples/sent2vec_vectorizer.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import cosine_similarity 2 | from sklearn.linear_model import SGDClassifier 3 | from sklearn.pipeline import Pipeline 4 | 5 | 6 | from wellcomeml.ml.sent2vec_vectorizer import Sent2VecVectorizer 7 | 8 | X = [ 9 | "Malaria is a disease spread by mosquitos", 10 | "HIV is a virus that causes a disease named AIDS", 11 | "Trump is the president of USA", 12 | ] 13 | 14 | sent2vec = Sent2VecVectorizer("sent2vec_wiki_unigrams") 15 | X_transformed = sent2vec.fit_transform(X) 16 | print(cosine_similarity(X_transformed)) 17 | 18 | y = [1, 1, 0] 19 | 20 | model = Pipeline( 21 | [ 22 | ("sent2vec", Sent2VecVectorizer("sent2vec_wiki_unigrams")), 23 | ("sgd", SGDClassifier()), 24 | ] 25 | ) 26 | model.fit(X, y) 27 | -------------------------------------------------------------------------------- /docs/wellcomeml.rst: -------------------------------------------------------------------------------- 1 | wellcomeml package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 5 9 | 10 | wellcomeml.datasets 11 | wellcomeml.io 12 | wellcomeml.metrics 13 | wellcomeml.ml 14 | wellcomeml.spacy 15 | wellcomeml.viz 16 | 17 | Submodules 18 | ---------- 19 | 20 | wellcomeml.logger module 21 | ------------------------ 22 | 23 | .. automodule:: wellcomeml.logger 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | wellcomeml.utils module 29 | ----------------------- 30 | 31 | .. automodule:: wellcomeml.utils 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Module contents 37 | --------------- 38 | 39 | .. automodule:: wellcomeml 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /examples/doc2vec.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import cosine_similarity 2 | from sklearn.linear_model import SGDClassifier 3 | from sklearn.pipeline import Pipeline 4 | 5 | 6 | from wellcomeml.ml.doc2vec_vectorizer import Doc2VecVectorizer 7 | 8 | X = [ 9 | "Malaria is a disease spread by mosquitos", 10 | "HIV is a virus that causes a disease named AIDS", 11 | "Trump is the president of USA", 12 | ] 13 | 14 | doc2vec = Doc2VecVectorizer(min_count=1, vector_size=8, sample=0, negative=1) 15 | X_transformed = doc2vec.fit_transform(X) 16 | print(cosine_similarity(X_transformed)) 17 | 18 | y = [1, 1, 0] 19 | 20 | model = Pipeline( 21 | [ 22 | ("doc2vec", Doc2VecVectorizer(min_count=1, vector_size=8)), 23 | ("sgd", SGDClassifier()), 24 | ] 25 | ) 26 | model.fit(X, y) 27 | print(model.score(X, y)) 28 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 5 9 | 10 | wellcomeml.datasets 11 | wellcomeml.io 12 | wellcomeml.metrics 13 | wellcomeml.ml 14 | wellcomeml.spacy 15 | wellcomeml.viz 16 | 17 | Submodules 18 | ---------- 19 | 20 | wellcomeml.logger module 21 | ------------------------ 22 | 23 | .. automodule:: wellcomeml.logger 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | wellcomeml.utils module 29 | ----------------------- 30 | 31 | .. automodule:: wellcomeml.utils 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Module contents 37 | --------------- 38 | 39 | .. automodule:: wellcomeml 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /docs/wellcomeml.viz.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.viz package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.viz.colors module 8 | ---------------------------- 9 | 10 | .. automodule:: wellcomeml.viz.colors 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.viz.palettes module 16 | ------------------------------ 17 | 18 | .. automodule:: wellcomeml.viz.palettes 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.viz.visualize\_clusters module 24 | ----------------------------------------- 25 | 26 | .. automodule:: wellcomeml.viz.visualize_clusters 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: wellcomeml.viz 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.viz.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.viz package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.viz.colors module 8 | ---------------------------- 9 | 10 | .. automodule:: wellcomeml.viz.colors 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.viz.palettes module 16 | ------------------------------ 17 | 18 | .. automodule:: wellcomeml.viz.palettes 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.viz.visualize\_clusters module 24 | ----------------------------------------- 25 | 26 | .. automodule:: wellcomeml.viz.visualize_clusters 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: wellcomeml.viz 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/bert_text_similarity_fine_tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | from wellcomeml.ml.bert_semantic_equivalence import SemanticEquivalenceClassifier 5 | 6 | data_file_path = os.path.join( 7 | os.path.dirname(__file__), "data/text_similarity_sample_100_pairs.csv" 8 | ) 9 | 10 | # Reads sample data and formats it 11 | df = pd.read_csv(data_file_path) 12 | 13 | X = df[["text_1", "text_2"]].values.tolist() 14 | y = df["label"].values 15 | 16 | # Define the classifier and fits for 1 epoch 17 | classifier = SemanticEquivalenceClassifier( 18 | pretrained="scibert", batch_size=8, eval_batch_size=16 19 | ) 20 | 21 | classifier.fit(X, y, epochs=1) 22 | 23 | test_pair = ( 24 | "the FCC will not request personal identifying information ", 25 | "personal information will not be requested by the FCC", 26 | ) 27 | 28 | score_related = classifier.predict_proba([test_pair]) 29 | 30 | print(f"Sentences are probably related with score {score_related[0][1]}.") 31 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | # Controls when the action will run. Triggers the workflow on push or pull request 6 | # events but only for the master branch 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 14 | jobs: 15 | flake8_check: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Setup Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.7.12 23 | architecture: x64 24 | - name: Checkout Master 25 | uses: actions/checkout@master 26 | - name: Install flake8 27 | run: pip install flake8 28 | - name: Run flak8 29 | uses: suo/flake8-github-action@v1 30 | with: 31 | checkName: flake8_check 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /wellcomeml/__main__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | 5 | def download(download_target): 6 | if download_target == "models": 7 | subprocess.run([ 8 | 'python', '-m', 'spacy', 'download', 'en_core_web_sm']) 9 | elif download_target == "deeplearning-models": 10 | subprocess.run([ 11 | 'python', '-m', 'spacy', 'download', 'en_core_web_trf']) 12 | elif download_target == "non_pypi_packages": 13 | # This is a workaround to pin sent2vec 14 | sent_2_vec_commit = 'f00a1b67f4330e5be99e7cc31ac28df94deed9ac' 15 | 16 | subprocess.run([ 17 | 'pip', 'install', f'git+https://github.com/epfml/sent2vec.git@{sent_2_vec_commit}']) 18 | else: 19 | print(f"{download_target} is not one of models,deeplearning-models") 20 | 21 | 22 | if __name__ == '__main__': 23 | command = sys.argv.pop(1) 24 | if command != "download": 25 | print("Only available command is download") 26 | exit() 27 | 28 | download_target = sys.argv.pop(1) 29 | download(download_target) 30 | -------------------------------------------------------------------------------- /docs/wellcomeml.datasets.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.datasets package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.datasets.conll module 8 | -------------------------------- 9 | 10 | .. automodule:: wellcomeml.datasets.conll 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.datasets.download module 16 | ----------------------------------- 17 | 18 | .. automodule:: wellcomeml.datasets.download 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.datasets.hoc module 24 | ------------------------------ 25 | 26 | .. automodule:: wellcomeml.datasets.hoc 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | wellcomeml.datasets.winer module 32 | -------------------------------- 33 | 34 | .. automodule:: wellcomeml.datasets.winer 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: wellcomeml.datasets 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /wellcomeml/datasets/hoc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import csv 3 | import os 4 | 5 | from wellcomeml.datasets.download import check_cache_and_download 6 | 7 | 8 | def load_split(data_path): 9 | X = [] 10 | Y = [] 11 | with open(data_path) as f: 12 | csvreader = csv.DictReader(f, delimiter="\t") 13 | for line in csvreader: 14 | X.append(line["sentence"]) 15 | Y.append(line["labels"].split(",")) 16 | return X, Y 17 | 18 | 19 | def load_hoc(split="train", shuffle=True): 20 | path = check_cache_and_download("hoc") 21 | 22 | if split == "train": 23 | train_data_path = os.path.join(path, "train.tsv") 24 | X, Y = load_split(train_data_path) 25 | elif split == "test": 26 | test_data_path = os.path.join(path, "test.tsv") 27 | X, Y = load_split(test_data_path) 28 | else: 29 | raise ValueError(f"Split argument {split} is not one of train or test") 30 | 31 | if shuffle: 32 | data = list(zip(X, Y)) 33 | random.shuffle(data) 34 | X, Y = zip(*data) 35 | 36 | return X, Y 37 | -------------------------------------------------------------------------------- /examples/heatmap.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.viz.elements import plot_heatmap 2 | 3 | # Fake co_occurrence matrix of the concepts "Data Science", "Machine Learning" and "Science" 4 | 5 | co_occurrence = [ 6 | {"concept_1": "Data Science", "concept_2": "Machine Learning", "value": 1, "abbr": "DS/ML"}, 7 | {"concept_1": "Machine Learning", "concept_2": "Data Science", "value": 0.3, "abbr": "ML/DS"}, 8 | {"concept_1": "Science", "concept_2": "Data Science", "value": 1, "abbr": "S/DS"}, 9 | {"concept_1": "Science", "concept_2": "Machine Learning", "value": 1, "abbr": "S/ML"}, 10 | {"concept_1": "Data Science", "concept_2": "Science", "value": 0.05, "abbr": "DS/S"}, 11 | {"concept_1": "Machine Learning", "concept_2": "Science", "value": 0.01, "abbr": "ML/S"} 12 | ] 13 | 14 | # Plot in blue 15 | plot_heatmap(co_occurrence, file='test-blue.html', color="Blue Lagoon", 16 | metadata_to_display=[("Abbreviation", "abbr")]) 17 | 18 | # Plot in gold 19 | plot_heatmap(co_occurrence, file='test-gold.html', color="Tahiti Gold", 20 | metadata_to_display=[("Abbreviation", "abbr")]) 21 | -------------------------------------------------------------------------------- /tests/test_heatmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from wellcomeml.viz.elements import plot_heatmap 4 | 5 | 6 | def test_heatmap(tmp_path): 7 | 8 | co_occurrence = [ 9 | {"concept_1": "Data Science", "concept_2": "Machine Learning", "value": 1, "abbr": "DS/ML"}, 10 | {"concept_1": "Machine Learning", "concept_2": "Data Science", "value": 0.3, 11 | "abbr": "ML/DS"}, 12 | {"concept_1": "Science", "concept_2": "Data Science", "value": 1, "abbr": "S/DS"}, 13 | {"concept_1": "Science", "concept_2": "Machine Learning", "value": 1, "abbr": "S/ML"}, 14 | {"concept_1": "Data Science", "concept_2": "Science", "value": 0.05, "abbr": "DS/S"}, 15 | {"concept_1": "Machine Learning", "concept_2": "Science", "value": 0.01, "abbr": "ML/S"} 16 | ] 17 | 18 | # Plot in blue 19 | path = os.path.join(tmp_path, 'test.html') 20 | 21 | plot_heatmap(co_occurrence, file=path, color="Blue Lagoon", 22 | metadata_to_display=[("Abbreviation", "abbr")]) 23 | 24 | # Asserts that it created the file correctly. 25 | assert os.path.exists(path) 26 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.datasets.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.datasets package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.datasets.conll module 8 | -------------------------------- 9 | 10 | .. automodule:: wellcomeml.datasets.conll 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.datasets.download module 16 | ----------------------------------- 17 | 18 | .. automodule:: wellcomeml.datasets.download 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.datasets.hoc module 24 | ------------------------------ 25 | 26 | .. automodule:: wellcomeml.datasets.hoc 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | wellcomeml.datasets.winer module 32 | -------------------------------- 33 | 34 | .. automodule:: wellcomeml.datasets.winer 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: wellcomeml.datasets 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /examples/text_clustering.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.clustering import TextClustering 2 | 3 | # This is bad clustering 4 | 5 | cluster = TextClustering( 6 | reducer='umap', n_kw=2, 7 | params={'clustering': {'min_samples': 2, 'eps': 3}} 8 | ) 9 | 10 | X = ['Wellcome Trust', 11 | 'The Wellcome Trust', 12 | 'Sir Henry Wellcome', 13 | 'Francis Crick', 14 | 'Crick Institute', 15 | 'Francis Harry Crick'] 16 | 17 | cluster.fit(X) 18 | print("Not very good clusters:") 19 | print([(x, cluster) for x, cluster in zip(X, cluster.cluster_ids)]) 20 | 21 | # This is a better one. Let's optimise for silhouette 22 | 23 | param_grid = { 24 | 'reducer': {'min_dist': [0.0, 0.2], 25 | 'n_neighbors': [2, 3, 5], 26 | 'metric': ['cosine', 'euclidean']}, 27 | 'clustering': {'min_samples': [2, 5], 28 | 'eps': [0.5, 1, 1.5]} 29 | } 30 | 31 | best_params = cluster.optimise(X, param_grid=param_grid, verbose=1) 32 | 33 | print("Awesome clusters:") 34 | print([(x, cluster) for x, cluster in zip(X, cluster.cluster_ids)]) 35 | print("Keywords:") 36 | print(cluster.cluster_kws) 37 | -------------------------------------------------------------------------------- /tests/test_palettes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | from wellcomeml.viz import palettes 5 | from wellcomeml.viz import colors 6 | 7 | 8 | def test_palette_sizes(): 9 | assert len(palettes.Wellcome11) == 11 10 | assert len(palettes.Wellcome33) == 33 11 | assert len(palettes.Wellcome33Shades) == 33 12 | 13 | 14 | def test_consistency_linearised(): 15 | assert set(palettes.Wellcome33) == set(palettes.Wellcome33Shades) 16 | 17 | 18 | def test_consistency_matrix(): 19 | linearised_full = [color for group in palettes.WellcomeMatrix 20 | for color in group] 21 | assert set(palettes.Wellcome33) == set(linearised_full) 22 | 23 | 24 | def test_hex_rgb_consistency(): 25 | for color in colors.NAMED_COLORS_LARGE_DICT.values(): 26 | from_rgb = "#" + "".join(f"{component:02x}" for component in color.rgb) 27 | assert from_rgb.lower() == color.hex.lower() 28 | 29 | for color in colors.NAMED_COLORS_DICT.values(): 30 | from_rgb = "#" + "".join(f"{component:02x}" for component in color.rgb) 31 | assert from_rgb.lower() == color.hex.lower() 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Wellcome Trust Data Labs Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/datasets/test_conll.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.datasets.conll import _load_data_spacy, load_conll 2 | import pytest 3 | 4 | 5 | def test_length(): 6 | X, Y = _load_data_spacy("tests/test_data/test_conll", inc_outside=True) 7 | 8 | assert len(X) == len(Y) and len(X) == 4 9 | 10 | 11 | def test_entity(): 12 | X, Y = _load_data_spacy("tests/test_data/test_conll", inc_outside=False) 13 | 14 | start = Y[0][0]["start"] 15 | end = Y[0][0]["end"] 16 | 17 | assert X[0][start:end] == "LEICESTERSHIRE" 18 | 19 | 20 | def test_no_outside_entities(): 21 | X, Y = _load_data_spacy("tests/test_data/test_conll", inc_outside=False) 22 | 23 | outside_entities = [ 24 | entity for entities in Y for entity in entities if entity["label"] == "O" 25 | ] 26 | 27 | assert len(outside_entities) == 0 28 | 29 | 30 | def test_load_conll(): 31 | X, y = load_conll(dataset="test_conll") 32 | 33 | assert isinstance(X, tuple) 34 | assert isinstance(y, tuple) 35 | assert len(X) == 4 36 | assert len(y) == 4 37 | 38 | 39 | def test_load_conll_raises_KeyError(): 40 | with pytest.raises(KeyError): 41 | load_conll(split="wrong_argument") 42 | -------------------------------------------------------------------------------- /tests/test_clustering_visualisation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from wellcomeml.ml.clustering import TextClustering 4 | from wellcomeml.viz.visualize_clusters import visualize_clusters 5 | 6 | 7 | def test_output_html(tmp_path): 8 | """Tests that the output html is generated correclty by the clustering function""" 9 | 10 | # This will be the file to 11 | temporary_file = os.path.join(tmp_path, 'test-cluster.html') 12 | 13 | # Run clustering on small dummy data (see test_clustering.py) 14 | cluster = TextClustering(embedding_random_state=42, 15 | reducer_random_state=43, 16 | clustering_random_state=44) 17 | 18 | X = ['Wellcome Trust', 19 | 'The Wellcome Trust', 20 | 'Sir Henry Wellcome', 21 | 'Francis Crick', 22 | 'Crick Institute', 23 | 'Francis Harry Crick'] 24 | 25 | cluster.fit(X) 26 | 27 | # Run the visualisation function with output_file=temporary_file 28 | visualize_clusters(clustering=cluster, output_file_path=temporary_file, radius=0.01, 29 | alpha=0.5, output_in_notebook=False) 30 | 31 | # Assert that the html was generated correctly 32 | assert os.path.exists(temporary_file) 33 | -------------------------------------------------------------------------------- /tests/test_doc2vec.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.doc2vec_vectorizer import Doc2VecVectorizer 2 | 3 | 4 | def test_fit_transform(): 5 | X = [ 6 | "Wellcome trust gives grants", 7 | "Covid is a infeactious disease", 8 | "Sourdough bread is delicious", 9 | "Zoom is not so cool", 10 | "Greece is the best country", 11 | "Waiting for the vaccine" 12 | ] 13 | doc2vec = Doc2VecVectorizer(vector_size=8, epochs=2) 14 | X_vec = doc2vec.fit_transform(X) 15 | assert X_vec.shape == (6, 8) 16 | 17 | 18 | def test_score(): 19 | # It is quite difficult to construct a test where the score is reliably high 20 | # so we fallback to testing that scores produced a number from 0 to 1. 21 | # It would be even better to test loss is decreasing but gensim does not expose loss 22 | X = [ 23 | "Covid is a disease that can kill you", 24 | "HIV is another disease that can kill", 25 | "Wellcome trust funds covid and hiv research" 26 | "Wellcome trust is similar to NIH in the US, in that they both fund research" 27 | "NIH gives the most money for research every year" 28 | ] 29 | doc2vec = Doc2VecVectorizer(min_count=1, vector_size=4, negative=5, epochs=100) 30 | doc2vec.fit(X) 31 | score = doc2vec.score(X) 32 | assert 0 <= score <= 1 33 | -------------------------------------------------------------------------------- /tests/test_vectorizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import pytest 3 | 4 | from wellcomeml.ml.vectorizer import Vectorizer 5 | 6 | 7 | def test_bert_dispatch(): 8 | X = ["This is a sentence"] 9 | 10 | text_vectorizer = Vectorizer(embedding='bert') 11 | X_embed = text_vectorizer.fit_transform(X) 12 | 13 | assert(X_embed.shape == (1, 768)) 14 | 15 | 16 | def test_tf_idf_dispatch(): 17 | X = ['Sentence Lacking Stopwords'] 18 | 19 | text_vectorizer = Vectorizer(embedding='tf-idf') 20 | X_embed = text_vectorizer.fit_transform(X) 21 | 22 | assert (X_embed.shape == (1, 3)) 23 | 24 | 25 | def test_wrong_model_dispatch_error(): 26 | with pytest.raises(ValueError): 27 | Vectorizer(embedding='embedding_that_doesnt_exist') 28 | 29 | 30 | def test_vectorizer_that_does_not_have_save(monkeypatch): 31 | X = ['This is a sentence'] 32 | 33 | vec = Vectorizer() 34 | 35 | X_embed = vec.fit_transform(X) 36 | 37 | monkeypatch.delattr(vec.vectorizer.__class__, 'save_transformed', raising=True) 38 | monkeypatch.delattr(vec.vectorizer.__class__, 'load_transformed', raising=True) 39 | 40 | with pytest.raises(NotImplementedError): 41 | vec.save_transformed(path='fake_path.npy', X_transformed=X_embed) 42 | 43 | with pytest.raises(NotImplementedError): 44 | vec.load_transformed(path='fake_path.npy') 45 | -------------------------------------------------------------------------------- /tests/test_frequency_vectorizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import pytest 3 | 4 | from wellcomeml.ml.frequency_vectorizer import WellcomeTfidf 5 | 6 | 7 | def test_tf_idf_dispatch(): 8 | X = ['Sentence Lacking Stopwords'] 9 | 10 | text_vectorizer = WellcomeTfidf() 11 | X_embed = text_vectorizer.fit_transform(X) 12 | 13 | assert (X_embed.shape == (1, 3)) 14 | 15 | 16 | def test_save_and_load(tmpdir): 17 | tmpfile = tmpdir.join('test.npz') 18 | 19 | X = ["This is a sentence"*100] 20 | 21 | vec = WellcomeTfidf() 22 | 23 | X_embed = vec.fit_transform(X) 24 | 25 | vec.save_transformed(str(tmpfile), X_embed) 26 | 27 | X_loaded = vec.load_transformed(str(tmpfile)) 28 | 29 | assert (X_loaded != X_embed).sum() == 0 30 | 31 | 32 | def test_fit_transform_and_transform(): 33 | X = [ 34 | "This is a sentence", 35 | "This is another one", 36 | "This is a third sentence", 37 | "Wellcome is a global charitable foundation", 38 | "We want everyone to benefit from science's potential to improve health and save lives." 39 | ] 40 | 41 | text_vectorizer = WellcomeTfidf() 42 | X_embed = text_vectorizer.fit_transform(X) 43 | 44 | X_embed_2 = text_vectorizer.transform(X) 45 | 46 | # Asserts that the result of transform is almost the same as fit transform 47 | assert (X_embed-X_embed_2).sum() == pytest.approx(0, abs=1e-6) 48 | -------------------------------------------------------------------------------- /wellcomeml/ml/sent2vec_vectorizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vectorizer that exposes sklearn interface to sent2vec 3 | paper and codebase. https://github.com/epfml/sent2vec 4 | """ 5 | from wellcomeml.utils import check_cache_and_download, throw_extra_import_message 6 | 7 | required_modules = 'sklearn' 8 | required_extras = 'core' 9 | 10 | try: 11 | from sklearn.base import TransformerMixin, BaseEstimator 12 | except ImportError as e: 13 | throw_extra_import_message(error=e, required_modules=required_modules, 14 | required_extras=required_extras) 15 | 16 | 17 | class Sent2VecVectorizer(BaseEstimator, TransformerMixin): 18 | def __init__(self, pretrained=None): 19 | self.pretrained = pretrained 20 | 21 | def fit(self, *_): 22 | 23 | try: 24 | import sent2vec 25 | except ImportError: 26 | from wellcomeml.__main__ import download 27 | 28 | download("non_pypi_packages") 29 | import sent2vec 30 | 31 | if self.pretrained: 32 | model_path = check_cache_and_download(self.pretrained) 33 | self.model = sent2vec.Sent2vecModel() 34 | self.model.load_model(model_path) 35 | else: 36 | # Custom training not yet implemented 37 | raise NotImplementedError( 38 | "Fit only implemented for loading pretrained models" 39 | ) 40 | return self 41 | 42 | def transform(self, X): 43 | return self.model.embed_sentences(X) 44 | -------------------------------------------------------------------------------- /wellcomeml/metrics/f1.py: -------------------------------------------------------------------------------- 1 | """ Whole test set and batchwise f1 metrics for use with keras 2 | 3 | Adapted from: https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric 4 | """ 5 | from wellcomeml.utils import throw_extra_import_message 6 | 7 | required_modules = 'tensorflow' 8 | required_extras = 'tensorflow' 9 | 10 | try: 11 | import tensorflow as tf 12 | import tensorflow.keras.backend as K 13 | except ImportError as e: 14 | throw_extra_import_message(e, required_modules, required_extras) 15 | 16 | 17 | def f1_metric(y_true, y_pred): 18 | """Calculate batchwise macro f1 19 | 20 | >>> model.compile( 21 | >>> loss="binary_crossentropy", 22 | >>> optimizer="adam", 23 | >>> metrics=["accuracy", f1_metric] 24 | >>> ) 25 | """ 26 | y_true = K.cast(y_true, "float") 27 | y_pred = K.cast(y_pred, "float") 28 | tp = K.sum(K.cast(y_true * y_pred, "float"), axis=0) 29 | fp = K.sum(K.cast((1 - y_true) * y_pred, "float"), axis=0) 30 | fn = K.sum(K.cast(y_true * (1 - y_pred), "float"), axis=0) 31 | p = tp / (tp + fp + K.epsilon()) 32 | r = tp / (tp + fn + K.epsilon()) 33 | f1 = 2 * p * r / (p + r + K.epsilon()) 34 | f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1) 35 | 36 | return K.mean(f1) 37 | 38 | 39 | def f1_loss(y_true, y_pred): 40 | """Generate batchwise macro f1 loss 41 | 42 | >>> model.compile( 43 | >>> loss=f1_loss, 44 | >>> optimizer="adam", 45 | >>> metrics=["accuracy"] 46 | >>> ) 47 | """ 48 | return 1 - f1_metric(y_true, y_pred) 49 | -------------------------------------------------------------------------------- /tests/test_data/test_conll: -------------------------------------------------------------------------------- 1 | -DOCSTART- -X- O O 2 | 3 | CRICKET NNP I-NP O 4 | - : O O 5 | LEICESTERSHIRE NNP I-NP I-ORG 6 | TAKE NNP I-NP O 7 | OVER IN I-PP O 8 | AT NNP I-NP O 9 | TOP NNP I-NP O 10 | AFTER NNP I-NP O 11 | INNINGS NNP I-NP O 12 | VICTORY NN I-NP O 13 | . . O O 14 | 15 | West NNP I-NP I-MISC 16 | Indian NNP I-NP I-MISC 17 | all-rounder NN I-NP O 18 | Phil NNP I-NP I-PER 19 | Simmons NNP I-NP I-PER 20 | took VBD I-VP O 21 | four CD I-NP O 22 | for IN I-PP O 23 | 38 CD I-NP O 24 | on IN I-PP O 25 | Friday NNP I-NP O 26 | as IN I-PP O 27 | Leicestershire NNP I-NP I-ORG 28 | beat VBD I-VP O 29 | Somerset NNP I-NP I-ORG 30 | by IN I-PP O 31 | an DT I-NP O 32 | innings NN I-NP O 33 | and CC O O 34 | 39 CD I-NP O 35 | runs NNS I-NP O 36 | in IN I-PP O 37 | two CD I-NP O 38 | days NNS I-NP O 39 | to TO I-VP O 40 | take VB I-VP O 41 | over IN I-PP O 42 | at IN B-PP O 43 | the DT I-NP O 44 | head NN I-NP O 45 | of IN I-PP O 46 | the DT I-NP O 47 | county NN I-NP O 48 | championship NN I-NP O 49 | . . O O 50 | 51 | -DOCSTART- -X- O O 52 | 53 | Result NN I-NP O 54 | and CC O O 55 | close VB I-VP O 56 | of IN I-PP O 57 | play NN I-NP O 58 | scores NNS I-NP O 59 | in IN I-PP O 60 | English JJ I-NP I-MISC 61 | county NN I-NP O 62 | championship NN I-NP O 63 | matches NNS I-NP O 64 | on IN I-PP O 65 | Friday NNP I-NP O 66 | : : O O 67 | 68 | Leicester JJ I-NP I-LOC 69 | : : O O 70 | Leicestershire VB I-VP I-ORG 71 | beat VB I-VP O 72 | Somerset NNP I-NP I-ORG 73 | by IN I-PP O 74 | an DT I-NP O 75 | innings NN I-NP O 76 | and CC O O 77 | 39 CD I-NP O 78 | runs NNS I-NP O 79 | . . O O 80 | 81 | -------------------------------------------------------------------------------- /wellcomeml/io/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | Utilities for loading and saving data from various formats 6 | """ 7 | import logging 8 | import json 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def write_jsonl(input_data, output_file): 14 | """ 15 | Write a list of dicts to jsonl (line delimited json) 16 | 17 | Output format will look like: 18 | 19 | ``` 20 | {'a': 0} 21 | {'b': 1} 22 | {'c': 2} 23 | {'d': 3} 24 | ``` 25 | 26 | Args: 27 | input_data(dict): A list of dicts to be written to json. 28 | output_file(str): Filename to which the jsonl will be saved. 29 | """ 30 | 31 | with open(output_file, 'w') as fb: 32 | 33 | # Check if a dict (and convert to list if so) 34 | 35 | if isinstance(input_data, dict): 36 | input_data = [value for key, value in input_data.items()] 37 | 38 | # Write out to jsonl file 39 | 40 | logger.debug('Writing %s lines to %s', len(input_data), output_file) 41 | 42 | for i in input_data: 43 | json_ = json.dumps(i) + '\n' 44 | fb.write(json_) 45 | 46 | 47 | def _yield_jsonl(file_name): 48 | for row in open(file_name, "r"): 49 | yield json.loads(row) 50 | 51 | 52 | def read_jsonl(input_file): 53 | """Create a list of dicts from a jsonl file 54 | 55 | Args: 56 | input_file(str): File to be loaded. 57 | """ 58 | 59 | out = list(_yield_jsonl(input_file)) 60 | 61 | logger.debug('Read %s lines from %s', len(out), input_file) 62 | 63 | return out 64 | -------------------------------------------------------------------------------- /wellcomeml/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | Set up shared logger 6 | """ 7 | import logging 8 | import warnings 9 | import os 10 | 11 | 12 | def get_numeric_level(level): 13 | if isinstance(level, str): 14 | level = getattr(logging, level.upper(), 10) 15 | return level 16 | 17 | 18 | def build_logger(logging_level, name): 19 | numeric_level = get_numeric_level(logging_level) 20 | 21 | logger = logging.getLogger(name) 22 | 23 | logging.basicConfig( 24 | format="%(asctime)s %(name)s %(levelname)s: %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | level=numeric_level, 27 | ) 28 | 29 | return logger 30 | 31 | 32 | DEFAULT_LOGGING_LEVEL = "INFO" 33 | LOGGING_LEVEL = os.getenv("LOGGING_LEVEL", DEFAULT_LOGGING_LEVEL) 34 | LOGGING_LEVEL = get_numeric_level(LOGGING_LEVEL) 35 | 36 | logger = build_logger(logging_level=LOGGING_LEVEL, name=__name__) 37 | 38 | external_logging_level = { 39 | 'transformers': LOGGING_LEVEL, 40 | 'tensorflow': LOGGING_LEVEL, 41 | 'gensim': LOGGING_LEVEL, 42 | 'sklearn': LOGGING_LEVEL, 43 | 'spacy': get_numeric_level('ERROR'), # Spacy is a bit annoying with some initialisation logs 44 | 'torch': LOGGING_LEVEL, 45 | 'tokenizers': LOGGING_LEVEL 46 | } 47 | 48 | for package, level in external_logging_level.items(): 49 | logging.getLogger(package).setLevel(level) 50 | 51 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = os.environ.get( 52 | "TF_CPP_MIN_LOG_LEVEL", str(LOGGING_LEVEL // 10) 53 | ) 54 | 55 | if LOGGING_LEVEL >= 40: # ERROR 56 | warnings.filterwarnings("ignore") 57 | -------------------------------------------------------------------------------- /tests/ml/test_keras_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | 6 | import numpy as np 7 | import pytest 8 | import tensorflow as tf 9 | 10 | from wellcomeml.ml.keras_utils import Metrics 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def tmpdir(tmpdir_factory): 15 | return tmpdir_factory.mktemp("test_f1") 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def data(): 20 | X_train = np.random.random((100, 10)) 21 | y_train = np.random.random(100).astype(int) 22 | 23 | X_test = np.random.random((100, 10)) 24 | y_test = np.random.random(100).astype(int) 25 | 26 | return {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test} 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def model(): 31 | inputs = tf.keras.Input(shape=(10,)) 32 | x = tf.keras.layers.Dense(128, activation="relu")(inputs) 33 | outputs = tf.keras.layers.Dense(1, "sigmoid")(x) 34 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 35 | 36 | return model 37 | 38 | 39 | def test_metrics_callback(data, model, tmpdir): 40 | 41 | history_path = os.path.join(tmpdir, "test_f1.csv") 42 | 43 | model.compile( 44 | loss="binary_crossentropy", 45 | optimizer="adam", 46 | metrics=["accuracy"], 47 | ) 48 | 49 | metrics = Metrics( 50 | validation_data=(data["X_test"], data["y_test"]), history_path=history_path 51 | ) 52 | 53 | model.fit( 54 | data["X_train"], 55 | data["y_train"], 56 | epochs=5, 57 | validation_data=(data["X_test"], data["y_test"]), 58 | batch_size=1024, 59 | verbose=0, 60 | callbacks=[metrics], 61 | ) 62 | 63 | assert os.path.exists(history_path) 64 | -------------------------------------------------------------------------------- /tests/test_extras.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | from unittest import mock 4 | 5 | from importlib import import_module, reload 6 | 7 | import pytest 8 | 9 | extra_checks = { 10 | 'tensorflow': [ 11 | 'wellcomeml.ml.attention', 12 | 'wellcomeml.ml.bert_semantic_equivalence', 13 | 'wellcomeml.ml.bilstm', 14 | 'wellcomeml.ml.cnn', 15 | 'wellcomeml.ml.keras_utils', 16 | # 'wellcomeml.ml.keras_vectorizer', 17 | # Not working properly yet - reloading the module causes a different errot han ImportError 18 | 'wellcomeml.ml.similarity_entity_liking' 19 | ], 20 | 'torch': [ 21 | # 'wellcomeml.ml.bert_vectorizer', 22 | # Not working properly yet - reloading the module causes a different error than ImportError 23 | 'wellcomeml.ml.spacy_classifier', 24 | 'wellcomeml.ml.similarity_entity_liking', 25 | # 'wellcomeml.ml.bert_semantic_equivalence' 26 | ], 27 | 'spacy': [ 28 | 'wellcomeml.ml.spacy_classifier', 29 | 'wellcomeml.ml.spacy_entity_linking', 30 | 'wellcomeml.ml.spacy_knowledge_base' 31 | ] 32 | } 33 | 34 | module_extra_pairs = [ 35 | (module_name, extra_name) 36 | for extra_name, module_name_list in extra_checks.items() 37 | for module_name in module_name_list 38 | ] 39 | 40 | 41 | @pytest.mark.extras 42 | @pytest.mark.parametrize("module_name,extra_name", module_extra_pairs) 43 | def test_dependencies(module_name, extra_name): 44 | """ Tests that importing the module, in the absence of the extra, throws an error """ 45 | 46 | with mock.patch.dict(sys.modules, {extra_name: None}): 47 | with pytest.raises(ImportError): 48 | _tmp_module = import_module(module_name) 49 | if module_name in sys.modules: 50 | reload(_tmp_module) 51 | -------------------------------------------------------------------------------- /tests/test_spacy_classifier.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import numpy as np 4 | 5 | from wellcomeml.ml.spacy_classifier import SpacyClassifier 6 | 7 | 8 | def test_multilabel(): 9 | X = [ 10 | "One and two", 11 | "One only", 12 | "Three and four, nothing else", 13 | "Two nothing else", 14 | "Two and three" 15 | ] 16 | Y = np.array([ 17 | [1, 1, 0, 0], 18 | [1, 0, 0, 0], 19 | [0, 0, 1, 1], 20 | [0, 1, 0, 0], 21 | [0, 1, 1, 0] 22 | ]) 23 | 24 | model = SpacyClassifier() 25 | model.fit(X, Y) 26 | assert model.score(X, Y) > 0.2 # > 0.3 fails sometimes 27 | assert model.predict(X).shape == (5, 4) 28 | 29 | 30 | def test_multilabel_Y_list(): 31 | X = [ 32 | "One and two", 33 | "One only", 34 | "Three and four, nothing else", 35 | "Two nothing else", 36 | "Two and three" 37 | ] 38 | Y = [ 39 | [1, 1, 0, 0], 40 | [1, 0, 0, 0], 41 | [0, 0, 1, 1], 42 | [0, 1, 0, 0], 43 | [0, 1, 1, 0] 44 | ] 45 | 46 | model = SpacyClassifier() 47 | model.fit(X, Y) 48 | assert model.score(X, Y) > 0.2 # > 0.3 fails sometimes 49 | assert model.predict(X).shape == (5, 4) 50 | 51 | 52 | def test_partial_fit(): 53 | X = [ 54 | "One and two", 55 | "One only", 56 | "Three and four, nothing else", 57 | "Two nothing else", 58 | "Two and three" 59 | ] 60 | Y = [ 61 | [1, 1, 0, 0], 62 | [1, 0, 0, 0], 63 | [0, 0, 1, 1], 64 | [0, 1, 0, 0], 65 | [0, 1, 1, 0] 66 | ] 67 | 68 | model = SpacyClassifier() 69 | for x, y in zip(X, Y): 70 | model.partial_fit([x], [y]) 71 | assert model.score(X, Y) > 0.2 72 | assert model.predict(X).shape == (5, 4) 73 | -------------------------------------------------------------------------------- /wellcomeml/spacy/spacy_doc_to_prodigy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | from wellcomeml.utils import throw_extra_import_message 4 | 5 | required_modules = 'spacy' 6 | required_extras = 'spacy' 7 | 8 | try: 9 | import spacy 10 | except ImportError as e: 11 | throw_extra_import_message(error=e, required_modules=required_modules, extras=required_extras) 12 | 13 | 14 | class SpacyDocToProdigy: 15 | """Convert spacy documents into prodigy format 16 | """ 17 | 18 | def run(self, docs): 19 | """ 20 | Cycle through docs and return prodigy docs. 21 | """ 22 | 23 | return list(self.return_one_prodigy_doc(doc) for doc in docs) 24 | 25 | def return_one_prodigy_doc(self, doc): 26 | """Given one spacy document, yield a prodigy style dict 27 | 28 | Args: 29 | doc (spacy.tokens.doc.Doc): A spacy document 30 | 31 | Returns: 32 | dict: Prodigy style document 33 | 34 | """ 35 | 36 | if not isinstance(doc, spacy.tokens.doc.Doc): 37 | raise TypeError("doc must be of type spacy.tokens.doc.Doc") 38 | 39 | text = doc.text 40 | spans = [] 41 | tokens = [] 42 | 43 | for token in doc: 44 | tokens.append({ 45 | "text": token.text, 46 | "start": token.idx, 47 | "end": token.idx + len(token.text), 48 | "id": token.i, 49 | }) 50 | 51 | for ent in doc.ents: 52 | spans.append({ 53 | "token_start": ent.start, 54 | "token_end": ent.end, 55 | "start": ent.start_char, 56 | "end": ent.end_char, 57 | "label": ent.label_, 58 | }) 59 | 60 | out = { 61 | "text": text, 62 | "spans": spans, 63 | "tokens": tokens, 64 | } 65 | 66 | return out 67 | -------------------------------------------------------------------------------- /wellcomeml/datasets/download.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tarfile 4 | 5 | import boto3 6 | from botocore import UNSIGNED 7 | from botocore.client import Config 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | DATA_DIR = os.path.expanduser("~/.cache/wellcomeml/data") 12 | DATA_DISPATCH = { 13 | "hoc": { 14 | "bucket": "datalabs-public", 15 | "path": "datasets/hoc/hoc.tar", 16 | "file_name": "hoc.tar", 17 | }, 18 | "winer": { 19 | "bucket": "datalabs-public", 20 | "path": "datasets/ner/winer.tar", 21 | "file_name": "winer.tar", 22 | }, 23 | "conll": { 24 | "bucket": "datalabs-public", 25 | "path": "datasets/ner/conll.tar", 26 | "file_name": "conll.tar", 27 | }, 28 | "test_conll": { 29 | "bucket": "datalabs-public", 30 | "path": "datasets/ner/test_conll.tar.gz", 31 | "file_name": "test_conll.tar.gz", 32 | }, 33 | } 34 | 35 | 36 | def check_cache_and_download(dataset_name): 37 | """ Checks if dataset_name is cached and return complete path""" 38 | os.makedirs(DATA_DIR, exist_ok=True) 39 | 40 | dataset_path = os.path.join(DATA_DIR, dataset_name) 41 | if not os.path.exists(dataset_path): 42 | logger.info(f"Could not find dataset {dataset_name}. Downloading from S3") 43 | 44 | # The following allows to download from S3 without AWS credentials 45 | s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) 46 | tmp_file = os.path.join(DATA_DIR, DATA_DISPATCH[dataset_name]["file_name"]) 47 | 48 | s3.download_file( 49 | DATA_DISPATCH[dataset_name]["bucket"], 50 | DATA_DISPATCH[dataset_name]["path"], 51 | tmp_file, 52 | ) 53 | 54 | tar = tarfile.open(tmp_file) 55 | tar.extractall(path=DATA_DIR) 56 | tar.close() 57 | 58 | os.remove(tmp_file) 59 | 60 | return dataset_path 61 | -------------------------------------------------------------------------------- /tests/test_data/test_jsonl.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 2 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 3 | {"text": "a b c\n a b c", "tokens": [{"text": "a", "start": 0, "end": 1, "id": 0}, {"text": "b", "start": 2, "end": 3, "id": 1}, {"text": "c", "start": 4, "end": 5, "id": 2}, {"text": "\n ", "start": 5, "end": 7, "id": 3}, {"text": "a", "start": 7, "end": 8, "id": 4}, {"text": "b", "start": 9, "end": 10, "id": 5}, {"text": "c", "start": 11, "end": 12, "id": 6}], "spans": [{"start": 2, "end": 3, "token_start": 1, "token_end": 2, "label": "b"}, {"start": 4, "end": 5, "token_start": 2, "token_end": 3, "label": "i"}, {"start": 7, "end": 8, "token_start": 4, "token_end": 5, "label": "i"}, {"start": 9, "end": 10, "token_start": 5, "token_end": 6, "label": "e"}]} 4 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/modules.md.txt: -------------------------------------------------------------------------------- 1 | .. _modules: 2 | 3 | Main modules and description 4 | ============================== 5 | 6 | |Module|Description|Extras needed| 7 | |---|---|---| 8 | | wellcomeml.ml.attention | Classes that implement keras layers for attention/self-attention | tensorflow | 9 | | wellcomeml.ml.bert_classifier | Classifier to facilitate fine-tuning bert/scibert | tensorflow | 10 | | wellcomeml.ml.bert_semantic_equivalence | Classifier to learn semantic equivalence between pairs of documents | tensorflow | 11 | | wellcomeml.ml.bert_vectorizer | Text vectorizer based on bert/scibert | tensorflow | 12 | | wellcomeml.ml.bilstm | BILSTM Text classifier | tensorflow | 13 | | wellcomeml.ml.clustering | Text clustering pipeline | NA | 14 | | wellcomeml.ml.cnn | CNN Text Classifier | tensorflow | 15 | | wellcomeml.ml.doc2vec_vectorizer | Text vectorizer based on doc2vec | NA | 16 | | wellcomeml.ml.frequency_vectorizer | Text vectorizer based on TF-IDF | NA | 17 | | wellcomeml.ml.keras_utils | Utils for computing metrics during training | tensorflow | 18 | | wellcomeml.ml.keras_vectorizer | Text vectorizer based on Keras | tensorflow | 19 | | wellcomeml.ml.sent2vec_vectorizer | Text vectorizer based on Sent2Vec | (Requires sent2vec, a non-pypi package) | 20 | | wellcomeml.ml.similarity_entity_liking | A class to find most similar documents to a sentence in a corpus | tensorflow | 21 | | wellcomeml.ml.spacy_classifier | A text classifier based on spacy | spacy | 22 | | wellcomeml.ml.spacy_entity_linking | Similar to similarity_entity_linking, but uses spacy | spacy | 23 | | wellcomeml.ml.spacy_knowledge_base | Creates a knowledge base of entities, based on [spacy](https://spacy.io/usage/training#entity-linker) | spacy | 24 | | wellcomeml.ml.spacy_ner | Named entity recognition classifier based on spacy | spacy | 25 | | wellcomeml.ml.transformers_tokenizer | Bespoke tokenizer based on transformers | Transformers | 26 | | wellcomeml.ml.vectorizer | Abstract class for vectorizers | NA | 27 | | wellcomeml.ml.voting_classifier | Meta-classifier based on majority voting | NA| 28 | -------------------------------------------------------------------------------- /create_release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | git fetch && git rebase origin/main 3 | 4 | VERSION=$(python setup.py --version) 5 | VIRTUALENV=build/virtualenv 6 | 7 | echo "\x1B[31m******* DANGER ********\x1B[0m" 8 | echo Creating a new release for v$VERSION 9 | echo This is going to upload files to upload files to AWS, create a github release and change the pypi registry 10 | read -p 'Are you sure you want to proceed (y/n)? ' PROCEED 11 | 12 | if [[ ! $PROCEED =~ ^[Yy]$ ]] 13 | then 14 | exit 1 15 | fi 16 | 17 | $VIRTUALENV/bin/python3 setup.py sdist bdist_wheel 18 | aws s3 sync dist/ s3://datalabs-packages/wellcomeml 19 | aws s3 cp --recursive --acl public-read dist/ s3://datalabs-public/wellcomeml 20 | $VIRTUALENV/bin/python -m twine upload --repository pypi --username $TWINE_USERNAME --password $TWINE_PASSWORD dist/* 21 | 22 | 23 | curl --request POST \ 24 | --url https://api.github.com/repos/wellcometrust/wellcomeml/releases \ 25 | --header 'authorization: token '$GITHUB_TOKEN'' \ 26 | --header 'content-type: application/json' \ 27 | --data '{ 28 | "tag_name": "v'$VERSION'", 29 | "target_commitish": "main", 30 | "name": "v'$VERSION'", 31 | "prerelease": false 32 | }' 33 | 34 | 35 | RELEASE_ID=$(curl -XGET --silent "https://api.github.com/repos/wellcometrust/WellcomeML/releases/tags/v$VERSION" | jq .id) 36 | 37 | cd dist/ 38 | 39 | curl --request POST --silent --header "Authorization: token $GITHUB_TOKEN" -H "Content-Type: $(file -b --mime-type wellcomeml-$VERSION.tar.gz)" --data-binary @wellcomeml-$VERSION.tar.gz --url "https://uploads.github.com/repos/wellcometrust/WellcomeML/releases/$RELEASE_ID/assets?name=wellcomeml-$VERSION.tar.gz" 40 | curl --request POST --silent --header "Authorization: token $GITHUB_TOKEN" -H "Content-Type: $(file -b --mime-type wellcomeml-$VERSION.tar.gz)" --data-binary @wellcomeml-$VERSION.tar.gz --url "https://uploads.github.com/repos/wellcometrust/WellcomeML/releases/$RELEASE_ID/assets?name=wellcomeml-$VERSION-py3-none-any.whl" 41 | 42 | echo "Release created" 43 | echo "Please change the description at https://github.com/wellcometrust/WellcomeML/releases/tag/v$VERSION" 44 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. WellcomeML documentation master file, created by 2 | sphinx-quickstart on Mon Jun 22 09:56:00 2020. 3 | 4 | WellcomeML's documentation! 5 | ====================================== 6 | 7 | Current release: |release| 8 | 9 | This package contains common utility functions for usual tasks at Wellcome Data 10 | Labs, in particular functionalities for processing, embedding and classifying text data. 11 | This includes 12 | 13 | * An intuitive sklearn-like API wrapping text vectorizers, such as Doc2vec, Bert, Scibert 14 | * Common API for off-the-shelf classifiers to allow quick iteration (e.g. Frequency Vectorizer, Bert, Scibert, basic CNN, BiLSTM) 15 | * Utils to download and convert academic text datasets for benchmark 16 | 17 | Check :ref:`examples` for some examples and :ref:`clustering` for clustering-specific documentation. 18 | 19 | Quickstart 20 | ------------------------------------- 21 | 22 | In order to install the latest release, with all the deep learning functionalities:: 23 | 24 | pip install wellcomeml[deep-learning] 25 | 26 | For a quicker installation that only includes certain frequency vectorisers, the io operations 27 | and the spacy-to-prodigy conversions:: 28 | 29 | pip install wellcomeml 30 | 31 | 32 | Development 33 | ------------------------------------- 34 | 35 | For installing the latest main branch:: 36 | 37 | pip install git+https://github.com/wellcometrust/WellcomeML.git[deep-learning] 38 | 39 | If you want to contribute, please refer to the issues and documentation in the main `github repository `_ 40 | 41 | Contact 42 | -------- 43 | To contact us, you can open an issue in the main github repository or e mail `Data Labs `_. 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | 52 | .. toctree:: 53 | :hidden: 54 | 55 | self 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: Contents: 60 | :hidden: 61 | 62 | Examples 63 | List of main modules and descriptions 64 | Clustering text with WellcomeML 65 | Core library documentation 66 | -------------------------------------------------------------------------------- /wellcomeml/metrics/ner_classification_report.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.utils import throw_extra_import_message 2 | 3 | required_module = 'nervaluate' 4 | required_extras = 'core' 5 | 6 | try: 7 | from nervaluate import Evaluator 8 | except ImportError as e: 9 | throw_extra_import_message(e, required_module, required_extras) 10 | 11 | 12 | def ner_classification_report(y_true, y_pred, groups, tags): 13 | """ 14 | Evaluate the model's performance for each grouping of data 15 | for the NER labels given in 'tags' 16 | 17 | Input: 18 | y_pred: a list of predicted entities 19 | y_true: a list of gold entities 20 | groups: (str) the group each of the pred or gold entities belong to 21 | 22 | Output: 23 | report: evaluation metrics for each group 24 | in a nice format for printing 25 | """ 26 | 27 | unique_groups = sorted(set(groups)) 28 | outputs = [] 29 | 30 | for group in unique_groups: 31 | pred_doc_entities = [y_pred[i] for i, g in enumerate(groups) if g == group] 32 | true_doc_entities = [y_true[i] for i, g in enumerate(groups) if g == group] 33 | 34 | evaluator = Evaluator( 35 | true_doc_entities, 36 | pred_doc_entities, 37 | tags=tags 38 | ) 39 | results, _ = evaluator.evaluate() 40 | 41 | output_dict = { 42 | 'precision (partial)': results['partial']['precision'], 43 | 'recall (partial)': results['partial']['recall'], 44 | 'f1-score': results['partial']['f1'], 45 | 'support': len(pred_doc_entities) 46 | } 47 | output = [group] 48 | output.extend(list(output_dict.values())) 49 | outputs.append(output) 50 | 51 | headers = output_dict.keys() 52 | 53 | width = max(len(cn) for cn in unique_groups) 54 | head_fmt = '{:>{width}s} ' + ' {:>17}' * len(headers) 55 | report = head_fmt.format('', *headers, width=width) 56 | report += '\n\n' 57 | row_fmt = '{:>{width}s} ' + ' {:>17.{digits}f}' * 3 + ' {:>17}\n' 58 | 59 | for row in outputs: 60 | report += row_fmt.format(*row, width=width, digits=3) 61 | 62 | report += '\n' 63 | 64 | return report 65 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. WellcomeML documentation master file, created by 2 | sphinx-quickstart on Mon Jun 22 09:56:00 2020. 3 | 4 | WellcomeML's documentation! 5 | ====================================== 6 | 7 | Current release: |release| 8 | 9 | This package contains common utility functions for usual tasks at Wellcome Data 10 | Labs, in particular functionalities for processing, embedding and classifying text data. 11 | This includes 12 | 13 | * An intuitive sklearn-like API wrapping text vectorizers, such as Doc2vec, Bert, Scibert 14 | * Common API for off-the-shelf classifiers to allow quick iteration (e.g. Frequency Vectorizer, Bert, Scibert, basic CNN, BiLSTM) 15 | * Utils to download and convert academic text datasets for benchmark 16 | 17 | Check :ref:`examples` for some examples and :ref:`clustering` for clustering-specific documentation. 18 | 19 | Quickstart 20 | ------------------------------------- 21 | 22 | In order to install the latest release, with all the deep learning functionalities:: 23 | 24 | pip install wellcomeml[deep-learning] 25 | 26 | For a quicker installation that only includes certain frequency vectorisers, the io operations 27 | and the spacy-to-prodigy conversions:: 28 | 29 | pip install wellcomeml 30 | 31 | 32 | Development 33 | ------------------------------------- 34 | 35 | For installing the latest main branch:: 36 | 37 | pip install git+https://github.com/wellcometrust/WellcomeML.git[deep-learning] 38 | 39 | If you want to contribute, please refer to the issues and documentation in the main `github repository `_ 40 | 41 | Contact 42 | -------- 43 | To contact us, you can open an issue in the main github repository or e mail `Data Labs `_. 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | 52 | .. toctree:: 53 | :hidden: 54 | 55 | self 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: Contents: 60 | :hidden: 61 | 62 | Examples 63 | List of main modules and descriptions 64 | Clustering text with WellcomeML 65 | Core library documentation 66 | -------------------------------------------------------------------------------- /tests/test_bert_vectorizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import pytest 3 | 4 | from wellcomeml.ml.bert_vectorizer import BertVectorizer 5 | 6 | EMBEDDING_TYPES = [ 7 | "mean_second_to_last", 8 | "mean_last", 9 | "sum_last", 10 | "mean_last_four", 11 | "pooler" 12 | ] 13 | 14 | 15 | @pytest.fixture 16 | def vec(scope='module'): 17 | vectorizer = BertVectorizer() 18 | 19 | vectorizer.fit() 20 | return vectorizer 21 | 22 | 23 | @pytest.mark.bert 24 | def test_fit_transform_works(vec): 25 | X = ["This is a sentence"] 26 | 27 | assert vec.fit_transform(X).shape == (1, 768) 28 | 29 | 30 | @pytest.mark.bert 31 | def test_embed_two_sentences(vec): 32 | X = [ 33 | "This is a sentence", 34 | "This is another one" 35 | ] 36 | 37 | for embedding in EMBEDDING_TYPES: 38 | vec.sentence_embedding = embedding 39 | X_embed = vec.transform(X, verbose=False) 40 | assert X_embed.shape == (2, 768) 41 | 42 | 43 | @pytest.mark.bert 44 | def test_embed_long_sentence(vec): 45 | X = ["This is a sentence"*500] 46 | 47 | for embedding in EMBEDDING_TYPES: 48 | vec.sentence_embedding = embedding 49 | X_embed = vec.transform(X, verbose=False) 50 | assert X_embed.shape == (1, 768) 51 | 52 | 53 | @pytest.mark.bert 54 | def test_embed_scibert(): 55 | X = ["This is a sentence"] 56 | vec = BertVectorizer(pretrained='scibert') 57 | vec.fit() 58 | 59 | for embedding in EMBEDDING_TYPES: 60 | vec.sentence_embedding = embedding 61 | X_embed = vec.transform(X, verbose=False) 62 | assert X_embed.shape == (1, 768) 63 | 64 | 65 | @pytest.mark.skip("Reason: Build killed or stalls. Issue #200") 66 | def test_save_and_load(tmpdir): 67 | tmpfile = tmpdir.join('test.npy') 68 | 69 | X = ["This is a sentence"] 70 | for pretrained in ['bert', 'scibert']: 71 | for embedding in EMBEDDING_TYPES: 72 | vec = BertVectorizer( 73 | pretrained=pretrained, 74 | sentence_embedding=embedding 75 | ) 76 | X_embed = vec.fit_transform(X, verbose=False) 77 | 78 | vec.save_transformed(str(tmpfile), X_embed) 79 | 80 | X_loaded = vec.load_transformed(str(tmpfile)) 81 | 82 | assert (X_loaded != X_embed).sum() == 0 83 | -------------------------------------------------------------------------------- /examples/entity_linking.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.ml.similarity_entity_linking import SimilarityEntityLinker 2 | 3 | entities_kb = { 4 | "Michelle Williams (actor)": ( 5 | "American actress. She is the recipient of several accolades, including two Golden Globe" 6 | " Awards and a Primetime Emmy Award, in addition to nominations for four Academy Awards " 7 | "and one Tony Award." 8 | ), 9 | "Michelle Williams (musician)": ( 10 | "American entertainer. She rose to fame in the 2000s as a member of R&B girl group " 11 | "Destiny's Child, one of the best-selling female groups of all time with over 60 " 12 | "million records, of which more than 35 million copies sold with the trio lineup " 13 | "with Williams." 14 | ), 15 | "id_3": " ", 16 | } 17 | 18 | stopwords = ["the", "and", "if", "in", "a"] 19 | 20 | train_data = [ 21 | ( 22 | ( 23 | "After Destiny's Child's disbanded in 2006, Michelle Williams released her first " 24 | "pop album, Unexpected (2008)," 25 | ), 26 | {"id": "Michelle Williams (musician)"}, 27 | ), 28 | ( 29 | ( 30 | "On Broadway, Michelle Williams starred in revivals of the musical Cabaret in 2014 " 31 | "and the drama Blackbird in 2016, for which she received a nomination for the Tony " 32 | "Award for Best Actress in a Play." 33 | ), 34 | {"id": "Michelle Williams (actor)"}, 35 | ), 36 | ( 37 | "Franklin would have ideally been awarded a Nobel Prize in Chemistry", 38 | {"id": "No ID"}, 39 | ), 40 | ] 41 | 42 | entity_linker = SimilarityEntityLinker(stopwords=stopwords, embedding="tf-idf") 43 | entity_linker.fit(entities_kb) 44 | tfidf_predictions = entity_linker.predict( 45 | train_data, similarity_threshold=0.1, no_id_col="No ID" 46 | ) 47 | 48 | entity_linker = SimilarityEntityLinker(stopwords=stopwords, embedding="bert") 49 | entity_linker.fit(entities_kb) 50 | bert_predictions = entity_linker.predict( 51 | train_data, similarity_threshold=0.1, no_id_col="No ID" 52 | ) 53 | 54 | print("TF-IDF Predictions:") 55 | for i, (sentence, _) in enumerate(train_data): 56 | print(sentence) 57 | print(tfidf_predictions[i]) 58 | 59 | print("BERT Predictions:") 60 | for i, (sentence, _) in enumerate(train_data): 61 | print(sentence) 62 | print(bert_predictions[i]) 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | import setuptools 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | """ 9 | Load data from the__versions__.py module. Change version, etc in 10 | that module, and it will be automatically populated here. This allows us to 11 | access the module version, etc from inside python with 12 | 13 | Examples: 14 | 15 | >>> from wellcomeml.common import about 16 | >>> about['__version__'] 17 | 2019.10.0 18 | 19 | """ 20 | 21 | about = {} # type: dict 22 | version_path = os.path.join(here, 'wellcomeml', '__version__.py') 23 | with open(version_path, 'r') as f: 24 | exec(f.read(), about) 25 | 26 | with open('README.md', 'r') as f: 27 | long_description = f.read() 28 | 29 | extras = { 30 | 'core': [ 31 | 'scikit-learn', 32 | 'scipy', 33 | 'umap-learn', 34 | 'gensim', 35 | 'bokeh', 36 | 'pandas', 37 | 'nervaluate' 38 | ], 39 | 'transformers': [ 40 | 'transformers', 41 | 'tokenizers' 42 | ], 43 | 'tensorflow': [ 44 | 'tensorflow==2.4.0', 45 | 'tensorflow-addons', 46 | 'numpy>=1.19.2,<1.20' 47 | ], 48 | 'torch': [ 49 | 'torch' 50 | ], 51 | 'spacy': [ 52 | 'spacy[lookups]==3.0.6', 53 | 'click>=7.0,<8.0' 54 | ], 55 | } 56 | 57 | # Allow users to install 'all' if they wish 58 | extras['all'] = [dep for dep_list in extras.values() for dep in dep_list] 59 | 60 | setuptools.setup( 61 | name=about['__name__'], 62 | version=about['__version__'], 63 | author=about['__author__'], 64 | author_email=about['__author_email__'], 65 | description=about['__description__'].replace('\n', ''), 66 | long_description=long_description, 67 | long_description_content_type='text/markdown', 68 | url=about['__url__'], 69 | license=about['__license__'], 70 | packages=setuptools.find_packages(include=["wellcomeml*"]), 71 | classifiers=[ 72 | 'Programming Language :: Python :: 3', 73 | 'Operating System :: OS Independent', 74 | ], 75 | install_requires=[ 76 | 'boto3', 77 | 'twine', 78 | 'cython', 79 | 'tqdm' 80 | ], 81 | extras_require=extras, 82 | tests_require=[ 83 | 'pytest', 84 | 'flake8', 85 | 'black', 86 | 'pytest-cov' 87 | 'tox' 88 | ] 89 | ) 90 | -------------------------------------------------------------------------------- /wellcomeml/viz/palettes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from wellcomeml.viz.colors import WellcomeColor,\ 4 | NAMED_COLORS_DICT, NAMED_COLORS_LARGE_DICT 5 | from typing import List 6 | from types import ModuleType 7 | 8 | # The palette module uses a trick to prevent users using the palettes from 9 | # accidentally changing it, by defining them as a class property. 10 | 11 | 12 | class _WellcomePalette(ModuleType): 13 | """Represents a Wellcome palette""" 14 | 15 | __all__: List[str] = [] 16 | 17 | @property 18 | def Wellcome11(self) -> List[WellcomeColor]: 19 | """ 20 | Full Wellcome categorical palette as in 21 | 22 | https://company-57536.frontify.com/d/gFEfjydViLRJ/ 23 | wellcome-brand-book#/visuals/dataviz-elements-and-rationale 24 | 25 | Returns: 26 | List of Wellcome colors 27 | 28 | """ 29 | return list(NAMED_COLORS_DICT.values()) 30 | 31 | @property 32 | def Wellcome33Shades(self) -> List[WellcomeColor]: 33 | """Wellcome33 with 30% decrements palette""" 34 | return list(NAMED_COLORS_LARGE_DICT.values()) 35 | 36 | @property 37 | def WellcomeMatrix(self) -> List[List[WellcomeColor]]: 38 | """Matrix palette""" 39 | return [ 40 | [self.Wellcome33Shades[i], 41 | self.Wellcome33Shades[i + 1], 42 | self.Wellcome33Shades[i + 2]] 43 | for i in range(0, len(self.Wellcome33Shades), 3) 44 | ] 45 | 46 | @property 47 | def Wellcome33(self) -> List[WellcomeColor]: 48 | """Linearised matrix palette with no repeated adjacent color""" 49 | return [color for color, _, _ in self.WellcomeMatrix] + \ 50 | [color for _, color, _ in self.WellcomeMatrix] + \ 51 | [color for _, _, color in self.WellcomeMatrix] 52 | 53 | @property 54 | def WellcomeBackground(self) -> WellcomeColor: 55 | """Wellcome background color""" 56 | return WellcomeColor("Backgrounds", "#F2F4F6", (244, 244, 246)) 57 | 58 | @property 59 | def WellcomeNoData(self) -> WellcomeColor: 60 | """Wellcome color for 'noise' or missing data or 'other' category """ 61 | return WellcomeColor("No data", "#CCD8DD", (204, 216, 221)) 62 | 63 | 64 | # Transfers the class property to module level variables 65 | _mod = _WellcomePalette('wellcomeml.viz.palettes') 66 | _mod.__doc__ = __doc__ 67 | _mod.__all__ = dir(_mod) 68 | sys.modules['wellcomeml.viz.palettes'] = _mod 69 | 70 | del _mod, sys 71 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := all 2 | 3 | VIRTUALENV := build/virtualenv 4 | 5 | ifeq ($(OS), Windows_NT) 6 | # for CYGWIN*|MINGW32*|MSYS*|MINGW* 7 | PYTHON_VERSION := C://Python38/python 8 | VENV_BIN := $(VIRTUALENV)/Scripts 9 | else 10 | PYTHON_VERSION := python3.8 11 | VENV_BIN := $(VIRTUALENV)/bin 12 | endif 13 | 14 | $(VIRTUALENV)/.installed: 15 | @if [ -d $(VIRTUALENV) ]; then rm -rf $(VIRTUALENV); fi 16 | @mkdir -p $(VIRTUALENV) 17 | $(PYTHON_VERSION) -m venv $(VIRTUALENV) 18 | $(VENV_BIN)/pip3 install --upgrade pip 19 | $(VENV_BIN)/pip3 install -r requirements_test.txt 20 | $(VENV_BIN)/pip3 install -r docs/requirements.txt # Installs requirements to docs 21 | $(VENV_BIN)/pip3 install -e .[tensorflow,spacy,torch,core,transformers] 22 | $(VENV_BIN)/pip3 install hdbscan --no-cache-dir --no-binary :all: --no-build-isolation 23 | touch $@ 24 | 25 | .PHONY: update-docs 26 | update-docs: 27 | $(VENV_BIN)/sphinx-apidoc --no-toc -d 5 -H WellcomeML -o ./docs -f wellcomeml 28 | . $(VENV_BIN)/activate && cd docs && make html 29 | 30 | .PHONY: virtualenv 31 | virtualenv: $(VIRTUALENV)/.installed 32 | 33 | .PHONY: dist 34 | dist: update-docs 35 | ./create_release.sh 36 | 37 | # Spacy is require for testing spacy_to_prodigy 38 | 39 | $(VIRTUALENV)/.models: 40 | $(VENV_BIN)/python -m spacy download en_core_web_sm 41 | touch $@ 42 | 43 | $(VIRTUALENV)/.deep_learning_models: 44 | # $(VENV_BIN)/python -m spacy download en_trf_bertbaseuncased_lg 45 | touch $@ 46 | 47 | $(VIRTUALENV)/.non_pypi_packages: 48 | # Install from local git directory - pip install [address] fails on Windows 49 | git clone https://github.com/epfml/sent2vec.git 50 | cd sent2vec && git checkout f00a1b67f4330e5be99e7cc31ac28df94deed9ac && $(VENV_BIN)/pip install . # Install latest compatible sent2vec 51 | @rm -rf sent2vec 52 | touch $@ 53 | 54 | .PHONY: download_models 55 | download_models: $(VIRTUALENV)/.installed $(VIRTUALENV)/.models 56 | 57 | .PHONY: download_deep_learning_models 58 | download_deep_learning_models: $(VIRTUALENV)/.models $(VIRTUALENV)/.deep_learning_models 59 | 60 | .PHONY: download_nonpypi_packages 61 | download_nonpypi_packages: $(VIRTUALENV)/.installed $(VIRTUALENV)/.non_pypi_packages 62 | 63 | .PHONY: test 64 | test: $(VIRTUALENV)/.models $(VIRTUALENV)/.deep_learning_models $(VIRTUALENV)/.non_pypi_packages 65 | $(VENV_BIN)/tox 66 | 67 | .PHONY: test-integrations 68 | test-integrations: 69 | $(VENV_BIN)/pytest -m "integration" -s -v --disable-warnings --tb=line ./tests 70 | 71 | .PHONY: run_codecov 72 | run_codecov: 73 | $(VENV_BIN)/python -m codecov 74 | 75 | all: virtualenv test 76 | -------------------------------------------------------------------------------- /examples/voting_classifier_ensemble.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | from sklearn.multiclass import OneVsRestClassifier 3 | from sklearn.linear_model import SGDClassifier 4 | from sklearn.naive_bayes import MultinomialNB 5 | from sklearn.pipeline import Pipeline 6 | 7 | from wellcomeml.ml.voting_classifier import WellcomeVotingClassifier 8 | 9 | X = [ 10 | "One two", 11 | "One", 12 | "Three and two", 13 | "Three" 14 | ] 15 | Y = [ 16 | [1, 1, 0], 17 | [1, 0, 0], 18 | [0, 1, 1], 19 | [0, 0, 1] 20 | ] 21 | 22 | vec = CountVectorizer() 23 | vec.fit(X) 24 | 25 | X_vec = vec.transform(X) 26 | 27 | sgd = OneVsRestClassifier(SGDClassifier(loss="log")) 28 | nb = OneVsRestClassifier(MultinomialNB()) 29 | 30 | sgd.fit(X_vec, Y) 31 | nb.fit(X_vec, Y) 32 | 33 | voting_classifier = WellcomeVotingClassifier( 34 | estimators=[sgd, nb], voting="soft", multilabel=True 35 | ) 36 | 37 | Y_pred = voting_classifier.predict(X_vec) 38 | print(Y_pred) 39 | 40 | Y = [1, 0, 1, 0] 41 | 42 | sgd = SGDClassifier(loss="log") 43 | nb = MultinomialNB() 44 | 45 | sgd.fit(X_vec, Y) 46 | nb.fit(X_vec, Y) 47 | 48 | voting_classifier = WellcomeVotingClassifier( 49 | estimators=[sgd, nb], voting="soft" 50 | ) 51 | 52 | Y_pred = voting_classifier.predict(X_vec) 53 | print(Y_pred) 54 | 55 | Y = [ 56 | [1, 1, 0], 57 | [1, 0, 0], 58 | [0, 1, 1], 59 | [0, 0, 1] 60 | ] 61 | 62 | pipe1 = Pipeline( 63 | [ 64 | ('count_vect', CountVectorizer()), 65 | ('sgd', OneVsRestClassifier(SGDClassifier(loss="log"))) 66 | ] 67 | ) 68 | pipe2 = Pipeline( 69 | [ 70 | ('count_vect', CountVectorizer()), 71 | ('nb', OneVsRestClassifier(MultinomialNB())) 72 | ] 73 | ) 74 | 75 | pipe1.fit(X, Y) 76 | pipe2.fit(X, Y) 77 | 78 | voting_classifier = WellcomeVotingClassifier( 79 | estimators=[pipe1, pipe2], voting="soft", multilabel=True 80 | ) 81 | Y_pred = voting_classifier.predict(X) 82 | print(Y_pred) 83 | 84 | Y = [1, 0, 1, 0] 85 | 86 | pipe1 = Pipeline( 87 | [ 88 | ('count_vect', CountVectorizer()), 89 | ('sgd', SGDClassifier(loss="log")) 90 | ] 91 | ) 92 | pipe2 = Pipeline( 93 | [ 94 | ('count_vect', CountVectorizer()), 95 | ('nb', MultinomialNB()) 96 | ] 97 | ) 98 | 99 | pipe1.fit(X, Y) 100 | pipe2.fit(X, Y) 101 | 102 | voting_classifier = WellcomeVotingClassifier( 103 | estimators=[sgd, nb], voting="soft" 104 | ) 105 | 106 | Y_pred = voting_classifier.predict(X_vec) 107 | print(Y_pred) 108 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | sys.path.insert(0, os.path.abspath('../')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'WellcomeML' 23 | copyright = '2021, Wellcome Data Labs' 24 | author = 'Wellcome Data Labs' 25 | 26 | about = {} # type: dict 27 | here = os.path.abspath(os.path.dirname(__file__)) 28 | version_path = os.path.join(here, '../wellcomeml', '__version__.py') 29 | with open(version_path, 'r') as f: 30 | exec(f.read(), about) 31 | 32 | # The full version, including alpha/beta/rc tags 33 | release = about["__version__"] 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | 42 | autodoc_inherit_docstrings = False 43 | 44 | extensions = [ 45 | "sphinx_rtd_theme", 46 | "sphinx.ext.napoleon", 47 | "sphinx_markdown_tables" 48 | ] 49 | 50 | 51 | source_suffix = ['.rst'] 52 | 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ['_templates'] 56 | 57 | # List of patterns, relative to source directory, that match files and 58 | # directories to ignore when looking for source files. 59 | # This pattern also affects html_static_path and html_extra_path. 60 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | html_theme = 'sphinx_rtd_theme' 69 | 70 | # Add any paths that contain custom static files (such as style sheets) here, 71 | # relative to this directory. They are copied after the builtin static files, 72 | # so a file named "default.css" will overwrite the builtin "default.css". 73 | html_static_path = ['_static'] 74 | -------------------------------------------------------------------------------- /wellcomeml/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tarfile 3 | import os 4 | 5 | import boto3 6 | from botocore import UNSIGNED 7 | from botocore.client import Config 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | MODELS_DIR = os.path.expanduser("~/.cache/wellcomeml/models") 12 | 13 | MODEL_DISPATCH = { 14 | 'scibert_scivocab_uncased': { 15 | "bucket": "ai2-s2-research", 16 | "path": "scibert/huggingface_pytorch/scibert_scivocab_uncased.tar", 17 | "file_name": "scibert_scivocab_uncased.tar" 18 | }, 19 | 'scibert_scivocab_cased': { 20 | "bucket": "ai2-s2-research", 21 | "path": "scibert/huggingface_pytorch/scibert_scivocab_cased.tar", 22 | "file_name": "scibert_scivocab_cased.tar" 23 | }, 24 | 'biosent2vec': { 25 | "bucket": "datalabs-public", 26 | "path": "models/ncbi-nlp/biosent2vec.bin", 27 | "file_name": "biosent2vec.bin", 28 | }, 29 | 'sent2vec_wiki_unigrams': { 30 | "bucket": "datalabs-public", 31 | "path": "models/epfml/wiki_unigrams.bin", 32 | "file_name": "wiki_unigrams.bin", 33 | } 34 | } 35 | 36 | 37 | def check_cache_and_download(model_name): 38 | """ Checks if model_name is cached and return complete path""" 39 | os.makedirs(MODELS_DIR, exist_ok=True) 40 | 41 | FILE_NAME = MODEL_DISPATCH[model_name]['file_name'] 42 | _, FILE_EXT = FILE_NAME.split(".") 43 | 44 | model_path = os.path.join(MODELS_DIR, model_name if FILE_EXT == "tar" else FILE_NAME) 45 | if not os.path.exists(model_path): 46 | logger.info(f"Could not find model {model_name}. Downloading from S3") 47 | 48 | # The following allows to download from S3 without AWS credentials 49 | s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED)) 50 | tmp_file = os.path.join(MODELS_DIR, FILE_NAME) 51 | 52 | s3.download_file(MODEL_DISPATCH[model_name]['bucket'], 53 | MODEL_DISPATCH[model_name]['path'], tmp_file) 54 | 55 | if FILE_EXT == 'tar': 56 | tar = tarfile.open(tmp_file) 57 | tar.extractall(path=MODELS_DIR) 58 | tar.close() 59 | 60 | os.remove(tmp_file) 61 | 62 | return model_path 63 | 64 | 65 | def throw_extra_import_message(error, extras, required_modules): 66 | """Safely throws an import error if it due to missing extras, and re-raising it otherwise""" 67 | if error.name in required_modules.split(','): 68 | raise ImportError(f"To use this class/module you need to install wellcomeml with {extras} " 69 | f"extras, e.g. pip install wellcomeml[{extras}]") 70 | else: 71 | raise error 72 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import os 5 | import tempfile 6 | 7 | import pytest 8 | 9 | from wellcomeml.io import read_jsonl, write_jsonl 10 | 11 | from .common import TEST_JSONL 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def temp_file(): 16 | temp_file, temp_file_name = tempfile.mkstemp() 17 | 18 | return temp_file_name 19 | 20 | 21 | def test_read_jsonl(): 22 | 23 | expected = [{ 24 | "text": "a b c\n a b c", 25 | "tokens": [ 26 | {'text': 'a', 'start': 0, 'end': 1, 'id': 0}, 27 | {'text': 'b', 'start': 2, 'end': 3, 'id': 1}, 28 | {'text': 'c', 'start': 4, 'end': 5, 'id': 2}, 29 | {'text': '\n ', 'start': 5, 'end': 7, 'id': 3}, 30 | {'text': 'a', 'start': 7, 'end': 8, 'id': 4}, 31 | {'text': 'b', 'start': 9, 'end': 10, 'id': 5}, 32 | {'text': 'c', 'start': 11, 'end': 12, 'id': 6} 33 | ], 34 | "spans": [ 35 | {'start': 2, 'end': 3, 'token_start': 1, "token_end": 2, "label": "b"}, 36 | {'start': 4, 'end': 5, 'token_start': 2, "token_end": 3, "label": "i"}, 37 | {'start': 7, 'end': 8, 'token_start': 4, "token_end": 5, "label": "i"}, 38 | {'start': 9, 'end': 10, 'token_start': 5, "token_end": 6, "label": "e"}, 39 | ] 40 | }] 41 | 42 | expected = expected * 3 43 | 44 | actual = read_jsonl(TEST_JSONL) 45 | assert expected == actual 46 | 47 | 48 | def test_write_jsonl(temp_file): 49 | 50 | expected = [{ 51 | "text": "a b c\n a b c", 52 | "tokens": [ 53 | {'text': 'a', 'start': 0, 'end': 1, 'id': 0}, 54 | {'text': 'b', 'start': 2, 'end': 3, 'id': 1}, 55 | {'text': 'c', 'start': 4, 'end': 5, 'id': 2}, 56 | {'text': '\n ', 'start': 5, 'end': 7, 'id': 3}, 57 | {'text': 'a', 'start': 7, 'end': 8, 'id': 4}, 58 | {'text': 'b', 'start': 9, 'end': 10, 'id': 5}, 59 | {'text': 'c', 'start': 11, 'end': 12, 'id': 6} 60 | ], 61 | "spans": [ 62 | {'start': 2, 'end': 3, 'token_start': 1, "token_end": 2, "label": "b"}, 63 | {'start': 4, 'end': 5, 'token_start': 2, "token_end": 3, "label": "i"}, 64 | {'start': 7, 'end': 8, 'token_start': 4, "token_end": 5, "label": "i"}, 65 | {'start': 9, 'end': 10, 'token_start': 5, "token_end": 6, "label": "e"}, 66 | ] 67 | }] 68 | 69 | expected = expected * 3 70 | 71 | write_jsonl(expected, temp_file) 72 | actual = read_jsonl(temp_file) 73 | 74 | assert expected == actual 75 | 76 | # Clean up 77 | 78 | if os.path.isfile(temp_file): 79 | os.remove(temp_file) 80 | -------------------------------------------------------------------------------- /tests/test_clustering.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from wellcomeml.ml.clustering import TextClustering 4 | 5 | 6 | @pytest.mark.parametrize("reducer,cluster_reduced", [("tsne", True), 7 | ("umap", True), 8 | ("umap", False)]) 9 | def test_full_pipeline(reducer, cluster_reduced, tmp_path): 10 | cluster = TextClustering(reducer=reducer, cluster_reduced=cluster_reduced, 11 | embedding_random_state=42, 12 | reducer_random_state=43, 13 | clustering_random_state=44) 14 | 15 | X = ['Wellcome Trust', 16 | 'The Wellcome Trust', 17 | 'Sir Henry Wellcome', 18 | 'Francis Crick', 19 | 'Crick Institute', 20 | 'Francis Harry Crick'] 21 | 22 | cluster.fit(X) 23 | 24 | assert len(cluster.cluster_kws) == len(cluster.cluster_ids) == 6 25 | 26 | cluster.save(folder=tmp_path) 27 | 28 | cluster_new = TextClustering() 29 | cluster_new.load(folder=tmp_path) 30 | 31 | # Asserts all coordinates of the loaded points are equal 32 | assert (cluster_new.embedded_points != cluster.embedded_points).sum() == 0 33 | assert (cluster_new.reduced_points != cluster.reduced_points).sum() == 0 34 | assert cluster_new.reducer_class.__class__ == cluster.reducer_class.__class__ 35 | assert cluster_new.clustering_class.__class__ == cluster.clustering_class.__class__ 36 | 37 | 38 | @pytest.mark.parametrize("reducer", ["tsne", "umap"]) 39 | def test_parameter_search(reducer): 40 | cluster = TextClustering(reducer=reducer) 41 | X = ['Wellcome Trust', 42 | 'The Wellcome Trust', 43 | 'Sir Henry Wellcome', 44 | 'Francis Crick', 45 | 'Crick Institute', 46 | 'Francis Harry Crick'] 47 | 48 | param_grid = { 49 | 'reducer': {'min_dist': [0.0], 50 | 'n_neighbors': [2], 51 | 'metric': ['cosine', 'euclidean']}, 52 | 'clustering': {'min_samples': [2], 53 | 'eps': [0.5]} 54 | } 55 | 56 | best_params = cluster.optimise(X, param_grid=param_grid, 57 | verbose=1, 58 | max_noise=1) 59 | 60 | # Asserts it found a parameter 61 | assert best_params is not None 62 | # Asserts the cross-validation results are returned correctly 63 | assert len(cluster.optimise_results['mean_test_silhouette']) == \ 64 | len(cluster.optimise_results['params']) 65 | # Asserts that silhouette is at least positive (for umap! - tsne dos 66 | # not work) 67 | if reducer != "tsne": 68 | assert cluster.silhouette > 0 69 | -------------------------------------------------------------------------------- /tests/datasets/test_winer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from wellcomeml.datasets.winer import create_train_test, _load_data_spacy 5 | 6 | 7 | @pytest.fixture(scope="module") 8 | def define_paths(): 9 | 10 | NE_path = "tests/test_data/mock_winer_CoarseNE.tar.bz2" 11 | docs_path = "tests/test_data/mock_winer_Documents.tar.bz2" 12 | vocab_path = "tests/test_data/mock_winer_document.vocab" 13 | 14 | # These will be generated and then deleted in these tests 15 | train_processed_path = "tests/test_data/temp_train_sample.txt" 16 | test_processed_path = "tests/test_data/temp_test_sample.txt" 17 | 18 | return (NE_path, docs_path, vocab_path, train_processed_path, test_processed_path) 19 | 20 | 21 | def test_train_test_documents(define_paths): 22 | 23 | ( 24 | NE_path, 25 | docs_path, 26 | vocab_path, 27 | train_processed_path, 28 | test_processed_path, 29 | ) = define_paths 30 | # Create the train/test data 31 | n_sample = 2 32 | prop_train = 0.5 33 | # There are 4 article IDs with entities in the sample data 34 | expected_train_size = round(prop_train * 4) 35 | 36 | create_train_test( 37 | NE_path, 38 | vocab_path, 39 | docs_path, 40 | train_processed_path, 41 | test_processed_path, 42 | n_sample, 43 | prop_train, 44 | rand_seed=42, 45 | ) 46 | 47 | docs_IDs = ["ID 1", "ID 2", "ID 3", "ID 4"] 48 | 49 | with open(train_processed_path, "r") as file: 50 | train_text = file.read() 51 | with open(test_processed_path, "r") as file: 52 | test_text = file.read() 53 | 54 | train_ids = [d for d in docs_IDs if d in train_text] 55 | test_ids = [d for d in docs_IDs if d in test_text] 56 | 57 | assert ( 58 | len(train_ids) == expected_train_size 59 | and len(test_ids) == (4 - expected_train_size) 60 | and len(set(train_ids + test_ids)) == 4 61 | ) 62 | os.remove(train_processed_path) 63 | os.remove(test_processed_path) 64 | 65 | 66 | def test_length(): 67 | X, Y = _load_data_spacy("tests/test_data/test_winer.txt", inc_outside=True) 68 | 69 | assert len(X) == len(Y) and len(X) == 168 70 | 71 | 72 | def test_entity(): 73 | X, Y = _load_data_spacy("tests/test_data/test_winer.txt", inc_outside=False) 74 | 75 | start = Y[101][1]["start"] 76 | end = Y[101][1]["end"] 77 | 78 | assert X[101][start:end] == "Spain" 79 | 80 | 81 | def test_no_outside_entities(): 82 | X, Y = _load_data_spacy("tests/test_data/test_winer.txt", inc_outside=False) 83 | 84 | outside_entities = [ 85 | entity for entities in Y for entity in entities if entity["label"] == "O" 86 | ] 87 | 88 | assert len(outside_entities) == 0 89 | -------------------------------------------------------------------------------- /tests/test_entity_linking.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from wellcomeml.ml.similarity_entity_linking import SimilarityEntityLinker 3 | 4 | 5 | @pytest.fixture(scope="module") 6 | def entities_kb(): 7 | return { 8 | 'id_1': "American actress. She is the recipient of several accolades, including two Golden" 9 | " Globe Awards and a Primetime Emmy Award, in addition to nominations for four" 10 | " Academy Awards and one Tony Award.", 11 | 'id_2': "American entertainer. She rose to fame in the 2000s as a member of R&B girl group" 12 | " Destiny's Child, one of the best-selling female groups of all time with over" 13 | " 60 million records, of which more than 35 million copies sold with the trio" 14 | " lineup with Williams.", 15 | 'id_3': " " 16 | } 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def stopwords(): 21 | return ['the', 'and', 'if', 'in', 'a'] 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def train_data(): 26 | return [ 27 | ("After Destiny's Child's disbanded in 2006, Michelle Williams released her first " 28 | "pop album, Unexpected (2008),", {'id': 'id_2'}), 29 | ("On Broadway, Michelle Williams starred in revivals of the musical Cabaret in 2014" 30 | " and the drama Blackbird in 2016, for which she received a nomination for the Tony Award" 31 | " for Best Actress in a Play.", {'id': 'id_1'}), 32 | ("Franklin would have ideally been awarded a Nobel Prize in Chemistry", {'id': 'No ID'}) 33 | ] 34 | 35 | 36 | def test_clean_kb(entities_kb, stopwords): 37 | 38 | entity_linker = SimilarityEntityLinker(stopwords=stopwords) 39 | knowledge_base = entity_linker._clean_kb(entities_kb) 40 | 41 | assert len(knowledge_base) == 2 42 | 43 | 44 | def test_optimise_threshold(entities_kb, stopwords, train_data): 45 | entity_linker = SimilarityEntityLinker(stopwords=stopwords) 46 | entity_linker.fit(entities_kb) 47 | entity_linker.optimise_threshold(train_data, id_col='id', no_id_col='No ID') 48 | optimal_threshold = entity_linker.optimal_threshold 49 | 50 | assert isinstance(optimal_threshold, float) 51 | 52 | 53 | def test_predict_lowthreshold(entities_kb, stopwords, train_data): 54 | entity_linker = SimilarityEntityLinker(stopwords=stopwords) 55 | entity_linker.fit(entities_kb) 56 | predictions = entity_linker.predict(train_data, similarity_threshold=0.1, no_id_col='No ID') 57 | 58 | assert predictions == ['id_2', 'id_1', 'No ID'] 59 | 60 | 61 | def test_predict_highthreshold(entities_kb, stopwords, train_data): 62 | entity_linker = SimilarityEntityLinker(stopwords=stopwords) 63 | entity_linker.fit(entities_kb) 64 | predictions = entity_linker.predict(train_data, similarity_threshold=1.0, no_id_col='No ID') 65 | 66 | assert predictions == ['No ID', 'No ID', 'No ID'] 67 | -------------------------------------------------------------------------------- /tests/io/epmc/test_client.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | import pytest 3 | 4 | from wellcomeml.io.epmc.client import EPMCClient 5 | 6 | 7 | @pytest.fixture 8 | def epmc_client(): 9 | return EPMCClient( 10 | max_retries=3 11 | ) 12 | 13 | 14 | def test_search(epmc_client): 15 | epmc_client._execute_query = MagicMock() 16 | epmc_client.search( 17 | "session", "query", result_type="not core", 18 | page_size=15, only_first=False 19 | ) 20 | 21 | expected_params = { 22 | "query": "query", 23 | "format": "json", 24 | "resultType": "not core", 25 | "pageSize": 15 26 | } 27 | epmc_client._execute_query.assert_called_with("session", expected_params, False) 28 | 29 | 30 | def test_search_by_pmid(epmc_client): 31 | epmc_client.search = MagicMock(return_value="results") 32 | epmc_client.search_by_pmid("session", "pmid") 33 | epmc_client.search.assert_called_with("session", "ext_id:pmid", only_first=True) 34 | 35 | 36 | def test_search_by_doi(epmc_client): 37 | epmc_client.search = MagicMock(return_value="results") 38 | epmc_client.search_by_doi("session", "doi") 39 | epmc_client.search.assert_called_with("session", "doi:doi", only_first=True) 40 | 41 | 42 | def test_search_by_pmcid(epmc_client): 43 | epmc_client.search = MagicMock(return_value="results") 44 | epmc_client.search_by_pmcid("session", "PMCID0") 45 | epmc_client.search.assert_called_with("session", "pmcid:PMCID0", only_first=True) 46 | 47 | 48 | def test_search_by_invalid_pmcid(epmc_client): 49 | epmc_client.search = MagicMock(return_value="results") 50 | with pytest.raises(ValueError): 51 | epmc_client.search_by_pmcid("session", "pmcid") 52 | 53 | 54 | def test_get_full_text(epmc_client): 55 | epmc_client._get_response_content = MagicMock(return_value="content") 56 | epmc_client.get_full_text("session", "pmid") 57 | 58 | epmc_endpoint = epmc_client.api_endpoint 59 | epmc_client._get_response_content.assert_called_with( 60 | "session", 61 | f"{epmc_endpoint}/pmid/fullTextXML" 62 | ) 63 | 64 | 65 | def test_get_references(epmc_client): 66 | epmc_client._get_response_json = MagicMock(return_value={"references": []}) 67 | epmc_client.get_references("session", "pmid") 68 | 69 | epmc_endpoint = epmc_client.api_endpoint 70 | params = {"format": "json", "page": 1, "pageSize": 1000} 71 | epmc_client._get_response_json.assert_called_with( 72 | "session", 73 | f"{epmc_endpoint}/MED/pmid/references", 74 | params 75 | ) 76 | 77 | 78 | def test_get_citations(epmc_client): 79 | epmc_client._get_response_json = MagicMock(return_value={"references": []}) 80 | epmc_client.get_citations("session", "pmid") 81 | 82 | epmc_endpoint = epmc_client.api_endpoint 83 | params = {"format": "json", "page": 1, "pageSize": 1000} 84 | epmc_client._get_response_json.assert_called_with( 85 | "session", 86 | f"{epmc_endpoint}/MED/pmid/citations", 87 | params 88 | ) 89 | -------------------------------------------------------------------------------- /tests/metrics/test_f1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from wellcomeml.metrics import f1_loss, f1_metric 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def tmpdir(tmpdir_factory): 13 | return tmpdir_factory.mktemp("test_f1") 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def data(): 18 | X_train = np.random.random((100, 10)) 19 | y_train = np.random.random(100).astype(int) 20 | 21 | X_test = np.random.random((100, 10)) 22 | y_test = np.random.random(100).astype(int) 23 | 24 | return {"X_train": X_train, "y_train": y_train, "X_test": X_test, "y_test": y_test} 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def model(): 29 | inputs = tf.keras.Input(shape=(10,)) 30 | x = tf.keras.layers.Dense(128, activation="relu")(inputs) 31 | outputs = tf.keras.layers.Dense(1, "sigmoid")(x) 32 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 33 | 34 | return model 35 | 36 | 37 | def test_f1_metric_all_true(): 38 | 39 | y_true = [0, 1, 1, 0] 40 | y_pred = [0, 1, 1, 0] 41 | 42 | f1 = f1_metric(y_true, y_pred) 43 | 44 | assert isinstance(f1, tf.Tensor) 45 | assert f1 == 1.0 46 | 47 | 48 | def test_f1_metric_all_false(): 49 | 50 | y_true = [0, 1, 1, 0] 51 | y_pred = [0, 0, 0, 0] 52 | 53 | f1 = f1_metric(y_true, y_pred) 54 | 55 | assert isinstance(f1, tf.Tensor) 56 | assert f1 == 0.0 57 | 58 | 59 | def test_f1_metric_poor_recall(): 60 | 61 | y_true = [0, 1, 1, 0] 62 | y_pred = [0, 1, 0, 0] 63 | 64 | f1 = f1_metric(y_true, y_pred) 65 | 66 | assert isinstance(f1, tf.Tensor) 67 | assert f1 == 0.66666657 68 | 69 | 70 | def test_f1_metric_poor_precision(): 71 | 72 | y_true = [0, 1, 1, 0] 73 | y_pred = [1, 0, 0, 0] 74 | 75 | f1 = f1_metric(y_true, y_pred) 76 | 77 | assert isinstance(f1, tf.Tensor) 78 | assert f1 == 0.0 79 | 80 | 81 | def test_f1_metric(data, model): 82 | """ Test whether the f1_metrics are output""" 83 | 84 | model.compile( 85 | loss="binary_crossentropy", 86 | optimizer="adam", 87 | metrics=[f1_metric], 88 | ) 89 | 90 | history = model.fit( 91 | data["X_train"], 92 | data["y_train"], 93 | epochs=5, 94 | validation_data=(data["X_test"], data["y_test"]), 95 | batch_size=1024, 96 | verbose=0, 97 | ) 98 | 99 | assert set(history.history.keys()) == set( 100 | ["loss", "f1_metric", "val_loss", "val_f1_metric"] 101 | ) 102 | 103 | 104 | def test_f1_loss(data, model): 105 | """ Test to see if it runs, don't test the loss itself """ 106 | 107 | model.compile( 108 | loss=f1_loss, 109 | optimizer="adam", 110 | metrics=["accuracy"], 111 | ) 112 | 113 | model.fit( 114 | data["X_train"], 115 | data["y_train"], 116 | epochs=5, 117 | validation_data=(data["X_test"], data["y_test"]), 118 | batch_size=1024, 119 | verbose=0, 120 | ) 121 | -------------------------------------------------------------------------------- /docs/_build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /tests/test_transformers_tokenizer.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import pickle 3 | 4 | import pytest 5 | 6 | from wellcomeml.ml.transformers_tokenizer import TransformersTokenizer 7 | 8 | 9 | texts = [ 10 | "This is a test", 11 | "Another sentence", 12 | "Don't split" 13 | ] 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def tokenizer(): 18 | tokenizer = TransformersTokenizer() 19 | tokenizer.fit(texts) 20 | return tokenizer 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def tmp_path(): 25 | with tempfile.TemporaryDirectory() as tmp_dir: 26 | tmp_path = f"{tmp_dir}/tokenizer.json" 27 | yield tmp_path 28 | 29 | 30 | def test_tokenize(tokenizer): 31 | tokens = tokenizer.tokenize("This is a test") 32 | assert len(tokens) == 4 33 | assert type(tokens[0]) == str 34 | 35 | 36 | def test_tokenize_batch(tokenizer): 37 | tokens = tokenizer.tokenize(["This is a test", "test"]) 38 | assert len(tokens) == 2 39 | 40 | 41 | def test_encode(tokenizer): 42 | token_ids = tokenizer.encode("This is a test") 43 | assert len(token_ids) == 4 44 | assert type(token_ids[0]) == int 45 | 46 | 47 | def test_encode_batch(tokenizer): 48 | token_ids = tokenizer.encode(["This is a test", "test"]) 49 | assert len(token_ids) == 2 50 | 51 | 52 | def test_decode(tokenizer): 53 | token_ids = tokenizer.encode("This is a test") 54 | text = tokenizer.decode(token_ids) 55 | assert text == "this is a test" 56 | 57 | 58 | def test_decode_batch(tokenizer): 59 | token_ids = tokenizer.encode(["This is a test", "test"]) 60 | texts = tokenizer.decode(token_ids) 61 | assert texts == ["this is a test", "test"] 62 | 63 | 64 | def test_decode_empty(tokenizer): 65 | assert tokenizer.decode([]) == "" 66 | 67 | 68 | def test_unknown_token(tokenizer): 69 | tokens = tokenizer.tokenize("I have not seen this before") 70 | assert "[UNK]" in tokens 71 | 72 | 73 | def test_save(tokenizer, tmp_path): 74 | tokenizer.save(tmp_path) 75 | 76 | loaded_tokenizer = TransformersTokenizer() 77 | loaded_tokenizer.load(tmp_path) 78 | tokens = loaded_tokenizer.tokenize("This is a test") 79 | assert len(tokens) == 4 80 | 81 | 82 | def test_pickle(tokenizer, tmp_path): 83 | with open(tmp_path, "wb") as f: 84 | f.write(pickle.dumps(tokenizer)) 85 | 86 | with open(tmp_path, "rb") as f: 87 | unpickled_tokenizer = pickle.loads(f.read()) 88 | 89 | tokens = unpickled_tokenizer.tokenize("This is a test") 90 | assert len(tokens) == 4 91 | 92 | 93 | def test_bpe_model(): 94 | tokenizer = TransformersTokenizer(model="bpe") 95 | tokenizer.fit(texts) 96 | tokens = tokenizer.tokenize("This is a test") 97 | assert len(tokens) == 4 98 | 99 | 100 | def test_lowercase(): 101 | tokenizer = TransformersTokenizer(lowercase=False) 102 | tokenizer.fit(texts) 103 | tokens = tokenizer.tokenize("This is a test") 104 | assert tokens[0] == "This" 105 | 106 | 107 | def test_vocab_size(): 108 | tokenizer = TransformersTokenizer(vocab_size=30) 109 | tokenizer.fit(texts) 110 | vocab = tokenizer.vocab 111 | print(vocab) 112 | assert len(vocab) == 30 113 | -------------------------------------------------------------------------------- /tests/test_keras_vectorizer.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import os 3 | 4 | import pytest 5 | 6 | from wellcomeml.ml.keras_vectorizer import KerasVectorizer, KerasTokenizer 7 | 8 | 9 | @pytest.fixture 10 | def tokenizer(): 11 | tokenizer = KerasTokenizer() 12 | tokenizer.fit([ 13 | "This is a test", 14 | "Another sentence", 15 | "Don't split" 16 | ]) 17 | return tokenizer 18 | 19 | 20 | def test_vanilla(): 21 | X = ["One", "Two", "Three Four"] 22 | 23 | keras_vectorizer = KerasVectorizer() 24 | X_vec = keras_vectorizer.fit_transform(X) 25 | 26 | assert X_vec.shape[0] == 3 27 | assert X_vec.shape[1] == 2 28 | assert X_vec.max() == 5 # 4 tokens including OOV 29 | 30 | 31 | def test_sequence_length(): 32 | X = ["One", "Two", "Three"] 33 | 34 | sequence_length = 5 35 | keras_vectorizer = KerasVectorizer(sequence_length=sequence_length) 36 | X_vec = keras_vectorizer.fit_transform(X) 37 | 38 | assert X_vec.shape[1] == sequence_length 39 | 40 | 41 | def test_vocab_size(): 42 | X = ["One", "Two", "Three"] 43 | 44 | vocab_size = 1 45 | keras_vectorizer = KerasVectorizer(vocab_size=vocab_size) 46 | X_vec = keras_vectorizer.fit_transform(X) 47 | 48 | assert X_vec.max() == vocab_size 49 | 50 | 51 | def test_build_embedding_matrix(): 52 | 53 | X = ["One", "Two", "Three"] 54 | 55 | vocab_size = 1 56 | keras_vectorizer = KerasVectorizer(vocab_size=vocab_size) 57 | keras_vectorizer.fit(X) 58 | 59 | with tempfile.TemporaryDirectory() as tmp_dir: 60 | embeddings_path = os.path.join(tmp_dir, "embeddings.csv") 61 | embeddings = [ 62 | "one 0 1 0 0 0", 63 | "two 0 0 1 0 0", 64 | "three 0 0 0 1 0", 65 | "four 0 0 0 0 1", 66 | ] 67 | with open(embeddings_path, "w") as embeddings_path_tmp: 68 | for line in embeddings: 69 | embeddings_path_tmp.write(line) 70 | embeddings_path_tmp.write("\n") 71 | embedding_matrix = keras_vectorizer.build_embedding_matrix( 72 | embeddings_name_or_path=embeddings_path 73 | ) 74 | 75 | assert embedding_matrix.shape == (5, 5) 76 | 77 | 78 | def test_build_embedding_matrix_word_vectors(): 79 | 80 | X = ["One", "Two", "Three"] 81 | 82 | vocab_size = 1 83 | keras_vectorizer = KerasVectorizer(vocab_size=vocab_size) 84 | keras_vectorizer.fit(X) 85 | 86 | embedding_matrix = keras_vectorizer.build_embedding_matrix( 87 | embeddings_name_or_path="glove-twitter-25" 88 | ) 89 | 90 | assert embedding_matrix.shape == (5, 25) 91 | 92 | 93 | def test_infer_from_data(): 94 | X = ["One", "Two words", "Three words here"] 95 | 96 | keras_vectorizer = KerasVectorizer() 97 | keras_vectorizer.fit(X) 98 | 99 | assert keras_vectorizer.sequence_length == 3 100 | 101 | 102 | def test_keras_tokenizer_decode(tokenizer): 103 | token_ids = tokenizer.encode("This is a test") 104 | text = tokenizer.decode(token_ids) 105 | assert text == "this is a test" 106 | 107 | 108 | def test_keras_tokenizer_decode_batch(tokenizer): 109 | token_ids = tokenizer.encode(["This is", "a test"]) 110 | texts = tokenizer.decode(token_ids) 111 | assert texts == ["this is", "a test"] 112 | 113 | 114 | def test_keras_tokenizer_decode_empty(tokenizer): 115 | assert tokenizer.decode([]) == "" 116 | -------------------------------------------------------------------------------- /wellcomeml/ml/attention.py: -------------------------------------------------------------------------------- 1 | from wellcomeml.utils import throw_extra_import_message 2 | 3 | try: 4 | import tensorflow as tf 5 | except ImportError as e: 6 | throw_extra_import_message(error=e, required_modules='tensorflow', extras='tensorflow') 7 | 8 | 9 | class SelfAttention(tf.keras.layers.Layer): 10 | """https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf""" 11 | 12 | def __init__(self, attention_dim=20): 13 | super(SelfAttention, self).__init__() 14 | self.attention_dim = attention_dim 15 | 16 | def build(self, input_shape): 17 | self.WQ = self.add_weight( 18 | shape=(input_shape[-1], self.attention_dim), 19 | trainable=True, 20 | initializer="uniform", 21 | ) 22 | self.WK = self.add_weight( 23 | shape=(input_shape[-1], self.attention_dim), 24 | trainable=True, 25 | initializer="uniform", 26 | ) 27 | self.WV = self.add_weight( 28 | shape=(input_shape[-1], input_shape[-1]), 29 | trainable=True, 30 | initializer="uniform", 31 | ) 32 | 33 | def call(self, X): 34 | """ 35 | In: (batch_size, sequence_length, embedding_dimension) 36 | Out: (batch_size, sequence_length, embedding_dimension) 37 | """ 38 | Q = tf.matmul(X, self.WQ) 39 | K = tf.matmul(X, self.WK) 40 | V = tf.matmul(X, self.WV) 41 | 42 | attention_scores = tf.nn.softmax(tf.matmul(Q, tf.transpose(K, perm=[0, 2, 1]))) 43 | return tf.matmul(attention_scores, V) 44 | 45 | 46 | class FeedForwardAttention(tf.keras.layers.Layer): 47 | """https://colinraffel.com/publications/iclr2016feed.pdf""" 48 | 49 | def __init__(self): 50 | super(FeedForwardAttention, self).__init__() 51 | 52 | def build(self, input_shape): 53 | self.W = self.add_weight( 54 | shape=(input_shape[-1], 1), trainable=True, initializer="uniform" 55 | ) 56 | 57 | def call(self, X): 58 | """ 59 | In: (batch_size, sequence_length, embedding_dimension) 60 | Out: (batch_size, embedding_dimension) 61 | """ 62 | e = tf.math.tanh(tf.matmul(X, self.W)) 63 | attention_scores = tf.nn.softmax(e) 64 | return tf.matmul(tf.transpose(X, perm=[0, 2, 1]), attention_scores) 65 | 66 | 67 | class HierarchicalAttention(tf.keras.layers.Layer): 68 | """https://www.aclweb.org/anthology/N16-1174/""" 69 | 70 | def __init__(self, attention_heads='same'): 71 | super(HierarchicalAttention, self).__init__() 72 | self.attention_heads = attention_heads 73 | 74 | def build(self, input_shape): 75 | if self.attention_heads == 'same': 76 | nb_attention_heads = input_shape[-2] 77 | else: 78 | nb_attention_heads = self.attention_heads 79 | self.attention_matrix = self.add_weight( 80 | shape=(input_shape[-1], nb_attention_heads), 81 | trainable=True, 82 | initializer="uniform", 83 | name="attention_matrix" 84 | ) 85 | 86 | def call(self, X): 87 | """ 88 | In: (batch_size, sequence_length, embedding_dimension) 89 | Out: (batch_size, sequence_length, embedding_dimension) 90 | """ 91 | attention_scores = tf.nn.softmax( 92 | tf.math.tanh(tf.matmul(X, self.attention_matrix)) 93 | ) 94 | return tf.matmul(tf.transpose(attention_scores, perm=[0, 2, 1]), X) 95 | -------------------------------------------------------------------------------- /tests/test_spacy_entity_linking.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import pytest 3 | 4 | from wellcomeml.ml.spacy_knowledge_base import SpacyKnowledgeBase 5 | from wellcomeml.ml.spacy_entity_linking import SpacyEntityLinker 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def entities(): 10 | # A dict of each entity, it's description and it's corpus frequency 11 | return { 12 | "id_1": ( 13 | "American actress. She is the recipient of several accolades, including two Golden" 14 | " Globe Awards and a Primetime Emmy Award, in addition to nominations for four" 15 | " Academy Awards and one Tony Award.", 16 | 0.1, 17 | ), 18 | "id_2": ( 19 | "American entertainer. She rose to fame in the 2000s as a member of R&B girl group" 20 | " Destiny's Child, one of the best-selling female groups of all time with over" 21 | " 60 million records, of which more than 35 million copies sold with the trio" 22 | " lineup with Williams.", 23 | 0.05, 24 | ), 25 | } 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def list_aliases(): 30 | # A list of dicts for each entity 31 | # probabilities are 'prior probabilities' and must sum to < 1 32 | return [ 33 | { 34 | "alias": "Michelle Williams", 35 | "entities": ["id_1", "id_2"], 36 | "probabilities": [0.7, 0.3], 37 | } 38 | ] 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def data(): 43 | return [ 44 | ( 45 | "After Destiny's Child's disbanded in 2006. Michelle Williams released her first " 46 | "pop album, Unexpected (2008),", 47 | {"links": {(43, 60): {"id_1": 0.0, "id_2": 1.0}}}, 48 | ), 49 | ( 50 | "On Broadway, Michelle Williams starred in revivals of the musical Cabaret in 2014" 51 | " and the drama Blackbird in 2016." 52 | " For which she received a nomination for the Tony Award" 53 | " for Best Actress in a Play.", 54 | {"links": {(13, 30): {"id_1": 1.0, "id_2": 0.0}}}, 55 | ), 56 | ] 57 | 58 | 59 | def test_kb_train(entities, list_aliases): 60 | 61 | kb = SpacyKnowledgeBase(kb_model="en_core_web_sm") 62 | kb.train(entities, list_aliases) 63 | 64 | assert sorted(kb.kb.get_entity_strings()) == ["id_1", "id_2"] 65 | assert kb.kb.get_alias_strings() == ["Michelle Williams"] 66 | 67 | 68 | def test_el_train(entities, list_aliases, data): 69 | 70 | with tempfile.TemporaryDirectory() as tmp_dir: 71 | temp_kb = SpacyKnowledgeBase(kb_model="en_core_web_sm") 72 | temp_kb.train(entities, list_aliases) 73 | temp_kb.save(tmp_dir) 74 | el = SpacyEntityLinker(tmp_dir, print_output=False) 75 | el.train(data) 76 | 77 | assert "entity_linker" in el.nlp.pipe_names 78 | 79 | 80 | def test_el_predict(entities, list_aliases, data): 81 | 82 | with tempfile.TemporaryDirectory() as tmp_dir: 83 | temp_kb = SpacyKnowledgeBase(kb_model="en_core_web_sm") 84 | temp_kb.train(entities, list_aliases) 85 | temp_kb.save(tmp_dir) 86 | el = SpacyEntityLinker(tmp_dir, print_output=False) 87 | el.train(data) 88 | predicted_ids = el.predict(data) 89 | 90 | entity_ids = temp_kb.kb.get_entity_strings() 91 | bad_entity_ids = [p for p in predicted_ids if p[0] not in entity_ids] 92 | 93 | assert len(predicted_ids) == 2 94 | assert len(bad_entity_ids) == 0 95 | -------------------------------------------------------------------------------- /wellcomeml/datasets/conll.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from wellcomeml.datasets.download import check_cache_and_download 5 | 6 | 7 | def _load_data_spacy(data_path, inc_outside=True): 8 | """ 9 | Load data in Spacy format: 10 | X = list of sentences (plural) / documents ['the cat ...', 'some dog...', ...] 11 | Y = list of list of entity tags for each sentence 12 | [[{'start': 36, 'end': 46, 'label': 'PERSON'}, {..}, ..], ... ] 13 | inc_outside = False: don't include none-entities in the output 14 | 15 | Raw format is: 16 | '-DOCSTART- -X- O O\n\nEU NNP I-NP I-ORG\nrejects VBZ I-VP O\nGerman JJ I-NP I-MISC...' 17 | where each article is separated by '-DOCSTART- -X- O O\n', 18 | each sentence is separate by a blank line, 19 | and the entity information is in the form 20 | 'EU NNP I-NP I-ORG' (A word, a part-of-speech (POS) tag, 21 | a syntactic chunk tag and the named entity tag) 22 | """ 23 | 24 | X = [] 25 | Y = [] 26 | with open(data_path) as f: 27 | articles = f.read().split("-DOCSTART- -X- O O\n\n") 28 | articles = articles[1:] # The first will be blank 29 | 30 | for article in articles: 31 | # Each sentence in the article is separated by a blank line 32 | sentences = article.split("\n\n") 33 | 34 | for sentence in sentences: 35 | char_i = 0 # A counter for the entity start and end character indices 36 | sentence_text = "" 37 | sentence_tags = [] 38 | entities = sentence.split("\n") 39 | 40 | for entity in entities: 41 | # Due to the splitting on '\n' sometimes we are left with empty elements 42 | 43 | if len(entity) != 0: 44 | token, _, _, tag = entity.split(" ") 45 | sentence_text += token + " " 46 | 47 | if tag != "O" or inc_outside: 48 | sentence_tags.append( 49 | { 50 | "start": char_i, 51 | "end": char_i + len(token), 52 | "label": tag, 53 | } 54 | ) 55 | char_i += len(token) + 1 # plus 1 for the space separating 56 | 57 | if sentence_tags != []: 58 | X.append(sentence_text) 59 | Y.append(sentence_tags) 60 | 61 | return X, Y 62 | 63 | 64 | def load_conll(split="train", shuffle=True, inc_outside=True, dataset: str = "conll"): 65 | """Load the conll dataset 66 | 67 | Args: 68 | split(str): Which split of the data to collect, one of ["train", "test", 69 | "evaluate"]. 70 | shuffle(bool): Should the data be shuffled with random.shuffle? 71 | inc_outside(bool): Should outside charavters be included? 72 | dataset(str): Which dataset to load. This defaults to "conll" and should 73 | only be altered for test purposes in which case it should be set to 74 | "test_conll". 75 | """ 76 | path = check_cache_and_download(dataset) 77 | 78 | map = {"train": "eng.train", "test": "eng.testa", "evaluate": "eng.testb"} 79 | 80 | try: 81 | data_path = os.path.join(path, map[split]) 82 | X, Y = _load_data_spacy(data_path, inc_outside=inc_outside) 83 | except KeyError: 84 | raise KeyError(f"Split argument {split} is not one of train, test or evaluate") 85 | 86 | if shuffle: 87 | data = list(zip(X, Y)) 88 | random.shuffle(data) 89 | X, Y = zip(*data) 90 | 91 | return X, Y 92 | -------------------------------------------------------------------------------- /tests/test_s3_policy_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import boto3 4 | from botocore.stub import Stubber 5 | 6 | from wellcomeml.io import s3_policy_data 7 | 8 | 9 | def stubber_responses(stubber, mock_hash_file=None): 10 | 11 | list_buckets_response = { 12 | "Contents": [ 13 | { 14 | "Key": "good/path/file1.json" 15 | }, 16 | { 17 | "Key": "bad/path/file2.json" 18 | } 19 | ] 20 | } 21 | expected_params = {'Bucket': 'datalabs-dev'} 22 | stubber.add_response('list_objects_v2', list_buckets_response, expected_params) 23 | 24 | if mock_hash_file: 25 | get_object_response = { 26 | "Body": mock_hash_file 27 | } 28 | expected_params = {'Bucket': 'datalabs-dev', 'Key': 'good/path/file1.json'} 29 | stubber.add_response('get_object', get_object_response, expected_params) 30 | 31 | return stubber 32 | 33 | 34 | def policy_downloader(s3): 35 | return s3_policy_data.PolicyDocumentsDownloader( 36 | s3=s3, 37 | bucket_name="datalabs-dev", 38 | dir_path="good/path" 39 | ) 40 | 41 | 42 | def test_get_keys(): 43 | 44 | s3 = boto3.client('s3') 45 | stubber = Stubber(s3) 46 | stubber = stubber_responses(stubber) 47 | 48 | with stubber: 49 | policy_s3 = policy_downloader(s3) 50 | pdf_keys = policy_s3.pdf_keys 51 | 52 | assert pdf_keys == ['good/path/file1.json'] 53 | 54 | 55 | def test_get_hashes_with_word(): 56 | 57 | s3 = boto3.client('s3') 58 | stubber = Stubber(s3) 59 | 60 | with open('tests/test_data/mock_s3_contents.json.gz', 'rb') as mock_hash_file: 61 | stubber = stubber_responses(stubber, mock_hash_file) 62 | 63 | with stubber: 64 | policy_s3 = policy_downloader(s3) 65 | hash_dicts = policy_s3.get_hashes(word_list=['the']) 66 | hash_list = [hash_dict['file_hash'] for hash_dict in hash_dicts] 67 | 68 | assert hash_list == ['x002'] 69 | 70 | 71 | def test_get_hashes(): 72 | 73 | s3 = boto3.client('s3') 74 | stubber = Stubber(s3) 75 | 76 | with open('tests/test_data/mock_s3_contents.json.gz', 'rb') as mock_hash_file: 77 | stubber = stubber_responses(stubber, mock_hash_file) 78 | 79 | with stubber: 80 | policy_s3 = policy_downloader(s3) 81 | hash_dicts = policy_s3.get_hashes() 82 | hash_list = [hash_dict['file_hash'] for hash_dict in hash_dicts] 83 | hash_list.sort() 84 | 85 | assert hash_list == ['x001', 'x002'] 86 | 87 | 88 | def test_download_all_hash(): 89 | 90 | s3 = boto3.client('s3') 91 | stubber = Stubber(s3) 92 | 93 | with open('tests/test_data/mock_s3_contents.json.gz', 'rb') as mock_hash_file: 94 | stubber = stubber_responses(stubber, mock_hash_file) 95 | 96 | with stubber: 97 | policy_s3 = policy_downloader(s3) 98 | documents = policy_s3.download(hash_list=None) 99 | 100 | document_hashes = [document['file_hash'] for document in documents] 101 | document_hashes.sort() 102 | 103 | assert document_hashes == ['x001', 'x002'] 104 | 105 | 106 | def test_download_one_hash(): 107 | 108 | s3 = boto3.client('s3') 109 | stubber = Stubber(s3) 110 | 111 | with open('tests/test_data/mock_s3_contents.json.gz', 'rb') as mock_hash_file: 112 | stubber = stubber_responses(stubber, mock_hash_file) 113 | 114 | with stubber: 115 | policy_s3 = policy_downloader(s3) 116 | documents = policy_s3.download(hash_list=['x002']) 117 | 118 | document_hashes = [document['file_hash'] for document in documents] 119 | assert document_hashes == ['x002'] 120 | -------------------------------------------------------------------------------- /tests/test_bert_classifier.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import pytest 3 | import tempfile 4 | 5 | import numpy as np 6 | 7 | from wellcomeml.ml.bert_classifier import BertClassifier 8 | 9 | 10 | @pytest.fixture 11 | def multilabel_bert(scope='module'): 12 | model = BertClassifier() 13 | model._init_model(num_labels=4) 14 | 15 | return model 16 | 17 | 18 | @pytest.mark.bert 19 | def test_multilabel(multilabel_bert): 20 | X = [ 21 | "One and two", 22 | "One only", 23 | "Three and four, nothing else", 24 | "Two nothing else", 25 | "Two and three" 26 | ] 27 | Y = np.array([ 28 | [1, 1, 0, 0], 29 | [1, 0, 0, 0], 30 | [0, 0, 1, 1], 31 | [0, 1, 0, 0], 32 | [0, 1, 1, 0] 33 | ]) 34 | 35 | model = multilabel_bert 36 | model.fit(X, Y) 37 | Y_pred = model.predict(X) 38 | Y_prob_pred = model.predict_proba(X) 39 | assert Y_pred.sum() != 0 40 | assert Y_pred.sum() != Y.size 41 | assert Y_prob_pred.max() <= 1 42 | assert Y_prob_pred.min() >= 0 43 | assert Y_pred.shape == Y.shape 44 | assert Y_prob_pred.shape == Y.shape 45 | assert model.losses[0] > model.losses[-1] 46 | 47 | 48 | @pytest.mark.bert 49 | def test_multiclass(): 50 | X = [ 51 | "One oh yes", 52 | "Two noo", 53 | "Three ok", 54 | "one fantastic", 55 | "two bad" 56 | ] 57 | Y = np.array([ 58 | [1, 0, 0], 59 | [0, 1, 0], 60 | [0, 0, 1], 61 | [1, 0, 0], 62 | [0, 1, 0] 63 | ]) 64 | 65 | model = BertClassifier(multilabel=False) 66 | model.fit(X, Y) 67 | Y_pred = model.predict(X) 68 | Y_prob_pred = model.predict_proba(X) 69 | assert Y_pred.sum() != 0 70 | assert Y_pred.sum() != Y.size 71 | assert Y_prob_pred.max() <= 1 72 | assert Y_prob_pred.min() >= 0 73 | assert Y_pred.shape == Y.shape 74 | assert Y_prob_pred.shape == Y.shape 75 | assert model.losses[0] > model.losses[-1] 76 | 77 | 78 | @pytest.mark.bert 79 | def test_scibert(): 80 | X = [ 81 | "One and two", 82 | "One only", 83 | "Three and four, nothing else", 84 | "Two nothing else", 85 | "Two and three" 86 | ] 87 | Y = np.array([ 88 | [1, 1, 0, 0], 89 | [1, 0, 0, 0], 90 | [0, 0, 1, 1], 91 | [0, 1, 0, 0], 92 | [0, 1, 1, 0] 93 | ]) 94 | 95 | model = BertClassifier(pretrained="scibert") 96 | model.fit(X, Y) 97 | Y_pred = model.predict(X) 98 | Y_prob_pred = model.predict_proba(X) 99 | assert Y_pred.sum() != 0 100 | assert Y_pred.sum() != Y.size 101 | assert Y_prob_pred.max() <= 1 102 | assert Y_prob_pred.min() >= 0 103 | assert Y_pred.shape == Y.shape 104 | assert Y_prob_pred.shape == Y.shape 105 | assert model.losses[0] > model.losses[-1] 106 | 107 | 108 | @pytest.mark.bert 109 | def test_save_load(multilabel_bert): 110 | X = [ 111 | "One and two", 112 | "One only", 113 | "Three and four, nothing else", 114 | "Two nothing else", 115 | "Two and three" 116 | ] 117 | Y = np.array([ 118 | [1, 1, 0, 0], 119 | [1, 0, 0, 0], 120 | [0, 0, 1, 1], 121 | [0, 1, 0, 0], 122 | [0, 1, 1, 0] 123 | ]) 124 | 125 | model = multilabel_bert 126 | model.epochs = 1 # Only need to fit 1 epoch here really, because we're testing save 127 | model.fit(X, Y) 128 | 129 | with tempfile.TemporaryDirectory() as tmp_path: 130 | model.save(tmp_path) 131 | loaded_model = BertClassifier() 132 | loaded_model.load(tmp_path) 133 | 134 | Y_pred = loaded_model.predict(X) 135 | Y_prob_pred = loaded_model.predict_proba(X) 136 | assert Y_prob_pred.sum() >= 0 137 | assert Y_pred.shape == Y.shape 138 | -------------------------------------------------------------------------------- /tests/test_spacy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import en_core_web_sm 5 | import pytest 6 | 7 | from wellcomeml.spacy.spacy_doc_to_prodigy import SpacyDocToProdigy 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def nlp(): 12 | return en_core_web_sm.load() 13 | 14 | 15 | def test_return_one_prodigy_doc_fails_if_passed_wrong_type(): 16 | 17 | with pytest.raises(TypeError): 18 | wrong_format = [ 19 | "this is the text", 20 | {"entities": [[0, 1, "PERSON"], [2, 4, "COMPANY"]]}, 21 | ] 22 | 23 | spacy_to_prodigy = SpacyDocToProdigy() 24 | spacy_to_prodigy.return_one_prodigy_doc(wrong_format) 25 | 26 | 27 | def test_SpacyDocToProdigy(nlp): 28 | 29 | # https://www.theguardian.com/world/2019/oct/30/pinochet-economic-model-current-crisis-chile 30 | before = nlp("After 12 days of mass demonstrations, rioting and human rights violations," 31 | " the government of President Sebastián Piñera must now find a way out of the" 32 | " crisis that has engulfed Chile.") 33 | 34 | stp = SpacyDocToProdigy() 35 | actual = stp.run([before]) 36 | 37 | expected = [ 38 | { 39 | 'text': 'After 12 days of mass demonstrations, rioting and human rights violations,' 40 | ' the government of President Sebastián Piñera must now find a way out of the' 41 | ' crisis that has engulfed Chile.', 42 | 'spans': [ 43 | {'token_start': 1, 'token_end': 3, 'start': 6, 'end': 13, 'label': 'DATE'}, 44 | {'token_start': 17, 'token_end': 19, 'start': 103, 'end': 119, 'label': 'PERSON'}, 45 | {'token_start': 31, 'token_end': 32, 'start': 176, 'end': 181, 'label': 'GPE'} 46 | ], 47 | 'tokens': [ 48 | {'text': 'After', 'start': 0, 'end': 5, 'id': 0}, 49 | {'text': '12', 'start': 6, 'end': 8, 'id': 1}, 50 | {'text': 'days', 'start': 9, 'end': 13, 'id': 2}, 51 | {'text': 'of', 'start': 14, 'end': 16, 'id': 3}, 52 | {'text': 'mass', 'start': 17, 'end': 21, 'id': 4}, 53 | {'text': 'demonstrations', 'start': 22, 'end': 36, 'id': 5}, 54 | {'text': ',', 'start': 36, 'end': 37, 'id': 6}, 55 | {'text': 'rioting', 'start': 38, 'end': 45, 'id': 7}, 56 | {'text': 'and', 'start': 46, 'end': 49, 'id': 8}, 57 | {'text': 'human', 'start': 50, 'end': 55, 'id': 9}, 58 | {'text': 'rights', 'start': 56, 'end': 62, 'id': 10}, 59 | {'text': 'violations', 'start': 63, 'end': 73, 'id': 11}, 60 | {'text': ',', 'start': 73, 'end': 74, 'id': 12}, 61 | {'text': 'the', 'start': 75, 'end': 78, 'id': 13}, 62 | {'text': 'government', 'start': 79, 'end': 89, 'id': 14}, 63 | {'text': 'of', 'start': 90, 'end': 92, 'id': 15}, 64 | {'text': 'President', 'start': 93, 'end': 102, 'id': 16}, 65 | {'text': 'Sebastián', 'start': 103, 'end': 112, 'id': 17}, 66 | {'text': 'Piñera', 'start': 113, 'end': 119, 'id': 18}, 67 | {'text': 'must', 'start': 120, 'end': 124, 'id': 19}, 68 | {'text': 'now', 'start': 125, 'end': 128, 'id': 20}, 69 | {'text': 'find', 'start': 129, 'end': 133, 'id': 21}, 70 | {'text': 'a', 'start': 134, 'end': 135, 'id': 22}, 71 | {'text': 'way', 'start': 136, 'end': 139, 'id': 23}, 72 | {'text': 'out', 'start': 140, 'end': 143, 'id': 24}, 73 | {'text': 'of', 'start': 144, 'end': 146, 'id': 25}, 74 | {'text': 'the', 'start': 147, 'end': 150, 'id': 26}, 75 | {'text': 'crisis', 'start': 151, 'end': 157, 'id': 27}, 76 | {'text': 'that', 'start': 158, 'end': 162, 'id': 28}, 77 | {'text': 'has', 'start': 163, 'end': 166, 'id': 29}, 78 | {'text': 'engulfed', 'start': 167, 'end': 175, 'id': 30}, 79 | {'text': 'Chile', 'start': 176, 'end': 181, 'id': 31}, 80 | {'text': '.', 'start': 181, 'end': 182, 'id': 32} 81 | ] 82 | } 83 | ] 84 | 85 | assert expected == actual 86 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /wellcomeml/ml/bert_vectorizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | BERT Vectorizer that embeds text using a prertained BERT model 4 | """ 5 | import logging 6 | 7 | import tqdm 8 | 9 | from wellcomeml.utils import check_cache_and_download, throw_extra_import_message 10 | 11 | required_modules = "torch,sklearn,numpy" 12 | required_extras = "torch,transformers,sklearn" 13 | try: 14 | from transformers import BertModel, BertTokenizer 15 | from sklearn.base import BaseEstimator, TransformerMixin 16 | import numpy as np 17 | import torch 18 | except ImportError as e: 19 | throw_extra_import_message(error=e, required_modules=required_modules, 20 | extras=required_extras) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class BertVectorizer(BaseEstimator, TransformerMixin): 26 | def __init__(self, pretrained="bert", sentence_embedding="mean_second_to_last"): 27 | """ 28 | Bert vectorizer parameters 29 | 30 | Args: 31 | pretrained: A pre-trained model name. Currently 'bert' or 'scibert' 32 | sentence_embedding: How to embedd a sentence using bert's layers. 33 | Current options: 34 | 'mean_second_to_last', 'mean_last', 'sum_last' or 'mean_last_four' 35 | Default: `mean_second_to_last`. If a valid option is not set, 36 | returns the pooler layer (embedding for the token [CLS]) 37 | """ 38 | self.pretrained = pretrained 39 | self.sentence_embedding = sentence_embedding 40 | 41 | @classmethod 42 | def save_transformed(cls, path, X_transformed): 43 | """Saves transformed embedded vectors""" 44 | np.save(path, X_transformed) 45 | 46 | @classmethod 47 | def load_transformed(cls, path): 48 | """Loads transformed embedded vectors""" 49 | return np.load(path) 50 | 51 | def bert_embedding(self, x): 52 | tokenized_x = self.tokenizer.tokenize(x) 53 | 54 | # Max sequence length is 512 for BERT. 510 without CLS and SEP 55 | if len(tokenized_x) > 510: 56 | embedded_a = self.bert_embedding(" ".join(tokenized_x[:510])) 57 | embedded_b = self.bert_embedding(" ".join(tokenized_x[510:])) 58 | return embedded_a + embedded_b 59 | 60 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenized_x + ["[SEP]"]) 61 | 62 | tokens_tensor = torch.tensor([indexed_tokens]) 63 | segments_tensor = torch.zeros(tokens_tensor.shape, dtype=torch.long) 64 | with torch.no_grad(): 65 | output = self.model(tokens_tensor, token_type_ids=segments_tensor) 66 | last_layer = output[2][-1] 67 | second_to_last_layer = output[2][-2] 68 | 69 | if self.sentence_embedding == "mean_second_to_last": 70 | embedded_x = second_to_last_layer.mean(dim=1) 71 | elif self.sentence_embedding == "mean_last": 72 | embedded_x = last_layer.mean(dim=1) 73 | elif self.sentence_embedding == "sum_last": 74 | embedded_x = last_layer.sum(dim=1) 75 | elif self.sentence_embedding == "mean_last_four": 76 | embedded_x = torch.stack(output[2][-4:]).mean(dim=0).mean(dim=1) 77 | else: 78 | # Else gives the embbedding for the pooler layer. This can be, for 79 | # example, fed into a softmax classifier 80 | embedded_x = output[1] 81 | 82 | return embedded_x.cpu().numpy().flatten() 83 | 84 | def transform(self, X, verbose=True, *_): 85 | X = (tqdm.tqdm(X) if verbose else X) 86 | 87 | return np.array([self.bert_embedding(x) for x in X]) 88 | 89 | def fit(self, *_): 90 | model_name = ( 91 | "bert-base-uncased" 92 | if self.pretrained == "bert" 93 | else "scibert_scivocab_uncased" 94 | ) 95 | 96 | # If model_name doesn't exist checks cache and change name to 97 | # full path 98 | if model_name == "scibert_scivocab_uncased": 99 | model_name = check_cache_and_download(model_name) 100 | 101 | logger.info("Using {} embedding".format(model_name)) 102 | self.model = BertModel.from_pretrained(model_name, output_hidden_states=True) 103 | self.tokenizer = BertTokenizer.from_pretrained(model_name) 104 | self.model.eval() 105 | return self 106 | -------------------------------------------------------------------------------- /wellcomeml/ml/frequency_vectorizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | A generic "frequency" vectorizer that wraps all usual transformations. 6 | """ 7 | import logging 8 | import re 9 | 10 | 11 | from wellcomeml.utils import throw_extra_import_message 12 | # Heavy dependencies go here 13 | required_modules = 'spacy,sklearn,scipy' 14 | required_extras = 'spacy,core' 15 | 16 | try: 17 | import spacy 18 | from scipy import sparse 19 | from sklearn.feature_extraction.text import TfidfVectorizer 20 | except ImportError as e: 21 | throw_extra_import_message(error=e, required_modules='spacy', extras='spacy') 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class WellcomeTfidf(TfidfVectorizer): 27 | """ 28 | Class to wrap some basic transformation and text 29 | vectorisation/embedding 30 | """ 31 | 32 | def __init__(self, use_regex=True, use_spacy_lemmatizer=True, **kwargs): 33 | """ 34 | 35 | Args: 36 | Any sklearn "tfidfvectorizer" arguments (min_df, etc.) 37 | 38 | """ 39 | self.embedding = "tf-idf" 40 | self.use_regex = use_regex 41 | self.use_spacy_lemmatizer = use_spacy_lemmatizer 42 | 43 | logger.info("Initialising frequency vectorizer.") 44 | 45 | kwargs["stop_words"] = kwargs.get("stop_words", "english") 46 | 47 | super().__init__(**kwargs) 48 | 49 | self.nlp = spacy.blank("en") 50 | self.nlp.add_pipe("lemmatizer", config={"mode": "lookup"}) 51 | self.nlp.initialize() 52 | 53 | @classmethod 54 | def save_transformed(cls, path, X_transformed): 55 | """Saves transformed embedded vectors""" 56 | sparse.save_npz(path, X_transformed) 57 | 58 | @classmethod 59 | def load_transformed(cls, path): 60 | """Loads transformed embedded vectors""" 61 | return sparse.load_npz(path) 62 | 63 | def regex_transform(self, X, remove_numbers="years", *_): 64 | """ 65 | Extra regular expression transformations to clean text 66 | Args: 67 | X: A list of texts (strings) 68 | *_: 69 | remove_numbers: Whether to remove years or all digits. Caveat: 70 | This does not only remove years, but **any number** between 71 | 1000 and 2999. 72 | 73 | Returns: 74 | A list of texts with the applied regex transformation 75 | 76 | """ 77 | if remove_numbers == "years": 78 | return [re.sub(r"[1-2]\d{3}", "", text) for text in X] 79 | elif remove_numbers == "digits": 80 | return [re.sub(r"\d", "", text) for text in X] 81 | else: 82 | return X 83 | 84 | def spacy_lemmatizer(self, X, remove_stopwords_and_punct=True): 85 | """ 86 | Uses spacy pre-trained lemmatisation model to 87 | Args: 88 | X: A list of texts (strings) 89 | remove_stopwords_and_punct: Whether to remove stopwords, 90 | punctuation, pronouns 91 | 92 | Returns: 93 | 94 | """ 95 | 96 | logger.info("Using spacy pre-trained lemmatiser.") 97 | if remove_stopwords_and_punct: 98 | return [ 99 | [ 100 | token.lemma_.lower() 101 | for token in doc 102 | if not token.is_stop 103 | and not token.is_punct 104 | and token.lemma_ != "-PRON-" 105 | ] 106 | for doc in self.nlp.pipe(X) 107 | ] 108 | else: 109 | return [ 110 | [token.lemma_.lower() for token in doc] for doc in self.nlp.pipe(X) 111 | ] 112 | 113 | def _pre_transform(self, X): 114 | if self.use_regex: 115 | X = self.regex_transform(X) 116 | if self.use_spacy_lemmatizer: 117 | X = self.spacy_lemmatizer(X) 118 | 119 | return [" ".join(text) for text in X] 120 | 121 | def transform(self, X): 122 | X = self._pre_transform(X) 123 | 124 | return super().transform(X) 125 | 126 | def fit(self, X, y=None): 127 | X = self._pre_transform(X) 128 | 129 | super().fit(X) 130 | return self 131 | 132 | def fit_transform(self, X, y=None): 133 | X = self._pre_transform(X) 134 | 135 | return super().fit_transform(X, y=y) 136 | -------------------------------------------------------------------------------- /wellcomeml/ml/vectorizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | """ 5 | A generic vectorizer that can fallback to tdidf or bag of words from sklearn 6 | or embed using bert, doc2vec etc 7 | """ 8 | from wellcomeml.utils import throw_extra_import_message 9 | 10 | required_modules = 'sklearn' 11 | required_extras = 'core' 12 | 13 | try: 14 | from sklearn.base import BaseEstimator, TransformerMixin 15 | except ImportError as e: 16 | throw_extra_import_message(error=e, required_modules=required_modules, 17 | required_extras=required_extras) 18 | 19 | 20 | class Vectorizer(BaseEstimator, TransformerMixin): 21 | """ 22 | Abstract class, sklearn-compatible, that can vectorize texts using 23 | various models. 24 | 25 | """ 26 | 27 | def __init__(self, embedding="tf-idf", cache_transformed=False, **kwargs): 28 | """ 29 | Args: 30 | embedding(str): One of `['bert', 'tf-idf']` 31 | cache_transformed(bool): Caches the last transformed vector X ( 32 | useful if performing Grid-search as part of a pipeline) 33 | """ 34 | self.embedding = embedding 35 | self.cache_transformed = cache_transformed 36 | 37 | # Only actually import when necessary (extras might not be installed) 38 | 39 | vectorizer_dispatcher = { 40 | "tf-idf": "wellcomeml.ml.frequency_vectorizer.WellcomeTfidf", 41 | "bert": "wellcomeml.ml.bert_vectorizer.BertVectorizer", 42 | "keras": "welclomeml.ml.keras_vectorizer.KerasVectorizer", 43 | "doc2vec": "wellcomeml.ml.doc2vec_vectorizer.Doc2VecVectorizer", 44 | } 45 | 46 | if not vectorizer_dispatcher.get(embedding): 47 | raise ValueError(f"Model {embedding} not available") 48 | 49 | vectorizer_path = '.'.join(vectorizer_dispatcher.get(embedding).split('.')[:-1]) 50 | vectorizer_class_path = vectorizer_dispatcher.get(embedding).split('.')[-1] 51 | vectorizer = getattr( 52 | __import__(vectorizer_path, fromlist=[vectorizer_class_path]), 53 | vectorizer_class_path 54 | ) 55 | 56 | self.vectorizer = vectorizer(**kwargs) 57 | 58 | def fit(self, X=None, *_): 59 | return self.vectorizer.fit(X) 60 | 61 | def transform(self, X, *_): 62 | X_transformed = self.vectorizer.transform(X) 63 | 64 | if self.cache_transformed: 65 | self.X_transformed = X_transformed 66 | 67 | return X_transformed 68 | 69 | def fit_transform(self, X, y=None, *_): 70 | # Slightly modified fit_transform so it can work with the 71 | # cache_transformed 72 | self.fit(X) 73 | return self.transform(X) 74 | 75 | def save_transformed(self, path, X_transformed): 76 | """ 77 | Saves transformed vector X_transformed vector, using the corresponding 78 | save_transformed method for the specific vectorizer. 79 | 80 | Args: 81 | path: A path to the embedding file 82 | X_transformed: A transformed vector (as output by using the 83 | .transform method) 84 | 85 | """ 86 | save_method = getattr(self.vectorizer.__class__, "save_transformed", None) 87 | if not save_method: 88 | raise NotImplementedError( 89 | f"Method save_transformed not implemented" 90 | f" for class " 91 | f"{self.vectorizer.__class__.__name__}" 92 | ) 93 | 94 | return save_method(path=path, X_transformed=X_transformed) 95 | 96 | def load_transformed(self, path): 97 | """ 98 | Loads transformed vector X_transformed vector, using the corresponding 99 | load method for the specific vectorizer. 100 | 101 | Args: 102 | path: A path to the file containing embedded vectors 103 | 104 | Returns: 105 | X_transformed (array), like the one returned by the the 106 | fit_transform function. 107 | """ 108 | load_method = getattr(self.vectorizer.__class__, "load_transformed", None) 109 | if not load_method: 110 | raise NotImplementedError( 111 | f"Method load_transformed not implemented" 112 | f" for class " 113 | f"{self.vectorizer.__class__.__name__}" 114 | ) 115 | 116 | return load_method(path=path) 117 | -------------------------------------------------------------------------------- /tests/test_ner_spacy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import en_core_web_sm 5 | import pytest 6 | from wellcomeml.ml.spacy_ner import SpacyNER 7 | from wellcomeml.metrics.ner_classification_report import ner_classification_report 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def nlp(): 12 | return en_core_web_sm.load() 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def X_train(): 17 | return [ 18 | """n Journal of Psychiatry 158: 2071–4\nFreeman MP, 19 | Hibbeln JR, Wisner KL et al. (2006)\nOmega-3 fatty ac""", 20 | """rd, (BKKBN)\n \nJakarta, Indonesia\n29. Drs Titut Prihyugiarto 21 | \n MSPA\n \nSenior Researcher for Reproducti""", 22 | """a Santé, 2008. \n118. Konradsen, F. et coll. 23 | Community uptake of safe storage boxes to reduce self-po""", 24 | """ted that the two treatments can \nbe combined. Contrarily, 25 | Wapf et al. \nnoted that many treatment per""", 26 | """ti-tuberculosis treatment in Mongolia. Int J Tuberc Lung Dis. 27 | 2015;19(6):657–62. \n160. Dudley L, Aze""", 28 | """he \nScottish Heart Health Study: cohort study. BMJ, 1997, 315:722–729. 29 | \nUmesawa M, Iso H, Date C et """, 30 | """T.A., G. Marland, and R.J. Andres (2010). Global, Regional, 31 | and National Fossil-Fuel CO2 Emissions. """, 32 | """Ian Gr\nMr Ian Graayy\nPrincipal Policy Officer 33 | (Public Health and Health Protection), Chartered Insti""", 34 | """. \n3. \nFischer G and Stöver H. Assessing the current 35 | state of opioid-dependence treatment across Eur""", 36 | """ated by\nLlorca et al. (2014) or Pae et al. (2015), or when 37 | vortioxetine was assumed to be\nas effecti""", 38 | ] 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def y_train(): 43 | return [ 44 | [{'start': 36, 'end': 46, 'label': 'PERSON'}, 45 | {'start': 48, 'end': 58, 'label': 'PERSON'}, 46 | {'start': 61, 'end': 69, 'label': 'PERSON'}], 47 | [{'start': 41, 'end': 59, 'label': 'PERSON'}], 48 | [{'start': 21, 'end': 34, 'label': 'PERSON'}], 49 | [{'start': 58, 'end': 62, 'label': 'PERSON'}], 50 | [{'start': 87, 'end': 95, 'label': 'PERSON'}], 51 | [{'start': 72, 'end': 81, 'label': 'PERSON'}, 52 | {'start': 83, 'end': 88, 'label': 'PERSON'}, 53 | {'start': 90, 'end': 96, 'label': 'PERSON'}], 54 | [{'start': 6, 'end': 16, 'label': 'PERSON'}, {'start': 22, 'end': 33, 'label': 'PERSON'}], 55 | [{'start': 0, 'end': 6, 'label': 'PERSON'}, {'start': 10, 'end': 20, 'label': 'PERSON'}], 56 | [{'start': 7, 'end': 16, 'label': 'PERSON'}, {'start': 21, 'end': 30, 'label': 'PERSON'}], 57 | [{'start': 8, 'end': 14, 'label': 'PERSON'}, {'start': 32, 'end': 35, 'label': 'PERSON'}], 58 | ] 59 | 60 | 61 | @pytest.fixture(scope="module") 62 | def ner_groups(): 63 | return ['Group 1', 'Group 2', 'Group 3', 'Group 2', 'Group 1', 'Group 3', 'Group 3', 64 | 'Group 3', 'Group 2', 'Group 1'] 65 | 66 | 67 | def test_fit(X_train, y_train): 68 | spacy_ner = SpacyNER(n_iter=3, dropout=0.2, output=True) 69 | spacy_ner.load("en_core_web_sm") 70 | retrained_nlp = spacy_ner.fit(X_train, y_train) 71 | 72 | assert vars(retrained_nlp)['_meta']['name'] == 'core_web_sm' 73 | 74 | 75 | def test_predict(): 76 | # Using spaCy's nlp model (don't retrain) 77 | spacy_ner = SpacyNER(n_iter=3, dropout=0.2, output=True) 78 | spacy_ner.load("en_core_web_sm") 79 | pred_entities = spacy_ner.predict("Apple is looking at buying U.K. startup for $1 billion") 80 | 81 | # Make sure its not an empty list 82 | assert pred_entities 83 | 84 | 85 | def test_score(X_train, y_train): 86 | # Using spaCy's nlp model (don't retrain) 87 | spacy_ner = SpacyNER(n_iter=3, dropout=0.2, output=True) 88 | spacy_ner.load("en_core_web_sm") 89 | y_pred = [spacy_ner.predict(text) for text in X_train] 90 | f1 = spacy_ner.score(y_train, y_pred, tags=['PERSON']) 91 | 92 | assert isinstance(f1['PERSON'], float) 93 | 94 | 95 | def test_ner_classification_report(X_train, y_train, ner_groups): 96 | spacy_ner = SpacyNER(n_iter=3, dropout=0.2, output=True) 97 | spacy_ner.load("en_core_web_sm") 98 | y_pred = [spacy_ner.predict(text) for text in X_train] 99 | report = ner_classification_report(y_train, y_pred, ner_groups, tags=['PERSON']) 100 | 101 | assert isinstance(report, str) 102 | -------------------------------------------------------------------------------- /wellcomeml/io/s3_policy_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from io import BytesIO 3 | import gzip 4 | import os 5 | 6 | 7 | class PolicyDocumentsDownloader: 8 | """ 9 | Interact with S3 to get policy document texts 10 | 11 | Args: 12 | bucket_name: the S3 bucket name to get data from 13 | dir_path: the directory path within this s3 bucket to get policy data from 14 | """ 15 | def __init__(self, s3, bucket_name, dir_path): 16 | 17 | self.s3 = s3 18 | self.bucket_name = bucket_name 19 | self.dir_path = dir_path 20 | self.pdf_keys = self.get_all_s3_keys() 21 | 22 | def get_all_s3_keys(self): 23 | """ 24 | https://alexwlchan.net/2017/07/listing-s3-keys/ 25 | Get a list of all keys to look for pdfs in the S3 bucket. 26 | """ 27 | keys = [] 28 | 29 | kwargs = {'Bucket': self.bucket_name} 30 | while True: 31 | resp = self.s3.list_objects_v2(**kwargs) 32 | for obj in resp['Contents']: 33 | keys.append(obj['Key']) 34 | 35 | try: 36 | kwargs['ContinuationToken'] = resp['NextContinuationToken'] 37 | except KeyError: 38 | break 39 | 40 | pdf_keys = [k for k in keys if self.dir_path in k] 41 | 42 | return pdf_keys 43 | 44 | def get_hashes(self, word_list=None): 45 | """ 46 | Get a list of policy document hashes from the S3 location 47 | 48 | Args: 49 | word_list(list): a list of words to look for in documents, if None then get hashes for 50 | all 51 | Returns: 52 | list: a list of dicts with the file hash and the policy doc source where it is from 53 | """ 54 | 55 | print("Getting hashes for policy documents") 56 | hashes = [] 57 | for key in self.pdf_keys: 58 | print("Loading "+key) 59 | key_name = os.path.split(key)[-1] 60 | response = self.s3.get_object(Bucket=self.bucket_name, Key=key) 61 | content = response['Body'].read() 62 | with gzip.GzipFile(fileobj=BytesIO(content), mode='rb') as fh: 63 | for line in fh: 64 | document = json.loads(line) 65 | if document['text']: 66 | if word_list: 67 | if not any(word.lower() in document['text'].lower() 68 | for word in word_list): 69 | continue 70 | hashes.append({ 71 | "source": key_name, 72 | "file_hash": document['file_hash'] 73 | }) 74 | print(str(len(hashes))+" documents") 75 | 76 | return hashes 77 | 78 | def download(self, hash_list=None): 79 | """ 80 | Download the policy document data from S3 81 | 82 | Args: 83 | hash_list: a list of hashes to specifically download, if None then download all 84 | """ 85 | 86 | print("Downloading policy documents") 87 | documents = [] 88 | hashes_found = set() # A checker so we dont download duplicates 89 | for key in self.pdf_keys: 90 | print("Loading "+key) 91 | key_name = os.path.split(key)[-1] 92 | response = self.s3.get_object(Bucket=self.bucket_name, Key=key) 93 | content = response['Body'].read() 94 | with gzip.GzipFile(fileobj=BytesIO(content), mode='rb') as fh: 95 | for line in fh: 96 | document = json.loads(line) 97 | if ((document['text']) and (document['file_hash'] not in hashes_found)): 98 | if hash_list: 99 | if document['file_hash'] not in set(hash_list): 100 | continue 101 | document["source"] = key_name 102 | documents.append(document) 103 | hashes_found.add(document['file_hash']) 104 | 105 | print(str(len(documents))+" documents") 106 | 107 | return documents 108 | 109 | def save_json(self, documents, file_name): 110 | print("Saving data ...") 111 | with open(file_name, 'w', encoding='utf-8') as output_file: 112 | for document in documents: 113 | json.dump(document, output_file) 114 | output_file.write("\n") 115 | print("Number of documents saved: " + str(len(documents))) 116 | -------------------------------------------------------------------------------- /docs/_build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Search — WellcomeML 2.0.3 documentation 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | WellcomeML 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | WellcomeML’s documentation! 42 | 43 | Contents: 44 | 45 | Examples 46 | List of main modules and descriptions 47 | Clustering text with WellcomeML 48 | Core library documentation 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | WellcomeML 58 | 59 | 60 | 61 | 62 | 63 | 64 | » 65 | Search 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | Please activate JavaScript to enable the search functionality. 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 103 | 104 | 105 | 106 | 107 | 112 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /docs/wellcomeml.ml.rst: -------------------------------------------------------------------------------- 1 | wellcomeml.ml package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.ml.attention module 8 | ------------------------------ 9 | 10 | .. automodule:: wellcomeml.ml.attention 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.ml.bert\_classifier module 16 | ------------------------------------- 17 | 18 | .. automodule:: wellcomeml.ml.bert_classifier 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.ml.bert\_semantic\_equivalence module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: wellcomeml.ml.bert_semantic_equivalence 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | wellcomeml.ml.bert\_vectorizer module 32 | ------------------------------------- 33 | 34 | .. automodule:: wellcomeml.ml.bert_vectorizer 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | wellcomeml.ml.bilstm module 40 | --------------------------- 41 | 42 | .. automodule:: wellcomeml.ml.bilstm 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | wellcomeml.ml.clustering module 48 | ------------------------------- 49 | 50 | .. automodule:: wellcomeml.ml.clustering 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | wellcomeml.ml.cnn module 56 | ------------------------ 57 | 58 | .. automodule:: wellcomeml.ml.cnn 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | wellcomeml.ml.constants module 64 | ------------------------------ 65 | 66 | .. automodule:: wellcomeml.ml.constants 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | wellcomeml.ml.doc2vec\_vectorizer module 72 | ---------------------------------------- 73 | 74 | .. automodule:: wellcomeml.ml.doc2vec_vectorizer 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | wellcomeml.ml.frequency\_vectorizer module 80 | ------------------------------------------ 81 | 82 | .. automodule:: wellcomeml.ml.frequency_vectorizer 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | wellcomeml.ml.keras\_utils module 88 | --------------------------------- 89 | 90 | .. automodule:: wellcomeml.ml.keras_utils 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | wellcomeml.ml.keras\_vectorizer module 96 | -------------------------------------- 97 | 98 | .. automodule:: wellcomeml.ml.keras_vectorizer 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | wellcomeml.ml.sent2vec\_vectorizer module 104 | ----------------------------------------- 105 | 106 | .. automodule:: wellcomeml.ml.sent2vec_vectorizer 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | wellcomeml.ml.similarity\_entity\_linking module 112 | ------------------------------------------------ 113 | 114 | .. automodule:: wellcomeml.ml.similarity_entity_linking 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | wellcomeml.ml.spacy\_classifier module 120 | -------------------------------------- 121 | 122 | .. automodule:: wellcomeml.ml.spacy_classifier 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | wellcomeml.ml.spacy\_entity\_linking module 128 | ------------------------------------------- 129 | 130 | .. automodule:: wellcomeml.ml.spacy_entity_linking 131 | :members: 132 | :undoc-members: 133 | :show-inheritance: 134 | 135 | wellcomeml.ml.spacy\_knowledge\_base module 136 | ------------------------------------------- 137 | 138 | .. automodule:: wellcomeml.ml.spacy_knowledge_base 139 | :members: 140 | :undoc-members: 141 | :show-inheritance: 142 | 143 | wellcomeml.ml.spacy\_ner module 144 | ------------------------------- 145 | 146 | .. automodule:: wellcomeml.ml.spacy_ner 147 | :members: 148 | :undoc-members: 149 | :show-inheritance: 150 | 151 | wellcomeml.ml.transformers\_tokenizer module 152 | -------------------------------------------- 153 | 154 | .. automodule:: wellcomeml.ml.transformers_tokenizer 155 | :members: 156 | :undoc-members: 157 | :show-inheritance: 158 | 159 | wellcomeml.ml.vectorizer module 160 | ------------------------------- 161 | 162 | .. automodule:: wellcomeml.ml.vectorizer 163 | :members: 164 | :undoc-members: 165 | :show-inheritance: 166 | 167 | wellcomeml.ml.voting\_classifier module 168 | --------------------------------------- 169 | 170 | .. automodule:: wellcomeml.ml.voting_classifier 171 | :members: 172 | :undoc-members: 173 | :show-inheritance: 174 | 175 | Module contents 176 | --------------- 177 | 178 | .. automodule:: wellcomeml.ml 179 | :members: 180 | :undoc-members: 181 | :show-inheritance: 182 | -------------------------------------------------------------------------------- /docs/_build/html/_sources/wellcomeml.ml.rst.txt: -------------------------------------------------------------------------------- 1 | wellcomeml.ml package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | wellcomeml.ml.attention module 8 | ------------------------------ 9 | 10 | .. automodule:: wellcomeml.ml.attention 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | wellcomeml.ml.bert\_classifier module 16 | ------------------------------------- 17 | 18 | .. automodule:: wellcomeml.ml.bert_classifier 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | wellcomeml.ml.bert\_semantic\_equivalence module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: wellcomeml.ml.bert_semantic_equivalence 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | wellcomeml.ml.bert\_vectorizer module 32 | ------------------------------------- 33 | 34 | .. automodule:: wellcomeml.ml.bert_vectorizer 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | wellcomeml.ml.bilstm module 40 | --------------------------- 41 | 42 | .. automodule:: wellcomeml.ml.bilstm 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | wellcomeml.ml.clustering module 48 | ------------------------------- 49 | 50 | .. automodule:: wellcomeml.ml.clustering 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | wellcomeml.ml.cnn module 56 | ------------------------ 57 | 58 | .. automodule:: wellcomeml.ml.cnn 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | wellcomeml.ml.constants module 64 | ------------------------------ 65 | 66 | .. automodule:: wellcomeml.ml.constants 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | wellcomeml.ml.doc2vec\_vectorizer module 72 | ---------------------------------------- 73 | 74 | .. automodule:: wellcomeml.ml.doc2vec_vectorizer 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | wellcomeml.ml.frequency\_vectorizer module 80 | ------------------------------------------ 81 | 82 | .. automodule:: wellcomeml.ml.frequency_vectorizer 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | wellcomeml.ml.keras\_utils module 88 | --------------------------------- 89 | 90 | .. automodule:: wellcomeml.ml.keras_utils 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | wellcomeml.ml.keras\_vectorizer module 96 | -------------------------------------- 97 | 98 | .. automodule:: wellcomeml.ml.keras_vectorizer 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | wellcomeml.ml.sent2vec\_vectorizer module 104 | ----------------------------------------- 105 | 106 | .. automodule:: wellcomeml.ml.sent2vec_vectorizer 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | wellcomeml.ml.similarity\_entity\_linking module 112 | ------------------------------------------------ 113 | 114 | .. automodule:: wellcomeml.ml.similarity_entity_linking 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | wellcomeml.ml.spacy\_classifier module 120 | -------------------------------------- 121 | 122 | .. automodule:: wellcomeml.ml.spacy_classifier 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | wellcomeml.ml.spacy\_entity\_linking module 128 | ------------------------------------------- 129 | 130 | .. automodule:: wellcomeml.ml.spacy_entity_linking 131 | :members: 132 | :undoc-members: 133 | :show-inheritance: 134 | 135 | wellcomeml.ml.spacy\_knowledge\_base module 136 | ------------------------------------------- 137 | 138 | .. automodule:: wellcomeml.ml.spacy_knowledge_base 139 | :members: 140 | :undoc-members: 141 | :show-inheritance: 142 | 143 | wellcomeml.ml.spacy\_ner module 144 | ------------------------------- 145 | 146 | .. automodule:: wellcomeml.ml.spacy_ner 147 | :members: 148 | :undoc-members: 149 | :show-inheritance: 150 | 151 | wellcomeml.ml.transformers\_tokenizer module 152 | -------------------------------------------- 153 | 154 | .. automodule:: wellcomeml.ml.transformers_tokenizer 155 | :members: 156 | :undoc-members: 157 | :show-inheritance: 158 | 159 | wellcomeml.ml.vectorizer module 160 | ------------------------------- 161 | 162 | .. automodule:: wellcomeml.ml.vectorizer 163 | :members: 164 | :undoc-members: 165 | :show-inheritance: 166 | 167 | wellcomeml.ml.voting\_classifier module 168 | --------------------------------------- 169 | 170 | .. automodule:: wellcomeml.ml.voting_classifier 171 | :members: 172 | :undoc-members: 173 | :show-inheritance: 174 | 175 | Module contents 176 | --------------- 177 | 178 | .. automodule:: wellcomeml.ml 179 | :members: 180 | :undoc-members: 181 | :show-inheritance: 182 | -------------------------------------------------------------------------------- /docs/_build/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap(""),n("table.docutils.footnote").wrap(""),n("table.docutils.citation").wrap(""),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t 0.5. 11 | multilabel = False: 12 | The predicted class will be class with the largest mean probability from all the 13 | estimators. 14 | 15 | voting = "hard": 16 | The predicted class(es) will be any class(es) where the majority (or >= num_agree) of 17 | estimators agree. Majority in even cases is 2/2 or >=3/4 (i.e. not 2/4). 18 | 19 | If there are multiple classes with the majority (a tie) then the first one in 20 | numerical order will be chosen. 21 | 22 | In the multiclass case, if there are multiple classes with >= num_agree, then the 23 | predicted class will be the first class in a numerically sorted list of the tying classes. 24 | e.g. 25 | If there is 1 vote for class 0, 2 votes for class 1 and 2 votes for class 2, 26 | then class 1 will be predicted. 27 | 28 | In the binary case, if there is a tie then class 0 is predicted. 29 | e.g. 30 | If there are 2 votes for class 0 and 2 votes for class 1, then class 0 31 | will be predicted. 32 | 33 | """ 34 | import logging 35 | 36 | from wellcomeml.utils import throw_extra_import_message 37 | 38 | required_modules = 'sklearn,numpy' 39 | required_extras = 'core' 40 | 41 | try: 42 | from sklearn.utils.validation import check_is_fitted 43 | from sklearn.ensemble import VotingClassifier 44 | from sklearn.exceptions import NotFittedError 45 | import numpy as np 46 | except ImportError as e: 47 | throw_extra_import_message(error=e, required_modules=required_modules, extras=required_extras) 48 | 49 | logger = logging.getLogger(__name__) 50 | 51 | 52 | class WellcomeVotingClassifier(VotingClassifier): 53 | def __init__(self, multilabel=False, num_agree=None, *args, **kwargs): 54 | super(WellcomeVotingClassifier, self).__init__(*args, **kwargs) 55 | self.pretrained = self._is_pretrained() 56 | self.multilabel = multilabel 57 | self.num_agree = num_agree 58 | 59 | def _is_pretrained(self): 60 | try: 61 | check_is_fitted(self, "estimators") 62 | except NotFittedError: 63 | return False 64 | return True 65 | 66 | def _get_estimators(self): 67 | if type(self.estimators) == list: 68 | return [est for est in self.estimators] 69 | else: # tuple with named estimators 70 | return [est for _, est in self.estimators] 71 | 72 | def predict(self, X): 73 | if self.pretrained: 74 | check_is_fitted(self, "estimators") 75 | 76 | estimators = self._get_estimators() 77 | 78 | if self.voting == "soft": 79 | if self.num_agree: 80 | logger.warning("num_agree specified but not used in soft voting") 81 | Y_probs = np.array([est.predict_proba(X) for est in estimators]) 82 | Y_prob = np.mean(Y_probs, axis=0) 83 | if self.multilabel: 84 | return np.array(Y_prob > 0.5, dtype=int) 85 | else: 86 | return np.argmax(Y_prob, axis=1) 87 | else: # hard voting 88 | 89 | # If num_agree isn't set then use majority vote 90 | if not self.num_agree: 91 | # So if 4 estimators, >= 3 need to agree 92 | self.num_agree = np.ceil((len(estimators) + 1) / 2) 93 | 94 | Y_preds = [est.predict(X) for est in estimators] 95 | Y_preds = np.array(Y_preds) 96 | if self.multilabel: 97 | return np.array(Y_preds.sum(axis=0) >= self.num_agree, dtype='int32') 98 | else: 99 | votes = np.apply_along_axis(lambda x: max(np.bincount(x)), axis=0, arr=Y_preds) 100 | max_class = np.apply_along_axis( 101 | lambda x: np.argmax(np.bincount(x)), axis=0, arr=Y_preds 102 | ) 103 | # If no maximum over the threshold, then pick the first from an ordered list 104 | # of the other options, e.g. if 5,2,3 were voted on pick 2 (not 0) 105 | options = np.sort(np.transpose(Y_preds)) 106 | return [m if v >= self.num_agree else options[i][0] 107 | for i, (m, v) in enumerate(zip(max_class, votes))] 108 | 109 | else: 110 | return super(WellcomeVotingClassifier, self).predict(X) 111 | --------------------------------------------------------------------------------
Contents:
77 | Please activate JavaScript to enable the search functionality. 78 |