├── .circleci └── config.yml ├── .codecov.yaml ├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── pull_request_template.md └── workflows │ └── stale.yml ├── .gitignore ├── .readthedocs.yml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── RELEASING.md ├── docs ├── Makefile ├── _static │ └── octopus.png ├── conf.py ├── index.rst ├── packages.json ├── packages │ ├── analysis.rst │ ├── augmentation.rst │ ├── classification.rst │ ├── labeling.rst │ ├── map.rst │ ├── preprocess.rst │ ├── slicing.rst │ └── utils.rst └── requirements-doc.txt ├── figs ├── ONR.jpg ├── WS_pipeline2.pdf ├── darpa.JPG ├── dp_neurips_2016.png ├── logo_01.png ├── mobilize_logo.png ├── moore_logo.png ├── nih_logo.png ├── user_logos.png └── vldb2018_logo.png ├── pyproject.toml ├── requirements-pyspark.txt ├── requirements.txt ├── scripts ├── check_requirements.py └── sync_api_docs.py ├── setup.cfg ├── setup.py ├── snorkel ├── __init__.py ├── analysis │ ├── __init__.py │ ├── error_analysis.py │ ├── metrics.py │ └── scorer.py ├── augmentation │ ├── __init__.py │ ├── apply │ │ ├── __init__.py │ │ ├── core.py │ │ └── pandas.py │ ├── policy │ │ ├── __init__.py │ │ ├── core.py │ │ └── sampling.py │ └── tf.py ├── classification │ ├── __init__.py │ ├── data.py │ ├── loss.py │ ├── multitask_classifier.py │ ├── task.py │ ├── training │ │ ├── __init__.py │ │ ├── loggers │ │ │ ├── __init__.py │ │ │ ├── checkpointer.py │ │ │ ├── log_manager.py │ │ │ ├── log_writer.py │ │ │ └── tensorboard_writer.py │ │ ├── schedulers │ │ │ ├── __init__.py │ │ │ ├── scheduler.py │ │ │ ├── sequential_scheduler.py │ │ │ └── shuffled_scheduler.py │ │ └── trainer.py │ └── utils.py ├── contrib │ ├── README.md │ └── __init__.py ├── labeling │ ├── __init__.py │ ├── analysis.py │ ├── apply │ │ ├── __init__.py │ │ ├── core.py │ │ ├── dask.py │ │ ├── pandas.py │ │ └── spark.py │ ├── lf │ │ ├── __init__.py │ │ ├── core.py │ │ ├── nlp.py │ │ └── nlp_spark.py │ ├── model │ │ ├── __init__.py │ │ ├── base_labeler.py │ │ ├── baselines.py │ │ ├── graph_utils.py │ │ ├── label_model.py │ │ └── logger.py │ └── utils.py ├── map │ ├── __init__.py │ ├── core.py │ └── spark.py ├── preprocess │ ├── __init__.py │ ├── core.py │ ├── nlp.py │ └── spark.py ├── slicing │ ├── __init__.py │ ├── apply │ │ ├── __init__.py │ │ ├── core.py │ │ ├── dask.py │ │ └── spark.py │ ├── modules │ │ ├── __init__.py │ │ └── slice_combiner.py │ ├── monitor.py │ ├── sf │ │ ├── __init__.py │ │ ├── core.py │ │ └── nlp.py │ ├── sliceaware_classifier.py │ └── utils.py ├── synthetic │ ├── __init__.py │ └── synthetic_data.py ├── types │ ├── __init__.py │ ├── classifier.py │ ├── data.py │ └── hashing.py ├── utils │ ├── __init__.py │ ├── config_utils.py │ ├── core.py │ ├── data_operators.py │ ├── lr_schedulers.py │ └── optimizers.py └── version.py ├── test ├── __init__.py ├── analysis │ ├── test_error_analysis.py │ ├── test_metrics.py │ └── test_scorer.py ├── augmentation │ ├── __init__.py │ ├── apply │ │ ├── __init__.py │ │ └── test_tf_applier.py │ └── policy │ │ ├── __init__.py │ │ ├── test_core.py │ │ └── test_sampling.py ├── classification │ ├── __init__.py │ ├── test_classifier_convergence.py │ ├── test_data.py │ ├── test_loss.py │ ├── test_multitask_classifier.py │ ├── test_task.py │ ├── test_utils.py │ └── training │ │ ├── loggers │ │ ├── __init__.py │ │ ├── test_checkpointer.py │ │ ├── test_log_manager.py │ │ ├── test_log_writer.py │ │ └── test_tensorboard_writer.py │ │ ├── schedulers │ │ └── test_schedulers.py │ │ └── test_trainer.py ├── labeling │ ├── __init__.py │ ├── apply │ │ ├── __init__.py │ │ ├── lf_applier_spark_test_script.py │ │ ├── test_lf_applier.py │ │ └── test_spark.py │ ├── lf │ │ ├── test_core.py │ │ ├── test_nlp.py │ │ └── test_nlp_spark.py │ ├── model │ │ ├── __init__.py │ │ ├── test_baseline.py │ │ ├── test_label_model.py │ │ └── test_logger.py │ ├── preprocess │ │ ├── __init__.py │ │ └── test_nlp.py │ ├── test_analysis.py │ ├── test_convergence.py │ └── test_utils.py ├── map │ ├── __init__.py │ ├── test_core.py │ └── test_spark.py ├── slicing │ ├── __init__.py │ ├── apply │ │ ├── __init__.py │ │ └── test_sf_applier.py │ ├── sf │ │ ├── __init__.py │ │ ├── test_core.py │ │ └── test_nlp.py │ ├── test_monitor.py │ ├── test_slice_combiner.py │ ├── test_sliceaware_classifier.py │ └── test_utils.py ├── synthetic │ ├── __init__.py │ └── test_synthetic_data.py └── utils │ ├── __init__.py │ ├── test_config_utils.py │ ├── test_core.py │ └── test_data_operators.py └── tox.ini /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Use the latest 2.1 version of CircleCI pipeline process engine. 4 | version: 2.1 5 | 6 | orbs: 7 | python: circleci/python@1.2 8 | 9 | commands: 10 | setup_dependencies: 11 | description: "Install depenencies" 12 | parameters: 13 | after-deps: 14 | description: "Install dependenceis" 15 | type: steps 16 | default: [] 17 | steps: 18 | - run: 19 | name: "Install open JDK" 20 | command: sudo add-apt-repository -y ppa:openjdk-r/ppa 21 | - run: 22 | name: "Install qq" 23 | command: sudo apt-get -qq update 24 | - run: 25 | name: "No install recommends for JDK" 26 | command: sudo apt-get install -y openjdk-8-jdk --no-install-recommends 27 | - run: 28 | name: "Run Java Alternatives install for JDK" 29 | command: sudo update-java-alternatives -s java-1.8.0-openjdk-amd64 30 | - run: 31 | name: "Run pip install setup tools and wheel" 32 | command: pip install -U pip setuptools wheel 33 | - run: 34 | name: "Install Tox" 35 | command: pip install -U tox==4.11.4 36 | - run: 37 | name: "Install Code Cov" 38 | command: pip install -U codecov 39 | - steps: << parameters.after-deps >> 40 | 41 | # We want to make sure we run this only on main branch, release, or when we make tags 42 | run_complex: &run_complex 43 | filters: 44 | branches: 45 | only: 46 | - main 47 | - /release-v.*/ 48 | tags: 49 | only: /.*/ 50 | 51 | jobs: 52 | Python311-Unit-Tests: 53 | docker: 54 | - image: cimg/python:3.11 55 | environment: 56 | TOXENV: coverage,doctest,type,check 57 | TOX_INSTALL_DIR: .env 58 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64 59 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python 60 | 61 | steps: 62 | - checkout 63 | - setup_dependencies 64 | - run: 65 | name: "Run Tox" 66 | command: tox 67 | 68 | Python311-Integration-Tests: 69 | docker: 70 | - image: cimg/python:3.11 71 | environment: 72 | TOXENV: complex,type,check 73 | TOX_INSTALL_DIR: .env 74 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64 75 | 76 | steps: 77 | - checkout 78 | - run: 79 | name: Setup python3 80 | command: | 81 | pyenv global 3.11.3 > /dev/null && activated=0 || activated=1 82 | if [[ $activated -ne 0 ]]; then 83 | for i in {1..6}; do 84 | pyenv install 3.11.3 && break || sleep $((2 ** $i)) 85 | done 86 | pyenv global 3.11.3 87 | fi 88 | - setup_dependencies 89 | - run: 90 | name: "Run Tox" 91 | no_output_timeout: 60m 92 | command: | 93 | export PYTHONUNBUFFERED=1 94 | tox 95 | 96 | workflows: 97 | version: 2 98 | 99 | Integration-Tests: 100 | jobs: 101 | - Python311-Integration-Tests: 102 | <<: *run_complex 103 | Unit-Tests: 104 | jobs: 105 | - Python311-Unit-Tests 106 | -------------------------------------------------------------------------------- /.codecov.yaml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 95% 6 | patch: 7 | default: 8 | threshold: 2% 9 | 10 | comment: 11 | layout: "header, diff, flags, files" 12 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = snorkel 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | raise NotImplementedError 9 | if __name__ == .__main__.: 10 | def __repr__ 11 | ignore_errors = True 12 | omit = 13 | test/* 14 | *spark.py 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Let us know about a bug you found 4 | --- 5 | 6 | ## Issue description 7 | 8 | A clear and concise description of what the bug is. 9 | 10 | ## Code example/repro steps 11 | 12 | Please try to provide a minimal example to repro the bug. 13 | Error messages and stack traces are also helpful. 14 | 15 | ## Expected behavior 16 | A clear and concise description of what you expected to happen. 17 | 18 | ## Screenshots 19 | If applicable, add screenshots to help explain your problem. 20 | No screenshots of code! 21 | 22 | ## System info 23 | 24 | * How you installed Snorkel (conda, pip, source): 25 | * Build command you used (if compiling from source): 26 | * OS: 27 | * Python version: 28 | * Snorkel version: 29 | * Versions of any other relevant libraries: 30 | 31 | ## Additional context 32 | Add any other context about the problem here. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Let us know about something new you want 4 | --- 5 | 6 | ## Is your feature request related to a problem? Please describe. 7 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 8 | 9 | ## Describe the solution you'd like 10 | A clear and concise description of what you want to happen. 11 | 12 | ## Describe alternatives you've considered 13 | A clear and concise description of any alternative solutions or features you've considered. 14 | 15 | ## Additional context 16 | Add any other context or screenshots about the feature request here. 17 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description of proposed changes 2 | 3 | ## Related issue(s) 4 | 5 | Fixes # (issue) 6 | 7 | ## Test plan 8 | 9 | ## Checklist 10 | 11 | Need help on these? Just ask! 12 | 13 | * [ ] I have read the **CONTRIBUTING** document. 14 | * [ ] I have updated the documentation accordingly. 15 | * [ ] I have added tests to cover my changes. 16 | * [ ] I have run `tox -e complex` and/or `tox -e spark` if appropriate. 17 | * [ ] All new and existing tests passed. 18 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Mark/close stale issues and pull requests 2 | 3 | on: 4 | schedule: 5 | - cron: "0 12 * * *" 6 | 7 | jobs: 8 | stale: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/stale@v1 14 | with: 15 | repo-token: ${{ secrets.GITHUB_TOKEN }} 16 | stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 7 days.' 17 | stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 7 days.' 18 | stale-issue-label: 'no-issue-activity' 19 | stale-pr-label: 'no-pr-activity' 20 | exempt-issue-label: 'no-stale' 21 | days-before-stale: 90 22 | days-before-close: 7 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | .pypirc 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | docs/packages/_autosummary 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # MacOS 110 | .DS_Store 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Editors 131 | .vscode/ 132 | .code-workspace* 133 | 134 | # Dask 135 | dask-worker-space/ 136 | 137 | # nohup 138 | nohup.out 139 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Optionally set the version of Python and requirements required to build your docs 13 | python: 14 | version: 3.11 15 | install: 16 | - requirements: docs/requirements-doc.txt 17 | - method: pip 18 | path: . 19 | system_packages: true 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /RELEASING.md: -------------------------------------------------------------------------------- 1 | # Snorkel Release Guide 2 | 3 | ## Before You Start 4 | 5 | Make sure you have [PyPI](https://pypi.org) account with maintainer access to the Snorkel project. 6 | Create a .pypirc in your home directory. 7 | It should look like this: 8 | 9 | ``` 10 | [distutils] 11 | index-servers = 12 | pypi 13 | pypitest 14 | 15 | [pypi] 16 | username=YOUR_USERNAME 17 | password=YOUR_PASSWORD 18 | ``` 19 | 20 | Then run `chmod 600 ./.pypirc` so only you can read/write. 21 | 22 | 23 | ## Release Steps 24 | 25 | 1. Make sure you're in the top-level `snorkel` directory. 26 | 1. Make certain your branch is in sync with head: 27 | 28 | $ git pull origin main 29 | 30 | 1. Add a new changelog entry for the release. 31 | 32 | ## [0.9.0] 33 | ### [Breaking Changes] 34 | ### [Added] 35 | ### [Changed] 36 | ### [Deprecated] 37 | ### [Removed] 38 | Make sure `CHANGELOG.md` is up to date for the release: compare against PRs 39 | merged since the last release. 40 | 41 | 1. Update version to, e.g. 0.9.0 (remove the `+dev` label) in `snorkel/version.py`. 42 | 43 | 44 | 1. Commit these changes and create a PR: 45 | 46 | git checkout -b release-v0.9.0 47 | git add . -u 48 | git commit -m "[RELEASE]: v0.9.0" 49 | git push --set-upstream origin release-v0.9.0 50 | 51 | 1. Once the PR is approved, merge it and pull main locally. 52 | 53 | 1. Tag the release: 54 | 55 | git tag -a v0.9.0 -m "v0.9.0 release" 56 | git push origin v0.9.0 57 | 58 | 1. Build source & wheel distributions: 59 | 60 | rm -rf dist build # clean old builds & distributions 61 | python3 setup.py sdist # create a source distribution 62 | python3 setup.py bdist_wheel # create a universal wheel 63 | 64 | 1. Check that everything looks correct by installing the wheel locally and checking the version: 65 | 66 | python3 -m venv test_snorkel # create a virtualenv for testing 67 | source test_snorkel/bin/activate # activate virtualenv 68 | python3 -m pip install dist/snorkel-0.9.1-py3-none-any.whl 69 | python3 -c "import snorkel; print(snorkel.__version__)" 70 | 71 | 1. Publish to PyPI 72 | 73 | pip install twine # if not installed 74 | twine upload dist/* -r pypi 75 | 76 | 1. A PR is auto-submitted (this will take a few hours) on [`conda-forge/snorkel-feedstock`](https://github.com/conda-forge/snorkel-feedstock) to update the version. 77 | * A maintainer needs to accept and merge those changes. 78 | 79 | 1. Create a new release on Github. 80 | * Input the recently-created Tag Version: `v0.9.0` 81 | * Copy the release notes in `CHANGELOG.md` to the GitHub tag. 82 | * Attach the resulting binaries in (`dist/snorkel-x.x.x.*`) to the release. 83 | * Publish the release. 84 | 85 | 86 | 1. Update version to, e.g. 0.9.1+dev in `snorkel/version.py`. 87 | 88 | 1. Add a new changelog entry for the unreleased version in `CHANGELOG.md`: 89 | 90 | ## [Unreleased] 91 | ### [Breaking Changes] 92 | ### [Added] 93 | ### [Changed] 94 | ### [Deprecated] 95 | ### [Removed] 96 | 97 | 1. Commit these changes and create a PR: 98 | 99 | git checkout -b bump-v0.9.1+dev 100 | git add . -u 101 | git commit -m "[BUMP]: v0.9.1+dev" 102 | git push --set-upstream origin bump-v0.9.1+dev 103 | 104 | 105 | 1. Add the new tag to [the Snorkel project on ReadTheDocs](https://readthedocs.org/projects/snorkel), 106 | * Trigger a build for main to pull new tags. 107 | * Go to the "Versions" tab, and "Activate" the new tag. 108 | * Go to Admin/Advanced to set this tag as the new default version. 109 | * In "Overview", make sure a build is triggered: 110 | * For the tag `v0.9.1` 111 | * For `latest` 112 | 113 | 114 | ## Credit 115 | * [AllenNLP](https://github.com/allenai/allennlp/blob/master/setup.py) 116 | * [Altair](https://github.com/altair-viz/altair/blob/master/RELEASING.md) 117 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/octopus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/docs/_static/octopus.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Snorkel documentation master file, created by 2 | sphinx-quickstart on Fri Jul 12 17:34:20 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Snorkel's API Documentation 7 | =========================== 8 | 9 | If you're looking for technical details on Snorkel's API, 10 | you're in the right place. 11 | 12 | For more narrative walkthroughs of Snorkel fundamentals or 13 | example use cases, check out our `homepage `_ 14 | and our `tutorials repo `_. 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | :caption: Package Reference 19 | 20 | packages/analysis 21 | packages/augmentation 22 | packages/classification 23 | packages/labeling 24 | packages/map 25 | packages/preprocess 26 | packages/slicing 27 | packages/utils 28 | -------------------------------------------------------------------------------- /docs/packages.json: -------------------------------------------------------------------------------- 1 | { 2 | "packages": [ 3 | "analysis", 4 | "augmentation", 5 | "classification", 6 | "labeling", 7 | "map", 8 | "preprocess", 9 | "slicing", 10 | "utils" 11 | ], 12 | "extra_members": { 13 | "labeling": [ 14 | "apply.dask.DaskLFApplier", 15 | "apply.dask.PandasParallelLFApplier", 16 | "apply.spark.SparkLFApplier", 17 | "lf.nlp.NLPLabelingFunction", 18 | "lf.nlp.nlp_labeling_function", 19 | "lf.nlp_spark.SparkNLPLabelingFunction", 20 | "lf.nlp_spark.spark_nlp_labeling_function", 21 | "model.baselines.MajorityClassVoter", 22 | "model.baselines.MajorityLabelVoter", 23 | "model.baselines.RandomVoter", 24 | "model.label_model.LabelModel" 25 | ], 26 | "map": [ 27 | "spark.make_spark_mapper" 28 | ], 29 | "preprocess": [ 30 | "nlp.SpacyPreprocessor", 31 | "spark.make_spark_preprocessor" 32 | ], 33 | "slicing": [ 34 | "apply.dask.DaskSFApplier", 35 | "apply.dask.PandasParallelSFApplier", 36 | "apply.spark.SparkSFApplier", 37 | "sf.nlp.NLPSlicingFunction", 38 | "sf.nlp.nlp_slicing_function" 39 | ] 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /docs/packages/analysis.rst: -------------------------------------------------------------------------------- 1 | Snorkel Analysis Package 2 | ------------------------ 3 | 4 | Generic model analysis utilities shared across Snorkel. 5 | 6 | .. currentmodule:: snorkel.analysis 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/analysis/ 10 | :nosignatures: 11 | 12 | Scorer 13 | get_label_buckets 14 | get_label_instances 15 | metric_score 16 | -------------------------------------------------------------------------------- /docs/packages/augmentation.rst: -------------------------------------------------------------------------------- 1 | Snorkel Augmentation Package 2 | ---------------------------- 3 | 4 | Programmatic data set augmentation: TF creation and data generation utilities. 5 | 6 | .. currentmodule:: snorkel.augmentation 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/augmentation/ 10 | :nosignatures: 11 | 12 | ApplyAllPolicy 13 | ApplyEachPolicy 14 | ApplyOnePolicy 15 | MeanFieldPolicy 16 | PandasTFApplier 17 | RandomPolicy 18 | TFApplier 19 | TransformationFunction 20 | transformation_function 21 | -------------------------------------------------------------------------------- /docs/packages/classification.rst: -------------------------------------------------------------------------------- 1 | Snorkel Classification Package 2 | ------------------------------ 3 | 4 | PyTorch-based multi-task learning framework for discriminative modeling. 5 | 6 | .. currentmodule:: snorkel.classification 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/classification/ 10 | :nosignatures: 11 | 12 | Checkpointer 13 | CheckpointerConfig 14 | DictDataLoader 15 | DictDataset 16 | LogManager 17 | LogManagerConfig 18 | LogWriter 19 | LogWriterConfig 20 | MultitaskClassifier 21 | Operation 22 | Task 23 | TensorBoardWriter 24 | Trainer 25 | cross_entropy_with_probs 26 | -------------------------------------------------------------------------------- /docs/packages/labeling.rst: -------------------------------------------------------------------------------- 1 | Snorkel Labeling Package 2 | ------------------------ 3 | 4 | Programmatic data set labeling: LF creation, models, and analysis utilities. 5 | 6 | .. currentmodule:: snorkel.labeling 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/labeling/ 10 | :nosignatures: 11 | 12 | apply.dask.DaskLFApplier 13 | LFAnalysis 14 | LFApplier 15 | model.label_model.LabelModel 16 | LabelingFunction 17 | model.baselines.MajorityClassVoter 18 | model.baselines.MajorityLabelVoter 19 | lf.nlp.NLPLabelingFunction 20 | PandasLFApplier 21 | apply.dask.PandasParallelLFApplier 22 | model.baselines.RandomVoter 23 | apply.spark.SparkLFApplier 24 | lf.nlp_spark.SparkNLPLabelingFunction 25 | filter_unlabeled_dataframe 26 | labeling_function 27 | lf.nlp.nlp_labeling_function 28 | lf.nlp_spark.spark_nlp_labeling_function 29 | -------------------------------------------------------------------------------- /docs/packages/map.rst: -------------------------------------------------------------------------------- 1 | Snorkel Map Package 2 | ------------------- 3 | 4 | Generic utilities for data point to data point operations. 5 | 6 | .. currentmodule:: snorkel.map 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/map/ 10 | :nosignatures: 11 | 12 | BaseMapper 13 | LambdaMapper 14 | Mapper 15 | lambda_mapper 16 | spark.make_spark_mapper 17 | -------------------------------------------------------------------------------- /docs/packages/preprocess.rst: -------------------------------------------------------------------------------- 1 | Snorkel Preprocess Package 2 | -------------------------- 3 | 4 | Preprocessors for LFs, TFs, and SFs. 5 | 6 | .. currentmodule:: snorkel.preprocess 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/preprocess/ 10 | :nosignatures: 11 | 12 | BasePreprocessor 13 | LambdaPreprocessor 14 | Preprocessor 15 | nlp.SpacyPreprocessor 16 | spark.make_spark_preprocessor 17 | preprocessor 18 | -------------------------------------------------------------------------------- /docs/packages/slicing.rst: -------------------------------------------------------------------------------- 1 | Snorkel Slicing Package 2 | ----------------------- 3 | 4 | Programmatic data set slicing: SF creation, monitoring utilities, and representation learning for slices. 5 | 6 | .. currentmodule:: snorkel.slicing 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/slicing/ 10 | :nosignatures: 11 | 12 | apply.dask.DaskSFApplier 13 | sf.nlp.NLPSlicingFunction 14 | apply.dask.PandasParallelSFApplier 15 | PandasSFApplier 16 | SFApplier 17 | SliceAwareClassifier 18 | SliceCombinerModule 19 | SlicingFunction 20 | apply.spark.SparkSFApplier 21 | add_slice_labels 22 | convert_to_slice_tasks 23 | sf.nlp.nlp_slicing_function 24 | slice_dataframe 25 | slicing_function 26 | -------------------------------------------------------------------------------- /docs/packages/utils.rst: -------------------------------------------------------------------------------- 1 | Snorkel Utils Package 2 | --------------------- 3 | 4 | General machine learning utilities shared across Snorkel. 5 | 6 | .. currentmodule:: snorkel.utils 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary/utils/ 10 | :nosignatures: 11 | 12 | filter_labels 13 | preds_to_probs 14 | probs_to_preds 15 | to_int_label_array 16 | -------------------------------------------------------------------------------- /docs/requirements-doc.txt: -------------------------------------------------------------------------------- 1 | sphinx==2.4.5 2 | sphinx_autodoc_typehints==1.7.0 3 | sphinx_rtd_theme==0.4.3 4 | https://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp36-cp36m-linux_x86_64.whl 5 | -------------------------------------------------------------------------------- /figs/ONR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/ONR.jpg -------------------------------------------------------------------------------- /figs/WS_pipeline2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/WS_pipeline2.pdf -------------------------------------------------------------------------------- /figs/darpa.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/darpa.JPG -------------------------------------------------------------------------------- /figs/dp_neurips_2016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/dp_neurips_2016.png -------------------------------------------------------------------------------- /figs/logo_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/logo_01.png -------------------------------------------------------------------------------- /figs/mobilize_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/mobilize_logo.png -------------------------------------------------------------------------------- /figs/moore_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/moore_logo.png -------------------------------------------------------------------------------- /figs/nih_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/nih_logo.png -------------------------------------------------------------------------------- /figs/user_logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/user_logos.png -------------------------------------------------------------------------------- /figs/vldb2018_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/vldb2018_logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.6.2", 4 | "wheel >= 0.30.0", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.black] 9 | line-length = 88 10 | target-version = ['py311'] 11 | exclude = ''' 12 | /( 13 | \.eggs 14 | | \.git 15 | | \.mypy_cache 16 | | \.tox 17 | | \.env 18 | | \.venv 19 | | _build 20 | | build 21 | | dist 22 | )/ 23 | ''' -------------------------------------------------------------------------------- /requirements-pyspark.txt: -------------------------------------------------------------------------------- 1 | # Note: we don't include PySpark in the normal required installs. 2 | # Installing a new version may overwrite your existing system install. 3 | pyspark==3.4.1 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for Python code. You need to install these with 2 | # `pip install -r requirements.txt` or 3 | # `conda install --file requirements.txt` 4 | # to ensure that you can use all Snorkel code. 5 | # NOTE: all essential packages must be placed under a section named 6 | # '#### ESSENTIAL ...' so that the script `./scripts/check_requirements.py` 7 | # can find them. 8 | 9 | #### ESSENTIAL LIBRARIES 10 | 11 | # General scientific computing 12 | numpy>=1.24.0 13 | scipy>=1.2.0 14 | 15 | # Data storage and function application 16 | pandas>=1.0.0 17 | tqdm>=4.33.0 18 | 19 | # Internal models 20 | scikit-learn>=0.20.2 21 | torch>=1.2.0 22 | munkres>=1.0.6 23 | 24 | # LF dependency learning 25 | networkx>=2.2 26 | 27 | # Model introspection tools 28 | protobuf>=3.19.6 29 | tensorboard>=2.13.0 30 | 31 | #### EXTRA/TEST LIBRARIES 32 | 33 | # spaCy (NLP) 34 | spacy>=2.1.0 35 | blis>=0.3.0 36 | 37 | # Dask (parallelism) 38 | dask[dataframe]>=2020.12.0 39 | distributed>=2023.7.0 40 | 41 | # Dill (serialization) 42 | dill>=0.3.0 43 | 44 | #### DEV TOOLS 45 | 46 | black>=22.8 47 | flake8>=3.7.0 48 | importlib_metadata<5 # necessary for flake8 49 | isort>=4.3.0 50 | mypy>=0.760 51 | pydocstyle>=4.0.0 52 | pytest>=6.0.0 53 | pytest-cov>=2.7.0 54 | pytest-doctestplus>=0.3.0 55 | tox>=3.13.0 56 | -------------------------------------------------------------------------------- /scripts/sync_api_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Check and update API docs under docs/packages. 4 | 5 | This script checks and updates the package documentation pages, making sure 6 | that the packages in docs/packages.json are documented and up to date. 7 | Rather than calling this directly, use `tox -e check` or `tox -e fix`. 8 | """ 9 | 10 | import json 11 | import os 12 | import sys 13 | from importlib import import_module 14 | from typing import Any, List 15 | 16 | PACKAGE_INFO_PATH = "docs/packages.json" 17 | PACKAGE_PAGE_PATH = "docs/packages" 18 | 19 | 20 | PACKAGE_DOC_TEMPLATE = """{title} 21 | {underscore} 22 | 23 | {docstring} 24 | 25 | .. currentmodule:: snorkel.{package_name} 26 | 27 | .. autosummary:: 28 | :toctree: _autosummary/{package_name}/ 29 | :nosignatures: 30 | 31 | {members} 32 | """ 33 | 34 | 35 | def get_title_and_underscore(package_name: str) -> str: 36 | title = f"Snorkel {package_name.capitalize()} Package" 37 | underscore = "-" * len(title) 38 | return title, underscore 39 | 40 | 41 | def get_package_members(package: Any) -> List[str]: 42 | members = [] 43 | for name in dir(package): 44 | if name.startswith("_"): 45 | continue 46 | obj = getattr(package, name) 47 | if isinstance(obj, type) or callable(obj): 48 | members.append(name) 49 | return members 50 | 51 | 52 | def main(check: bool) -> None: 53 | with open(PACKAGE_INFO_PATH, "r") as f: 54 | packages_info = json.load(f) 55 | package_names = sorted(packages_info["packages"]) 56 | if check: 57 | f_basenames = sorted( 58 | [ 59 | os.path.splitext(f_name)[0] 60 | for f_name in os.listdir(PACKAGE_PAGE_PATH) 61 | if f_name.endswith(".rst") 62 | ] 63 | ) 64 | if f_basenames != package_names: 65 | raise ValueError( 66 | "Expected package files do not match actual!\n" 67 | f"Expected: {package_names}\n" 68 | f"Actual: {f_basenames}" 69 | ) 70 | else: 71 | os.makedirs(PACKAGE_PAGE_PATH, exist_ok=True) 72 | for package_name in package_names: 73 | package = import_module(f"snorkel.{package_name}") 74 | docstring = package.__doc__ 75 | title, underscore = get_title_and_underscore(package_name) 76 | all_members = get_package_members(package) 77 | all_members.extend(packages_info["extra_members"].get(package_name, [])) 78 | contents = PACKAGE_DOC_TEMPLATE.format( 79 | title=title, 80 | underscore=underscore, 81 | docstring=docstring, 82 | package_name=package_name, 83 | members="\n ".join(sorted(all_members, key=lambda s: s.split(".")[-1])), 84 | ) 85 | f_path = os.path.join(PACKAGE_PAGE_PATH, f"{package_name}.rst") 86 | if check: 87 | with open(f_path, "r") as f: 88 | contents_actual = f.read() 89 | if contents != contents_actual: 90 | raise ValueError(f"Contents for {package_name} differ!") 91 | else: 92 | with open(f_path, "w") as f: 93 | f.write(contents) 94 | 95 | 96 | if __name__ == "__main__": 97 | check = False if len(sys.argv) == 1 else (sys.argv[1] == "--check") 98 | sys.exit(main(check)) 99 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = test 3 | markers = 4 | spark 5 | complex 6 | norecursedirs = .tox 7 | doctest_optionflags = 8 | NORMALIZE_WHITESPACE 9 | ELLIPSIS 10 | FLOAT_CMP 11 | 12 | [flake8] 13 | extend-ignore = 14 | E203, 15 | # Throws errors for '#%%' delimiter in VSCode jupyter notebook syntax 16 | E265, 17 | E501, 18 | E731, 19 | E741, 20 | exclude = 21 | .eggs, 22 | .git, 23 | .mypy_cache, 24 | .tox, 25 | .env, 26 | .venv, 27 | _build, 28 | build, 29 | dist 30 | 31 | [isort] 32 | multi_line_output=3 33 | include_trailing_comma=True 34 | force_grid_wrap=0 35 | combine_as_imports=True 36 | line_length=88 37 | known_first_party= 38 | snorkel, 39 | known_third_party= 40 | numpy, 41 | pandas, 42 | pyspark, 43 | scipy, 44 | setuptools, 45 | tqdm, 46 | default_section=THIRDPARTY 47 | skip=.env,.venv,.tox 48 | 49 | [pydocstyle] 50 | convention = numpy 51 | add-ignore = 52 | D100, 53 | D104, 54 | D105, 55 | D107, 56 | D202, 57 | D203, 58 | D204, 59 | D213, 60 | D413, 61 | 62 | [mypy] 63 | 64 | [mypy-dask] 65 | ignore_missing_imports = True 66 | 67 | [mypy-dask.distributed] 68 | ignore_missing_imports = True 69 | 70 | [mypy-networkx] 71 | ignore_missing_imports = True 72 | 73 | [mypy-numpy] 74 | ignore_missing_imports = True 75 | 76 | [mypy-numpy.lib] 77 | ignore_missing_imports = True 78 | 79 | [mypy-pandas] 80 | ignore_missing_imports = True 81 | 82 | [mypy-scipy] 83 | ignore_missing_imports = True 84 | 85 | [mypy-scipy.sparse] 86 | ignore_missing_imports = True 87 | 88 | [mypy-sklearn] 89 | ignore_missing_imports = True 90 | 91 | [mypy-sklearn.metrics] 92 | ignore_missing_imports = True 93 | 94 | [mypy-pyspark] 95 | ignore_missing_imports = True 96 | 97 | [mypy-pyspark.sql] 98 | ignore_missing_imports = True 99 | 100 | [mypy-spacy] 101 | ignore_missing_imports = True 102 | 103 | [mypy-tqdm] 104 | ignore_missing_imports = True 105 | 106 | [mypy-torch] 107 | ignore_missing_imports = True 108 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # version.py defines the VERSION and VERSION_SHORT variables. 6 | # We use exec here so we don't import snorkel. 7 | VERSION: Dict[str, str] = {} 8 | with open("snorkel/version.py", "r") as version_file: 9 | exec(version_file.read(), VERSION) 10 | 11 | # Use README.md as the long_description for the package 12 | with open("README.md", "r") as readme_file: 13 | long_description = readme_file.read() 14 | 15 | setup( 16 | name="snorkel", 17 | version=VERSION["VERSION"], 18 | url="https://github.com/snorkel-team/snorkel", 19 | description="A system for quickly generating training data with weak supervision", 20 | long_description_content_type="text/markdown", 21 | long_description=long_description, 22 | license="Apache License 2.0", 23 | classifiers=[ 24 | "Intended Audience :: Science/Research", 25 | "Topic :: Scientific/Engineering :: Information Analysis", 26 | "License :: OSI Approved :: Apache Software License", 27 | "Programming Language :: Python :: 3", 28 | ], 29 | project_urls={ 30 | "Homepage": "https://snorkel.org", 31 | "Source": "https://github.com/snorkel-team/snorkel/", 32 | "Bug Reports": "https://github.com/snorkel-team/snorkel/issues", 33 | "Citation": "https://doi.org/10.14778/3157794.3157797", 34 | }, 35 | packages=find_packages(exclude=("test*",)), 36 | include_package_data=True, 37 | install_requires=[ 38 | "munkres>=1.0.6", 39 | "numpy>=1.24.0", 40 | "scipy>=1.2.0", 41 | "pandas>=1.0.0", 42 | "tqdm>=4.33.0", 43 | "scikit-learn>=0.20.2", 44 | "torch>=1.2.0", 45 | "tensorboard>=2.13.0", 46 | "protobuf>=3.19.6", 47 | "networkx>=2.2", 48 | ], 49 | python_requires=">=3.11", 50 | keywords="machine-learning ai weak-supervision", 51 | ) 52 | -------------------------------------------------------------------------------- /snorkel/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import VERSION as __version__ # noqa: F401 2 | -------------------------------------------------------------------------------- /snorkel/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """Generic model analysis utilities shared across Snorkel.""" 2 | 3 | from .error_analysis import get_label_buckets, get_label_instances # noqa: F401 4 | from .metrics import metric_score # noqa: F401 5 | from .scorer import Scorer # noqa: F401 6 | -------------------------------------------------------------------------------- /snorkel/analysis/error_analysis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import DefaultDict, Dict, List, Tuple 4 | 5 | import numpy as np 6 | 7 | from snorkel.utils import to_int_label_array 8 | 9 | 10 | def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]: 11 | """Return data point indices bucketed by label combinations. 12 | 13 | Parameters 14 | ---------- 15 | *y 16 | A list of np.ndarray of (int) labels 17 | 18 | Returns 19 | ------- 20 | Dict[Tuple[int, ...], np.ndarray] 21 | A mapping of each label bucket to a NumPy array of its corresponding indices 22 | 23 | Example 24 | ------- 25 | A common use case is calling ``buckets = label_buckets(Y_gold, Y_pred)`` where 26 | ``Y_gold`` is a set of gold (i.e. ground truth) labels and ``Y_pred`` is a 27 | corresponding set of predicted labels. 28 | 29 | >>> Y_gold = np.array([1, 1, 1, 0]) 30 | >>> Y_pred = np.array([1, 1, -1, -1]) 31 | >>> buckets = get_label_buckets(Y_gold, Y_pred) 32 | 33 | The returned ``buckets[(i, j)]`` is a NumPy array of data point indices with 34 | true label i and predicted label j. 35 | 36 | More generally, the returned indices within each bucket refer to the order of the 37 | labels that were passed in as function arguments. 38 | 39 | >>> buckets[(1, 1)] # true positives 40 | array([0, 1]) 41 | >>> (1, 0) in buckets # false positives 42 | False 43 | >>> (0, 1) in buckets # false negatives 44 | False 45 | >>> (0, 0) in buckets # true negatives 46 | False 47 | >>> buckets[(1, -1)] # abstained positives 48 | array([2]) 49 | >>> buckets[(0, -1)] # abstained negatives 50 | array([3]) 51 | """ 52 | buckets: DefaultDict[Tuple[int, int], List[int]] = defaultdict(list) 53 | y_flat = list(map(lambda x: to_int_label_array(x, flatten_vector=True), y)) 54 | if len(set(map(len, y_flat))) != 1: 55 | raise ValueError("Arrays must all have the same number of elements") 56 | for i, labels in enumerate(zip(*y_flat)): 57 | buckets[labels].append(i) 58 | return {k: np.array(v) for k, v in buckets.items()} 59 | 60 | 61 | def get_label_instances( 62 | bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray 63 | ) -> np.ndarray: 64 | """Return instances in x with the specified combination of labels. 65 | 66 | Parameters 67 | ---------- 68 | bucket 69 | A tuple of label values corresponding to which instances from x are returned 70 | x 71 | NumPy array of data instances to be returned 72 | *y 73 | A list of np.ndarray of (int) labels 74 | 75 | Returns 76 | ------- 77 | np.ndarray 78 | NumPy array of instances from x with the specified combination of labels 79 | 80 | Example 81 | ------- 82 | A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)`` 83 | where ``x`` is a NumPy array of data instances that the labels correspond to, 84 | ``Y_gold`` is a list of gold (i.e. ground truth) labels, and 85 | ``Y_pred`` is a corresponding list of predicted labels. 86 | 87 | >>> import pandas as pd 88 | >>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]}) 89 | >>> Y_gold = np.array([1, 1, 1]) 90 | >>> Y_pred = np.array([1, 0, 0]) 91 | >>> bucket = (1, 0) 92 | 93 | The returned NumPy array of data instances from ``x`` will correspond to 94 | the rows where the first list had a 1 and the second list had a 0. 95 | >>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred) 96 | array([['a second string', '2'], 97 | ['a third string', '3']], dtype=object) 98 | 99 | More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...`` 100 | the returned data instances from ``x`` will correspond to the rows where 101 | y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y`` 102 | must all be the same length. 103 | """ 104 | if len(y) != len(bucket): 105 | raise ValueError("Number of lists must match the amount of labels in bucket") 106 | if x.shape[0] != len(y[0]): 107 | # Note: the check for all y having the same number of elements occurs in get_label_buckets 108 | raise ValueError( 109 | "Number of rows in x does not match number of elements in at least one label list" 110 | ) 111 | buckets = get_label_buckets(*y) 112 | try: 113 | indices = buckets[bucket] 114 | except KeyError: 115 | logging.warning("Bucket" + str(bucket) + " does not exist.") 116 | return np.array([]) 117 | instances = x[indices] 118 | return instances 119 | -------------------------------------------------------------------------------- /snorkel/analysis/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, NamedTuple, Optional 2 | 3 | import numpy as np 4 | import sklearn.metrics as skmetrics 5 | 6 | from snorkel.utils import filter_labels, to_int_label_array 7 | 8 | 9 | class Metric(NamedTuple): 10 | """Specification for a metric and the subset of [golds, preds, probs] it expects.""" 11 | 12 | func: Callable[..., float] 13 | inputs: List[str] = ["golds", "preds"] 14 | 15 | 16 | def metric_score( 17 | golds: Optional[np.ndarray] = None, 18 | preds: Optional[np.ndarray] = None, 19 | probs: Optional[np.ndarray] = None, 20 | metric: str = "accuracy", 21 | filter_dict: Optional[Dict[str, List[int]]] = None, 22 | **kwargs: Any, 23 | ) -> float: 24 | """Evaluate a standard metric on a set of predictions/probabilities. 25 | 26 | Parameters 27 | ---------- 28 | golds 29 | An array of gold (int) labels 30 | preds 31 | An array of (int) predictions 32 | probs 33 | An [n_datapoints, n_classes] array of probabilistic (float) predictions 34 | metric 35 | The name of the metric to calculate 36 | filter_dict 37 | A mapping from label set name to the labels that should be filtered out for 38 | that label set 39 | 40 | Returns 41 | ------- 42 | float 43 | The value of the requested metric 44 | 45 | Raises 46 | ------ 47 | ValueError 48 | The requested metric is not currently supported 49 | ValueError 50 | The user attempted to calculate roc_auc score for a non-binary problem 51 | """ 52 | if metric not in METRICS: 53 | msg = f"The metric you provided ({metric}) is not currently implemented." 54 | raise ValueError(msg) 55 | 56 | # Print helpful error messages if golds or preds has invalid shape or type 57 | golds = to_int_label_array(golds) if golds is not None else None 58 | preds = to_int_label_array(preds) if preds is not None else None 59 | 60 | # Optionally filter out examples (e.g., abstain predictions or unknown labels) 61 | label_dict: Dict[str, Optional[np.ndarray]] = { 62 | "golds": golds, 63 | "preds": preds, 64 | "probs": probs, 65 | } 66 | if filter_dict: 67 | if set(filter_dict.keys()).difference(set(label_dict.keys())): 68 | raise ValueError( 69 | "filter_dict must only include keys in ['golds', 'preds', 'probs']" 70 | ) 71 | # label_dict is overwritten from type Dict[str, Optional[np.ndarray]] 72 | # to Dict[str, np.ndarray] 73 | label_dict = filter_labels(label_dict, filter_dict) # type: ignore 74 | 75 | # Confirm that required label sets are available 76 | func, label_names = METRICS[metric] 77 | for label_name in label_names: 78 | if label_dict[label_name] is None: 79 | raise ValueError(f"Metric {metric} requires access to {label_name}.") 80 | 81 | label_sets = [label_dict[label_name] for label_name in label_names] 82 | return func(*label_sets, **kwargs) 83 | 84 | 85 | def _coverage_score(preds: np.ndarray) -> float: 86 | return np.sum(preds != -1) / len(preds) 87 | 88 | 89 | def _roc_auc_score(golds: np.ndarray, probs: np.ndarray) -> float: 90 | if not probs.shape[1] == 2: 91 | raise ValueError( 92 | "Metric roc_auc is currently only defined for binary problems." 93 | ) 94 | return skmetrics.roc_auc_score(golds, probs[:, 1]) 95 | 96 | 97 | def _f1_score(golds: np.ndarray, preds: np.ndarray) -> float: 98 | if golds.max() <= 1: 99 | return skmetrics.f1_score(golds, preds) 100 | else: 101 | raise ValueError( 102 | "f1 not supported for multiclass. Try f1_micro or f1_macro instead." 103 | ) 104 | 105 | 106 | def _f1_micro_score(golds: np.ndarray, preds: np.ndarray) -> float: 107 | return skmetrics.f1_score(golds, preds, average="micro") 108 | 109 | 110 | def _f1_macro_score(golds: np.ndarray, preds: np.ndarray) -> float: 111 | return skmetrics.f1_score(golds, preds, average="macro") 112 | 113 | 114 | # See https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics 115 | # for details on the definitions and available kwargs for all metrics from scikit-learn 116 | METRICS = { 117 | "accuracy": Metric(skmetrics.accuracy_score), 118 | "coverage": Metric(_coverage_score, ["preds"]), 119 | "precision": Metric(skmetrics.precision_score), 120 | "recall": Metric(skmetrics.recall_score), 121 | "f1": Metric(_f1_score, ["golds", "preds"]), 122 | "f1_micro": Metric(_f1_micro_score, ["golds", "preds"]), 123 | "f1_macro": Metric(_f1_macro_score, ["golds", "preds"]), 124 | "fbeta": Metric(skmetrics.fbeta_score), 125 | "matthews_corrcoef": Metric(skmetrics.matthews_corrcoef), 126 | "roc_auc": Metric(_roc_auc_score, ["golds", "probs"]), 127 | } 128 | -------------------------------------------------------------------------------- /snorkel/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | """Programmatic data set augmentation: TF creation and data generation utilities.""" 2 | 3 | from .apply.core import TFApplier # noqa: F401 4 | from .apply.pandas import PandasTFApplier # noqa: F401 5 | from .policy.core import ApplyAllPolicy, ApplyEachPolicy, ApplyOnePolicy # noqa: F401 6 | from .policy.sampling import MeanFieldPolicy, RandomPolicy # noqa: F401 7 | from .tf import TransformationFunction, transformation_function # noqa: F401 8 | -------------------------------------------------------------------------------- /snorkel/augmentation/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/augmentation/apply/__init__.py -------------------------------------------------------------------------------- /snorkel/augmentation/apply/core.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from tqdm import tqdm 4 | 5 | from snorkel.augmentation.policy.core import Policy 6 | from snorkel.augmentation.tf import BaseTransformationFunction 7 | from snorkel.types import DataPoint, DataPoints 8 | from snorkel.utils.data_operators import check_unique_names 9 | 10 | 11 | class BaseTFApplier: 12 | """Base class for TF applier objects. 13 | 14 | Base class for TF applier objects, which execute a set of TF 15 | on a collection of data points. Subclasses should operate on 16 | a single data point collection format (e.g. ``DataFrame``). 17 | Subclasses must implement the ``apply`` method. 18 | 19 | Parameters 20 | ---------- 21 | tfs 22 | TFs that this applier executes on examples 23 | policy 24 | Augmentation policy used to generate sequences of TFs 25 | 26 | Raises 27 | ------ 28 | ValueError 29 | If names of TFs are not unique 30 | """ 31 | 32 | def __init__(self, tfs: List[BaseTransformationFunction], policy: Policy) -> None: 33 | self._tfs = tfs 34 | self._tf_names = [tf.name for tf in tfs] 35 | check_unique_names(self._tf_names) 36 | self._policy = policy 37 | 38 | def _apply_policy_to_data_point(self, x: DataPoint) -> DataPoints: 39 | x_transformed = [] 40 | for seq in self._policy.generate_for_example(): 41 | x_t = x 42 | # Handle empty sequence for `keep_original` 43 | transform_applied = len(seq) == 0 44 | # Apply TFs 45 | for tf_idx in seq: 46 | tf = self._tfs[tf_idx] 47 | x_t_or_none = tf(x_t) 48 | # Update if transformation was applied 49 | if x_t_or_none is not None: 50 | transform_applied = True 51 | x_t = x_t_or_none 52 | # Add example if original or transformations applied 53 | if transform_applied: 54 | x_transformed.append(x_t) 55 | return x_transformed 56 | 57 | def __repr__(self) -> str: 58 | policy_name = type(self._policy).__name__ 59 | return f"{type(self).__name__}, Policy: {policy_name}, TFs: {self._tf_names}" 60 | 61 | 62 | class TFApplier(BaseTFApplier): 63 | """TF applier for a list of data points. 64 | 65 | Augments a list of data points (e.g. ``SimpleNamespace``). Primarily 66 | useful for testing. 67 | """ 68 | 69 | def apply_generator( 70 | self, data_points: DataPoints, batch_size: int 71 | ) -> Iterator[List[DataPoint]]: 72 | """Augment a list of data points using TFs and policy in batches. 73 | 74 | This method acts as a generator, yielding augmented data points for 75 | a given input batch of data points. This can be useful in a training 76 | loop when it is too memory-intensive to pregenerate all transformed 77 | examples. 78 | 79 | Parameters 80 | ---------- 81 | data_points 82 | List containing data points to be transformed 83 | batch_size 84 | Batch size for generator. Yields augmented data points 85 | for the next ``batch_size`` input data points. 86 | 87 | Yields 88 | ------ 89 | List[DataPoint] 90 | List of data points in augmented data set for batches of inputs 91 | """ 92 | for i in range(0, len(data_points), batch_size): 93 | batch_transformed: List[DataPoint] = [] 94 | for x in data_points[i : i + batch_size]: 95 | batch_transformed.extend(self._apply_policy_to_data_point(x)) 96 | yield batch_transformed 97 | 98 | def apply( 99 | self, data_points: DataPoints, progress_bar: bool = True 100 | ) -> List[DataPoint]: 101 | """Augment a list of data points using TFs and policy. 102 | 103 | Parameters 104 | ---------- 105 | data_points 106 | List containing data points to be transformed 107 | progress_bar 108 | Display a progress bar? 109 | 110 | Returns 111 | ------- 112 | List[DataPoint] 113 | List of data points in augmented data set 114 | """ 115 | x_transformed: List[DataPoint] = [] 116 | for x in tqdm(data_points, disable=(not progress_bar)): 117 | x_transformed.extend(self._apply_policy_to_data_point(x)) 118 | return x_transformed 119 | -------------------------------------------------------------------------------- /snorkel/augmentation/apply/pandas.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from .core import BaseTFApplier 7 | 8 | 9 | class PandasTFApplier(BaseTFApplier): 10 | """TF applier for a Pandas DataFrame. 11 | 12 | Data points are stored as Series in a DataFrame. The TFs 13 | run on data points obtained via a ``pandas.DataFrame.iterrows`` 14 | call, which is single-process and can be slow for large DataFrames. 15 | For large datasets, consider ``DaskTFApplier`` or ``SparkTFApplier``. 16 | """ 17 | 18 | def apply_generator(self, df: pd.DataFrame, batch_size: int) -> pd.DataFrame: 19 | """Augment a Pandas DataFrame of data points using TFs and policy in batches. 20 | 21 | This method acts as a generator, yielding augmented data points for 22 | a given input batch of data points. This can be useful in a training 23 | loop when it is too memory-intensive to pregenerate all transformed 24 | examples. 25 | 26 | Parameters 27 | ---------- 28 | df 29 | Pandas DataFrame containing data points to be transformed 30 | batch_size 31 | Batch size for generator. Yields augmented data points 32 | for the next ``batch_size`` input data points. 33 | 34 | Returns 35 | ------- 36 | pd.DataFrame 37 | Pandas DataFrame of data points in augmented data set 38 | """ 39 | batch_transformed: List[pd.Series] = [] 40 | for i, (_, x) in enumerate(df.iterrows()): 41 | batch_transformed.extend(self._apply_policy_to_data_point(x)) 42 | if (i + 1) % batch_size == 0: 43 | yield pd.concat(batch_transformed, axis=1).T.infer_objects() 44 | batch_transformed = [] 45 | yield pd.concat(batch_transformed, axis=1).T.infer_objects() 46 | 47 | def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> pd.DataFrame: 48 | """Augment a Pandas DataFrame of data points using TFs and policy. 49 | 50 | Parameters 51 | ---------- 52 | df 53 | Pandas DataFrame containing data points to be transformed 54 | progress_bar 55 | Display a progress bar? 56 | 57 | Returns 58 | ------- 59 | pd.DataFrame 60 | Pandas DataFrame of data points in augmented data set 61 | """ 62 | x_transformed: List[pd.Series] = [] 63 | for _, x in tqdm(df.iterrows(), total=len(df), disable=(not progress_bar)): 64 | x_transformed.extend(self._apply_policy_to_data_point(x)) 65 | return pd.concat(x_transformed, axis=1).T.infer_objects() 66 | -------------------------------------------------------------------------------- /snorkel/augmentation/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/augmentation/policy/__init__.py -------------------------------------------------------------------------------- /snorkel/augmentation/policy/sampling.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Sequence 2 | 3 | import numpy as np 4 | 5 | from .core import Policy 6 | 7 | 8 | class MeanFieldPolicy(Policy): 9 | """Sample sequences of TFs according to a distribution. 10 | 11 | Samples sequences of indices of a specified length from a 12 | user-provided distribution. A distribution over TFs can be 13 | learned by a TANDA mean-field model, for example. 14 | See https://hazyresearch.github.io/snorkel/blog/tanda.html 15 | 16 | Parameters 17 | ---------- 18 | n_tfs 19 | Total number of TFs 20 | sequence_length 21 | Number of TFs to run on each data point 22 | p 23 | Probability distribution from which to sample TF indices. 24 | Must have length ``n_tfs`` and be a valid distribution. 25 | n_per_original 26 | Number of transformed data points per original 27 | keep_original 28 | Keep untransformed data point in augmented data set? Note that 29 | even if in-place modifications are made to the original data 30 | point by the TFs being applied, the original data point will 31 | remain unchanged. 32 | 33 | Attributes 34 | ---------- 35 | n 36 | Total number of TFs 37 | n_per_original 38 | See above 39 | keep_original 40 | See above 41 | sequence_length 42 | See above 43 | """ 44 | 45 | def __init__( 46 | self, 47 | n_tfs: int, 48 | sequence_length: int = 1, 49 | p: Optional[Sequence[float]] = None, 50 | n_per_original: int = 1, 51 | keep_original: bool = True, 52 | ) -> None: 53 | self.sequence_length = sequence_length 54 | self._p = p 55 | super().__init__( 56 | n_tfs, n_per_original=n_per_original, keep_original=keep_original 57 | ) 58 | 59 | def generate(self) -> List[int]: 60 | """Generate a sequence of TF indices by sampling from distribution. 61 | 62 | Returns 63 | ------- 64 | List[int] 65 | Indices of TFs to run on data point in order. 66 | """ 67 | return np.random.choice(self.n, size=self.sequence_length, p=self._p).tolist() 68 | 69 | 70 | class RandomPolicy(MeanFieldPolicy): 71 | """Naive random augmentation policy. 72 | 73 | Samples sequences of TF indices a specified length at random 74 | from the total number of TFs. Sampling uniformly at random is 75 | a common baseline approach to data augmentation. 76 | 77 | Parameters 78 | ---------- 79 | n_tfs 80 | Total number of TFs 81 | sequence_length 82 | Number of TFs to run on each data point 83 | n_per_original 84 | Number of transformed data points per original 85 | keep_original 86 | Keep untransformed data point in augmented data set? Note that 87 | even if in-place modifications are made to the original data 88 | point by the TFs being applied, the original data point will 89 | remain unchanged. 90 | 91 | Attributes 92 | ---------- 93 | n 94 | Total number of TFs 95 | n_per_original 96 | See above 97 | keep_original 98 | See above 99 | sequence_length 100 | See above 101 | """ 102 | 103 | def __init__( 104 | self, 105 | n_tfs: int, 106 | sequence_length: int = 1, 107 | n_per_original: int = 1, 108 | keep_original: bool = True, 109 | ) -> None: 110 | super().__init__( 111 | n_tfs, 112 | sequence_length=sequence_length, 113 | p=None, 114 | n_per_original=n_per_original, 115 | keep_original=keep_original, 116 | ) 117 | -------------------------------------------------------------------------------- /snorkel/augmentation/tf.py: -------------------------------------------------------------------------------- 1 | from snorkel.map import BaseMapper, LambdaMapper, Mapper, lambda_mapper 2 | 3 | """Base classes for transformation functions. 4 | 5 | A transformation function (TF) represents an atomic transformation 6 | to a data point in a data augmentation pipeline. Common examples in 7 | image processing include small image rotations or crops. Snorkel 8 | models data augmentation as a sequence of TFs generated by a policy. 9 | """ 10 | 11 | # Used for type checking only 12 | # Note: subclassing as below trips up mypy 13 | BaseTransformationFunction = BaseMapper 14 | 15 | 16 | class TransformationFunction(Mapper): 17 | """Base class for TFs. 18 | 19 | See ``snorkel.map.core.Mapper`` for details. 20 | """ 21 | 22 | pass 23 | 24 | 25 | class LambdaTransformationFunction(LambdaMapper): 26 | """Convenience class for definining TFs from functions. 27 | 28 | See ``snorkel.map.core.LambdaMapper`` for details. 29 | """ 30 | 31 | pass 32 | 33 | 34 | class transformation_function(lambda_mapper): 35 | """Decorate functions to create TFs. 36 | 37 | See ``snorkel.map.core.lambda_mapper`` for details. 38 | 39 | Example 40 | ------- 41 | >>> @transformation_function() 42 | ... def square(x): 43 | ... x.num = x.num ** 2 44 | ... return x 45 | >>> from types import SimpleNamespace 46 | >>> square(SimpleNamespace(num=2)) 47 | namespace(num=4) 48 | """ 49 | 50 | pass 51 | -------------------------------------------------------------------------------- /snorkel/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch-based multi-task learning framework for discriminative modeling.""" 2 | 3 | from .data import DictDataLoader, DictDataset # noqa: F401 4 | from .loss import cross_entropy_with_probs # noqa: F401 5 | from .multitask_classifier import MultitaskClassifier # noqa: F401 6 | from .task import Operation, Task # noqa: F401 7 | from .training.loggers import ( # noqa: F401 8 | Checkpointer, 9 | CheckpointerConfig, 10 | LogManager, 11 | LogManagerConfig, 12 | LogWriter, 13 | LogWriterConfig, 14 | TensorBoardWriter, 15 | ) 16 | from .training.trainer import Trainer # noqa: F401 17 | -------------------------------------------------------------------------------- /snorkel/classification/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | Outputs = Mapping[str, List[torch.Tensor]] 7 | 8 | 9 | def cross_entropy_with_probs( 10 | input: torch.Tensor, 11 | target: torch.Tensor, 12 | weight: Optional[torch.Tensor] = None, 13 | reduction: str = "mean", 14 | ) -> torch.Tensor: 15 | """Calculate cross-entropy loss when targets are probabilities (floats), not ints. 16 | 17 | PyTorch's F.cross_entropy() method requires integer labels; it does accept 18 | probabilistic labels. We can, however, simulate such functionality with a for loop, 19 | calculating the loss contributed by each class and accumulating the results. 20 | Libraries such as keras do not require this workaround, as methods like 21 | "categorical_crossentropy" accept float labels natively. 22 | 23 | Note that the method signature is intentionally very similar to F.cross_entropy() 24 | so that it can be used as a drop-in replacement when target labels are changed from 25 | from a 1D tensor of ints to a 2D tensor of probabilities. 26 | 27 | Parameters 28 | ---------- 29 | input 30 | A [num_points, num_classes] tensor of logits 31 | target 32 | A [num_points, num_classes] tensor of probabilistic target labels 33 | weight 34 | An optional [num_classes] array of weights to multiply the loss by per class 35 | reduction 36 | One of "none", "mean", "sum", indicating whether to return one loss per data 37 | point, the mean loss, or the sum of losses 38 | 39 | Returns 40 | ------- 41 | torch.Tensor 42 | The calculated loss 43 | 44 | Raises 45 | ------ 46 | ValueError 47 | If an invalid reduction keyword is submitted 48 | """ 49 | num_points, num_classes = input.shape 50 | # Note that t.new_zeros, t.new_full put tensor on same device as t 51 | cum_losses = input.new_zeros(num_points) 52 | for y in range(num_classes): 53 | target_temp = input.new_full((num_points,), y, dtype=torch.long) 54 | y_loss = F.cross_entropy(input, target_temp, reduction="none") 55 | if weight is not None: 56 | y_loss = y_loss * weight[y] 57 | cum_losses += target[:, y].float() * y_loss 58 | 59 | if reduction == "none": 60 | return cum_losses 61 | elif reduction == "mean": 62 | return cum_losses.mean() 63 | elif reduction == "sum": 64 | return cum_losses.sum() 65 | else: 66 | raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']") 67 | -------------------------------------------------------------------------------- /snorkel/classification/task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from snorkel.analysis import Scorer 10 | 11 | Outputs = Mapping[str, List[torch.FloatTensor]] 12 | 13 | 14 | class Operation: 15 | """A single operation (forward pass of a module) to execute in a Task. 16 | 17 | See ``Task`` for more detail on the usage and semantics of an Operation. 18 | 19 | Parameters 20 | ---------- 21 | name 22 | The name of this operation (defaults to module_name since for most workflows, 23 | each module is only used once per forward pass) 24 | module_name 25 | The name of the module in the module pool that this operation uses 26 | inputs 27 | The inputs that the specified module expects, given as a list of names of 28 | previous operations (or optionally a tuple of the operation name and a key 29 | if the output of that module is a dict instead of a Tensor). 30 | Note that the original input to the model can be referred to as "_input_". 31 | 32 | Example 33 | ------- 34 | >>> op1 = Operation(module_name="linear1", inputs=[("_input_", "features")]) 35 | >>> op2 = Operation(module_name="linear2", inputs=["linear1"]) 36 | >>> op_sequence = [op1, op2] 37 | 38 | Attributes 39 | ---------- 40 | name 41 | See above 42 | module_name 43 | See above 44 | inputs 45 | See above 46 | """ 47 | 48 | def __init__( 49 | self, 50 | module_name: str, 51 | inputs: Sequence[Union[str, Tuple[str, str]]], 52 | name: Optional[str] = None, 53 | ) -> None: 54 | self.name = name or module_name 55 | self.module_name = module_name 56 | self.inputs = inputs 57 | 58 | def __repr__(self) -> str: 59 | return ( 60 | f"Operation(name={self.name}, " 61 | f"module_name={self.module_name}, " 62 | f"inputs={self.inputs})" 63 | ) 64 | 65 | 66 | class Task: 67 | r"""A single task (a collection of modules and specified path through them). 68 | 69 | Parameters 70 | ---------- 71 | name 72 | The name of the task 73 | module_pool 74 | A ModuleDict mapping module names to the modules themselves 75 | op_sequence 76 | A list of ``Operation``\s to execute in order, defining the flow of information 77 | through the network for this task 78 | scorer 79 | A ``Scorer`` with the desired metrics to calculate for this task 80 | loss_func 81 | A function that converts final logits into loss values. 82 | Defaults to F.cross_entropy() if none is provided. 83 | To use probalistic labels for training, use the Snorkel-defined method 84 | cross_entropy_with_probs() instead. 85 | output_func 86 | A function that converts final logits into 'outputs' (e.g. probabilities) 87 | Defaults to F.softmax(..., dim=1). 88 | 89 | Attributes 90 | ---------- 91 | name 92 | See above 93 | module_pool 94 | See above 95 | op_sequence 96 | See above 97 | scorer 98 | See above 99 | loss_func 100 | See above 101 | output_func 102 | See above 103 | """ 104 | 105 | def __init__( 106 | self, 107 | name: str, 108 | module_pool: nn.ModuleDict, 109 | op_sequence: Sequence[Operation], 110 | scorer: Scorer = Scorer(metrics=["accuracy"]), 111 | loss_func: Optional[Callable[..., torch.Tensor]] = None, 112 | output_func: Optional[Callable[..., torch.Tensor]] = None, 113 | ) -> None: 114 | self.name = name 115 | self.module_pool = module_pool 116 | self.op_sequence = op_sequence 117 | self.loss_func = loss_func or F.cross_entropy 118 | self.output_func = output_func or partial(F.softmax, dim=1) 119 | self.scorer = scorer 120 | 121 | logging.info(f"Created task: {self.name}") 122 | 123 | def __repr__(self) -> str: 124 | cls_name = type(self).__name__ 125 | return f"{cls_name}(name={self.name})" 126 | -------------------------------------------------------------------------------- /snorkel/classification/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/classification/training/__init__.py -------------------------------------------------------------------------------- /snorkel/classification/training/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpointer import Checkpointer, CheckpointerConfig # noqa: F401 2 | from .log_manager import LogManager, LogManagerConfig # noqa: F401 3 | from .log_writer import LogWriter, LogWriterConfig # noqa: F401 4 | from .tensorboard_writer import TensorBoardWriter # noqa: F401 5 | -------------------------------------------------------------------------------- /snorkel/classification/training/loggers/log_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Optional 3 | 4 | from snorkel.classification.multitask_classifier import MultitaskClassifier 5 | from snorkel.types import Config 6 | 7 | from .checkpointer import Checkpointer 8 | from .log_writer import LogWriter 9 | 10 | 11 | class LogManagerConfig(Config): 12 | """Manager for checkpointing model. 13 | 14 | Parameters 15 | ---------- 16 | counter_unit 17 | The unit to use when assessing when it's time to log. 18 | Options are ["epochs", "batches", "points"] 19 | evaluation_freq: 20 | Evaluate performance on the validation set every this many counter_units 21 | """ 22 | 23 | counter_unit: str = "epochs" 24 | evaluation_freq: float = 1.0 25 | 26 | 27 | class LogManager: 28 | """A class to manage logging during training progress. 29 | 30 | Parameters 31 | ---------- 32 | n_batches_per_epoch 33 | Total number batches per epoch 34 | log_writer 35 | ``LogWriter`` for current run logs 36 | checkpointer 37 | ``Checkpointer`` for current model 38 | kwargs 39 | Settings to update in LogManagerConfig 40 | """ 41 | 42 | def __init__( 43 | self, 44 | n_batches_per_epoch: int, 45 | log_writer: Optional[LogWriter] = None, 46 | checkpointer: Optional[Checkpointer] = None, 47 | **kwargs: Any, 48 | ) -> None: 49 | self.config = LogManagerConfig(**kwargs) # type: ignore 50 | self.n_batches_per_epoch = n_batches_per_epoch 51 | 52 | self.log_writer = log_writer 53 | self.checkpointer = checkpointer 54 | 55 | # Set up counter unit 56 | self.counter_unit = self.config.counter_unit 57 | if self.counter_unit not in ["points", "batches", "epochs"]: 58 | raise ValueError(f"Unrecognized counter_unit: {self.counter_unit}") 59 | 60 | # Set up evaluation frequency 61 | self.evaluation_freq = self.config.evaluation_freq 62 | logging.info(f"Evaluating every {self.evaluation_freq} {self.counter_unit}.") 63 | 64 | # Set up number of X passed since last evaluation/checkpointing and total 65 | self.point_count = 0 66 | self.point_total = 0 67 | 68 | self.batch_count = 0 69 | self.batch_total = 0 70 | 71 | self.epoch_count = 0.0 72 | self.epoch_total = 0.0 73 | 74 | self.unit_count = 0.0 75 | self.unit_total = 0.0 76 | 77 | # Set up count that triggers the evaluation since last checkpointing 78 | self.trigger_count = 0 79 | 80 | def update(self, batch_size: int) -> None: 81 | """Update the count and total number.""" 82 | 83 | # Update number of points 84 | self.point_count += batch_size 85 | self.point_total += batch_size 86 | 87 | # Update number of batches 88 | self.batch_count += 1 89 | self.batch_total += 1 90 | 91 | # Update number of epochs 92 | self.epoch_count = self.batch_count / self.n_batches_per_epoch 93 | self.epoch_total = self.batch_total / self.n_batches_per_epoch 94 | 95 | # Update number of units 96 | if self.counter_unit == "points": 97 | self.unit_count = self.point_count 98 | self.unit_total = self.point_total 99 | if self.counter_unit == "batches": 100 | self.unit_count = self.batch_count 101 | self.unit_total = self.batch_total 102 | elif self.counter_unit == "epochs": 103 | self.unit_count = self.epoch_count 104 | self.unit_total = self.epoch_total 105 | 106 | def trigger_evaluation(self) -> bool: 107 | """Check if current counts trigger evaluation.""" 108 | satisfied = self.unit_count >= self.evaluation_freq 109 | if satisfied: 110 | self.trigger_count += 1 111 | self.reset() 112 | return satisfied 113 | 114 | def trigger_checkpointing(self) -> bool: 115 | """Check if current counts trigger checkpointing.""" 116 | if self.checkpointer is None: 117 | return False 118 | satisfied = self.trigger_count >= self.checkpointer.checkpoint_factor 119 | if satisfied: 120 | self.trigger_count = 0 121 | return satisfied 122 | 123 | def reset(self) -> None: 124 | """Reset counters.""" 125 | self.point_count = 0 126 | self.batch_count = 0 127 | self.epoch_count = 0 128 | self.unit_count = 0 129 | 130 | def cleanup(self, model: MultitaskClassifier) -> MultitaskClassifier: 131 | """Close the log writer and checkpointer if needed. Reload best model.""" 132 | if self.log_writer is not None: 133 | self.log_writer.cleanup() 134 | if self.checkpointer is not None: 135 | self.checkpointer.clear() 136 | model = self.checkpointer.load_best_model(model) 137 | return model 138 | -------------------------------------------------------------------------------- /snorkel/classification/training/loggers/log_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from collections import defaultdict 5 | from datetime import datetime 6 | from typing import Any, DefaultDict, List, Mapping, Optional 7 | 8 | from snorkel.types import Config 9 | 10 | 11 | class LogWriterConfig(Config): 12 | """Manager for checkpointing model. 13 | 14 | Parameters 15 | ---------- 16 | log_dir 17 | The root directory where logs should be saved 18 | run_name 19 | The name of this particular run (defaults to date-time combination if None) 20 | """ 21 | 22 | log_dir: str = "logs" 23 | run_name: Optional[str] = None 24 | 25 | 26 | class LogWriter: 27 | """A class for writing logs. 28 | 29 | Parameters 30 | ---------- 31 | kwargs 32 | Settings to merge into LogWriterConfig 33 | 34 | Attributes 35 | ---------- 36 | config 37 | Merged configuration 38 | run_name 39 | Name of run if provided, otherwise date-time combination 40 | log_dir 41 | The root directory where logs should be saved 42 | run_log 43 | Dictionary of scalar values to log, keyed by value name 44 | """ 45 | 46 | def __init__(self, **kwargs: Any) -> None: 47 | self.config = LogWriterConfig(**kwargs) # type: ignore 48 | 49 | self.run_name = self.config.run_name 50 | if self.run_name is None: 51 | date = datetime.now().strftime("%Y_%m_%d") 52 | time = datetime.now().strftime("%H_%M_%S") 53 | self.run_name = f"{date}/{time}/" 54 | 55 | self.log_dir = os.path.join(self.config.log_dir, self.run_name) 56 | if not os.path.exists(self.log_dir): 57 | os.makedirs(self.log_dir) 58 | 59 | self.run_log: DefaultDict[str, List[List[float]]] = defaultdict(list) 60 | 61 | def add_scalar(self, name: str, value: float, step: float) -> None: 62 | """Log a scalar variable. 63 | 64 | Parameters 65 | ---------- 66 | name 67 | Name of the scalar collection 68 | value 69 | Value of scalar 70 | step 71 | Step axis value 72 | """ 73 | # Note: storing as list for JSON roundtripping 74 | self.run_log[name].append([step, value]) 75 | 76 | def write_config( 77 | self, config: Config, config_filename: str = "config.json" 78 | ) -> None: 79 | """Dump the config to file. 80 | 81 | Parameters 82 | ---------- 83 | config 84 | JSON-compatible config to write to file 85 | config_filename 86 | Name of file in logging directory to write to 87 | """ 88 | self.write_json(config._asdict(), config_filename) 89 | 90 | def write_log(self, log_filename: str) -> None: 91 | """Dump the scalar value log to file. 92 | 93 | Parameters 94 | ---------- 95 | log_filename 96 | Name of file in logging directory to write to 97 | """ 98 | self.write_json(self.run_log, log_filename) 99 | 100 | def write_text(self, text: str, filename: str) -> None: 101 | """Dump user-provided text to filename (e.g., the launch command). 102 | 103 | Parameters 104 | ---------- 105 | text 106 | Text to write 107 | filename 108 | Name of file in logging directory to write to 109 | """ 110 | text_path = os.path.join(self.log_dir, filename) 111 | with open(text_path, "w") as f: 112 | f.write(text) 113 | 114 | def write_json(self, dict_to_write: Mapping[str, Any], filename: str) -> None: 115 | """Dump a JSON-compatbile object to root log directory. 116 | 117 | Parameters 118 | ---------- 119 | dict_to_write 120 | JSON-compatbile object to log 121 | filename 122 | Name of file in logging directory to write to 123 | """ 124 | if not filename.endswith(".json"): # pragma: no cover 125 | logging.warning( 126 | f"Using write_json() method with a filename without a .json extension: {filename}" 127 | ) 128 | log_path = os.path.join(self.log_dir, filename) 129 | with open(log_path, "w") as f: 130 | json.dump(dict_to_write, f) 131 | 132 | def cleanup(self) -> None: 133 | """Perform final operations and close writer if necessary.""" 134 | self.write_log("log.json") 135 | -------------------------------------------------------------------------------- /snorkel/classification/training/loggers/tensorboard_writer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | from snorkel.types import Config 6 | 7 | from .log_writer import LogWriter 8 | 9 | 10 | class TensorBoardWriter(LogWriter): 11 | """A class for logging to Tensorboard during training process. 12 | 13 | See ``LogWriter`` for more attributes. 14 | 15 | Parameters 16 | ---------- 17 | kwargs 18 | Passed to ``LogWriter`` initializer 19 | 20 | Attributes 21 | ---------- 22 | writer 23 | ``SummaryWriter`` for logging and visualization 24 | """ 25 | 26 | def __init__(self, **kwargs: Any) -> None: 27 | super().__init__(**kwargs) 28 | self.writer = SummaryWriter(self.log_dir) 29 | 30 | def add_scalar(self, name: str, value: float, step: float) -> None: 31 | """Log a scalar variable to TensorBoard. 32 | 33 | Parameters 34 | ---------- 35 | name 36 | Name of the scalar collection 37 | value 38 | Value of scalar 39 | step 40 | Step axis value 41 | """ 42 | self.writer.add_scalar(name, value, step) 43 | 44 | def write_config( 45 | self, config: Config, config_filename: str = "config.json" 46 | ) -> None: 47 | """Dump the config to file and add it to TensorBoard. 48 | 49 | Parameters 50 | ---------- 51 | config 52 | JSON-compatible config to write to TensorBoard 53 | config_filename 54 | File to write config to 55 | """ 56 | super().write_config(config, config_filename) 57 | self.writer.add_text(tag="config", text_string=str(config)) 58 | 59 | def cleanup(self) -> None: 60 | """Close the ``SummaryWriter``.""" 61 | self.writer.close() 62 | -------------------------------------------------------------------------------- /snorkel/classification/training/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequential_scheduler import SequentialScheduler 2 | from .shuffled_scheduler import ShuffledScheduler 3 | 4 | batch_schedulers = {"sequential": SequentialScheduler, "shuffled": ShuffledScheduler} 5 | -------------------------------------------------------------------------------- /snorkel/classification/training/schedulers/scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, Iterator, Sequence, Tuple 3 | 4 | from torch import Tensor 5 | 6 | from snorkel.classification.data import DictDataLoader # noqa: F401 7 | 8 | BatchIterator = Iterator[ 9 | Tuple[Tuple[Dict[str, Any], Dict[str, Tensor]], "DictDataLoader"] 10 | ] 11 | 12 | 13 | class Scheduler(ABC): 14 | """Return batches from all dataloaders according to a specified strategy.""" 15 | 16 | def __init__(self) -> None: 17 | pass 18 | 19 | @abstractmethod 20 | def get_batches(self, dataloaders: Sequence["DictDataLoader"]) -> BatchIterator: 21 | """Return batches from dataloaders according to a specified strategy. 22 | 23 | Parameters 24 | ---------- 25 | dataloaders 26 | A sequence of dataloaders to get batches from 27 | 28 | Yields 29 | ------ 30 | (batch, dataloader) 31 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader 32 | that that batch came from. That dataloader will not be accessed by the 33 | model; it is passed primarily so that the model can pull the necessary 34 | metadata to know what to do with the batch it has been given. 35 | """ 36 | -------------------------------------------------------------------------------- /snorkel/classification/training/schedulers/sequential_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from snorkel.classification.data import DictDataLoader 4 | 5 | from .scheduler import BatchIterator, Scheduler 6 | 7 | 8 | class SequentialScheduler(Scheduler): 9 | """Return batches from all dataloaders in sequential order.""" 10 | 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | def get_batches(self, dataloaders: Sequence[DictDataLoader]) -> BatchIterator: 15 | """Return batches from dataloaders sequentially in the order they were given. 16 | 17 | Parameters 18 | ---------- 19 | dataloaders 20 | A sequence of dataloaders to get batches from 21 | 22 | Yields 23 | ------ 24 | (batch, dataloader) 25 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader 26 | that that batch came from. That dataloader will not be accessed by the 27 | model; it is passed primarily so that the model can pull the necessary 28 | metadata to know what to do with the batch it has been given. 29 | """ 30 | for dataloader in dataloaders: 31 | for batch in dataloader: 32 | yield batch, dataloader 33 | -------------------------------------------------------------------------------- /snorkel/classification/training/schedulers/shuffled_scheduler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Sequence 3 | 4 | from snorkel.classification.data import DictDataLoader 5 | 6 | from .scheduler import BatchIterator, Scheduler 7 | 8 | 9 | class ShuffledScheduler(Scheduler): 10 | """Return batches from all dataloaders in shuffled order for each epoch.""" 11 | 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | def get_batches(self, dataloaders: Sequence[DictDataLoader]) -> BatchIterator: 16 | """Return batches in shuffled order from dataloaders. 17 | 18 | Note that this shuffles the batch order, but it does not shuffle the datasets 19 | themselves; shuffling the datasets is specified in the DataLoaders directly. 20 | 21 | Parameters 22 | ---------- 23 | dataloaders 24 | A sequence of dataloaders to get batches from 25 | 26 | Yields 27 | ------ 28 | (batch, dataloader) 29 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader 30 | that that batch came from. That dataloader will not be accessed by the 31 | model; it is passed primarily so that the model can pull the necessary 32 | metadata to know what to do with the batch it has been given. 33 | """ 34 | batch_counts = [len(dl) for dl in dataloaders] 35 | dataloader_iters = [iter(dl) for dl in dataloaders] 36 | 37 | dataloader_indices = [] 38 | for idx, count in enumerate(batch_counts): 39 | dataloader_indices.extend([idx] * count) 40 | 41 | random.shuffle(dataloader_indices) 42 | 43 | for index in dataloader_indices: 44 | yield next(dataloader_iters[index]), dataloaders[index] 45 | -------------------------------------------------------------------------------- /snorkel/classification/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | 7 | TensorCollection = Union[torch.Tensor, dict, list, tuple] 8 | 9 | 10 | def list_to_tensor(item_list: List[torch.Tensor]) -> torch.Tensor: 11 | """Convert a list of torch.Tensor into a single torch.Tensor.""" 12 | 13 | # Convert single value tensor 14 | if all(item_list[i].dim() == 0 for i in range(len(item_list))): 15 | item_tensor = torch.stack(item_list, dim=0) 16 | # Convert 2 or more-D tensor with the same shape 17 | elif all( 18 | (item_list[i].size() == item_list[0].size()) and (len(item_list[i].size()) != 1) 19 | for i in range(len(item_list)) 20 | ): 21 | item_tensor = torch.stack(item_list, dim=0) 22 | # Convert reshape to 1-D tensor and then convert 23 | else: 24 | item_tensor, _ = pad_batch([item.view(-1) for item in item_list]) 25 | 26 | return item_tensor 27 | 28 | 29 | def pad_batch( 30 | batch: List[torch.Tensor], 31 | max_len: int = 0, 32 | pad_value: int = 0, 33 | left_padded: bool = False, 34 | ) -> Tuple[torch.Tensor, torch.Tensor]: 35 | """Convert the batch into a padded tensor and mask tensor. 36 | 37 | Parameters 38 | ---------- 39 | batch 40 | The data for padding 41 | max_len 42 | Max length of sequence of padding 43 | pad_value 44 | The value to use for padding 45 | left_padded 46 | If True, pad on the left, otherwise on the right 47 | 48 | Returns 49 | ------- 50 | Tuple[torch.Tensor, torch.Tensor] 51 | The padded matrix and correspoing mask matrix. 52 | """ 53 | 54 | batch_size = len(batch) 55 | max_seq_len = int(np.max([len(item) for item in batch])) # type: ignore 56 | 57 | if max_len > 0 and max_len < max_seq_len: 58 | max_seq_len = max_len 59 | 60 | padded_batch = batch[0].new_full((batch_size, max_seq_len), pad_value) 61 | 62 | for i, item in enumerate(batch): 63 | length = min(len(item), max_seq_len) # type: ignore 64 | if left_padded: 65 | padded_batch[i, -length:] = item[-length:] 66 | else: 67 | padded_batch[i, :length] = item[:length] 68 | 69 | mask_batch = torch.eq(padded_batch.clone().detach(), pad_value).type_as( 70 | padded_batch 71 | ) 72 | 73 | return padded_batch, mask_batch 74 | 75 | 76 | def move_to_device( 77 | obj: TensorCollection, device: int = -1 78 | ) -> TensorCollection: # pragma: no cover 79 | """Recursively move torch.Tensors to a given CUDA device. 80 | 81 | Given a structure (possibly) containing Tensors on the CPU, move all the Tensors 82 | to the specified GPU (or do nothing, if they should beon the CPU). 83 | 84 | Originally from: 85 | https://github.com/HazyResearch/metal/blob/mmtl_clean/metal/utils.py 86 | 87 | Parameters 88 | ---------- 89 | obj 90 | Tensor or collection of Tensors to move 91 | device 92 | Device to move Tensors to 93 | device = -1 -> "cpu" 94 | device = 0 -> "cuda:0" 95 | """ 96 | 97 | if device < 0 or not torch.cuda.is_available(): 98 | return obj 99 | elif isinstance(obj, torch.Tensor): 100 | return obj.cuda(device) # type: ignore 101 | elif isinstance(obj, dict): 102 | return {key: move_to_device(value, device) for key, value in obj.items()} 103 | elif isinstance(obj, list): 104 | return [move_to_device(item, device) for item in obj] 105 | elif isinstance(obj, tuple): 106 | return tuple([move_to_device(item, device) for item in obj]) 107 | else: 108 | return obj 109 | 110 | 111 | def collect_flow_outputs_by_suffix( 112 | output_dict: Dict[str, torch.Tensor], suffix: str 113 | ) -> List[torch.Tensor]: 114 | """Return output_dict outputs specified by suffix, ordered by sorted flow_name.""" 115 | return [ 116 | output_dict[flow_name] 117 | for flow_name in sorted(output_dict.keys()) 118 | if flow_name.endswith(suffix) 119 | ] 120 | 121 | 122 | def metrics_dict_to_dataframe(metrics_dict: Dict[str, float]) -> pd.DataFrame: 123 | """Format a metrics_dict (with keys 'label/dataset/split/metric') format as a pandas DataFrame.""" 124 | 125 | metrics = [] 126 | 127 | for full_metric, score in metrics_dict.items(): 128 | label_name, dataset_name, split, metric = tuple(full_metric.split("/")) 129 | metrics.append((label_name, dataset_name, split, metric, score)) 130 | 131 | return pd.DataFrame( 132 | metrics, columns=["label", "dataset", "split", "metric", "score"] 133 | ) 134 | -------------------------------------------------------------------------------- /snorkel/contrib/README.md: -------------------------------------------------------------------------------- 1 | # Snorkel contrib 2 | 3 | tl;dr Have something fun that you think others could benefit from but isn't 4 | ready for prime time yet? Put it here! 5 | 6 | Any code in this directory is not officially supported, and may change or be 7 | removed at any time without notice. 8 | 9 | The contrib directory contains project directories, each of which has designated 10 | owners. It is meant to contain features and contributions whose interfaces may 11 | change, or which require some testing to see whether they can find broader acceptance. 12 | You may be asked to refactor code in contrib to use some feature inside core or 13 | in another contrib project rather than reimplementing the feature. 14 | 15 | When adding a project, please stick to the following directory structure: 16 | Create a project directory in `contrib/`, and mirror the portions of the 17 | Snorkel tree that your project requires underneath `contrib/my_project/`. 18 | 19 | For example, let's say you create foo ops for labeling in two files: 20 | `foo_ops.py` and `foo_ops_test.py`. If you were to merge those files 21 | directly into Snorkel, they would live in `snorkel/labeling/foo_ops.py` and 22 | `test/labeling/foo_ops_test.py`. In `contrib/`, they are part 23 | of project `foo`, and their full paths are `contrib/foo/snorkel/labeling/foo_ops.py` 24 | and `contrib/foo/test/labeling/foo_ops_test.py`. 25 | 26 | 27 | *Adapted from [TensorFlow contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib).* 28 | -------------------------------------------------------------------------------- /snorkel/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | """Contributed modules for Snorkel.""" 2 | -------------------------------------------------------------------------------- /snorkel/labeling/__init__.py: -------------------------------------------------------------------------------- 1 | """Programmatic data set labeling: LF creation, models, and analysis utilities.""" 2 | 3 | from .analysis import LFAnalysis # noqa: F401 4 | from .apply.core import LFApplier # noqa: F401 5 | from .apply.pandas import PandasLFApplier # noqa: F401 6 | from .lf.core import LabelingFunction, labeling_function # noqa: F401 7 | from .utils import filter_unlabeled_dataframe # noqa: F401 8 | -------------------------------------------------------------------------------- /snorkel/labeling/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/labeling/apply/__init__.py -------------------------------------------------------------------------------- /snorkel/labeling/apply/dask.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from dask import dataframe as dd 7 | from dask.distributed import Client 8 | 9 | from .core import BaseLFApplier, _FunctionCaller 10 | from .pandas import apply_lfs_to_data_point, rows_to_triplets 11 | 12 | Scheduler = Union[str, Client] 13 | 14 | 15 | class DaskLFApplier(BaseLFApplier): 16 | """LF applier for a Dask DataFrame. 17 | 18 | Dask DataFrames consist of partitions, each being a Pandas DataFrame. 19 | This allows for efficient parallel computation over DataFrame rows. 20 | For more information, see https://docs.dask.org/en/stable/dataframe.html 21 | """ 22 | 23 | def apply( 24 | self, 25 | df: dd.DataFrame, 26 | scheduler: Scheduler = "processes", 27 | fault_tolerant: bool = False, 28 | ) -> np.ndarray: 29 | """Label Dask DataFrame of data points with LFs. 30 | 31 | Parameters 32 | ---------- 33 | df 34 | Dask DataFrame containing data points to be labeled by LFs 35 | scheduler 36 | A Dask scheduling configuration: either a string option or 37 | a ``Client``. For more information, see 38 | https://docs.dask.org/en/stable/scheduling.html# 39 | fault_tolerant 40 | Output ``-1`` if LF execution fails? 41 | 42 | Returns 43 | ------- 44 | np.ndarray 45 | Matrix of labels emitted by LFs 46 | """ 47 | f_caller = _FunctionCaller(fault_tolerant) 48 | apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller) 49 | map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1)) 50 | labels = map_fn.compute(scheduler=scheduler) 51 | labels_with_index = rows_to_triplets(labels) 52 | return self._numpy_from_row_data(labels_with_index) 53 | 54 | 55 | class PandasParallelLFApplier(DaskLFApplier): 56 | """Parallel LF applier for a Pandas DataFrame. 57 | 58 | Creates a Dask DataFrame from a Pandas DataFrame, then uses 59 | ``DaskLFApplier`` to label data in parallel. See ``DaskLFApplier``. 60 | """ 61 | 62 | def apply( # type: ignore 63 | self, 64 | df: pd.DataFrame, 65 | n_parallel: int = 2, 66 | scheduler: Scheduler = "processes", 67 | fault_tolerant: bool = False, 68 | ) -> np.ndarray: 69 | """Label Pandas DataFrame of data points with LFs in parallel using Dask. 70 | 71 | Parameters 72 | ---------- 73 | df 74 | Pandas DataFrame containing data points to be labeled by LFs 75 | n_parallel 76 | Parallelism level for LF application. Corresponds to ``npartitions`` 77 | in constructed Dask DataFrame. For ``scheduler="processes"``, number 78 | of processes launched. Recommended to be no more than the number 79 | of cores on the running machine. 80 | scheduler 81 | A Dask scheduling configuration: either a string option or 82 | a ``Client``. For more information, see 83 | https://docs.dask.org/en/stable/scheduling.html# 84 | fault_tolerant 85 | Output ``-1`` if LF execution fails? 86 | 87 | Returns 88 | ------- 89 | np.ndarray 90 | Matrix of labels emitted by LFs 91 | """ 92 | if n_parallel < 2: 93 | raise ValueError( 94 | "n_parallel should be >= 2. " 95 | "For single process Pandas, use PandasLFApplier." 96 | ) 97 | df = dd.from_pandas(df, npartitions=n_parallel) 98 | return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant) 99 | -------------------------------------------------------------------------------- /snorkel/labeling/apply/pandas.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List, Tuple, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from snorkel.labeling.lf import LabelingFunction 9 | from snorkel.types import DataPoint 10 | 11 | from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller 12 | 13 | PandasRowData = List[Tuple[int, int]] 14 | 15 | 16 | def apply_lfs_to_data_point( 17 | x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller 18 | ) -> PandasRowData: 19 | """Label a single data point with a set of LFs. 20 | 21 | Parameters 22 | ---------- 23 | x 24 | Data point to label 25 | lfs 26 | Set of LFs to label ``x`` with 27 | f_caller 28 | A ``_FunctionCaller`` to record failed LF executions 29 | 30 | Returns 31 | ------- 32 | RowData 33 | A list of (LF index, label) tuples 34 | """ 35 | labels = [] 36 | for j, lf in enumerate(lfs): 37 | y = f_caller(lf, x) 38 | if y >= 0: 39 | labels.append((j, y)) 40 | return labels 41 | 42 | 43 | def rows_to_triplets(labels: List[PandasRowData]) -> List[RowData]: 44 | """Convert list of list sparse matrix representation to list of triplets.""" 45 | return [ 46 | [(index, j, y) for j, y in row_labels] 47 | for index, row_labels in enumerate(labels) 48 | ] 49 | 50 | 51 | class PandasLFApplier(BaseLFApplier): 52 | """LF applier for a Pandas DataFrame. 53 | 54 | Data points are stored as ``Series`` in a DataFrame. The LFs 55 | are executed via a ``pandas.DataFrame.apply`` call, which 56 | is single-process and can be slow for large DataFrames. 57 | For large datasets, consider ``DaskLFApplier`` or ``SparkLFApplier``. 58 | 59 | Parameters 60 | ---------- 61 | lfs 62 | LFs that this applier executes on examples 63 | 64 | Example 65 | ------- 66 | >>> from snorkel.labeling import labeling_function 67 | >>> @labeling_function() 68 | ... def is_big_num(x): 69 | ... return 1 if x.num > 42 else 0 70 | >>> applier = PandasLFApplier([is_big_num]) 71 | >>> applier.apply(pd.DataFrame(dict(num=[10, 100], text=["hello", "hi"]))) 72 | array([[0], [1]]) 73 | """ 74 | 75 | def apply( 76 | self, 77 | df: pd.DataFrame, 78 | progress_bar: bool = True, 79 | fault_tolerant: bool = False, 80 | return_meta: bool = False, 81 | ) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]: 82 | """Label Pandas DataFrame of data points with LFs. 83 | 84 | Parameters 85 | ---------- 86 | df 87 | Pandas DataFrame containing data points to be labeled by LFs 88 | progress_bar 89 | Display a progress bar? 90 | fault_tolerant 91 | Output ``-1`` if LF execution fails? 92 | return_meta 93 | Return metadata from apply call? 94 | 95 | Returns 96 | ------- 97 | np.ndarray 98 | Matrix of labels emitted by LFs 99 | ApplierMetadata 100 | Metadata, such as fault counts, for the apply call 101 | """ 102 | f_caller = _FunctionCaller(fault_tolerant) 103 | apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller) 104 | call_fn = df.apply 105 | if progress_bar: 106 | tqdm.pandas() 107 | call_fn = df.progress_apply 108 | labels = call_fn(apply_fn, axis=1) 109 | labels_with_index = rows_to_triplets(labels) 110 | L = self._numpy_from_row_data(labels_with_index) 111 | if return_meta: 112 | return L, ApplierMetadata(f_caller.fault_counts) 113 | return L 114 | -------------------------------------------------------------------------------- /snorkel/labeling/apply/spark.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | from pyspark import RDD 5 | 6 | from snorkel.types import DataPoint 7 | 8 | from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point 9 | 10 | 11 | class SparkLFApplier(BaseLFApplier): 12 | r"""LF applier for a Spark RDD. 13 | 14 | Data points are stored as ``Row``\s in an RDD, and a Spark 15 | ``map`` job is submitted to execute the LFs. A common 16 | way to obtain an RDD is via a PySpark DataFrame. For an 17 | example usage with AWS EMR instructions, see 18 | ``test/labeling/apply/lf_applier_spark_test_script.py``. 19 | """ 20 | 21 | def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray: 22 | """Label PySpark RDD of data points with LFs. 23 | 24 | Parameters 25 | ---------- 26 | data_points 27 | PySpark RDD containing data points to be labeled by LFs 28 | fault_tolerant 29 | Output ``-1`` if LF execution fails? 30 | 31 | Returns 32 | ------- 33 | np.ndarray 34 | Matrix of labels emitted by LFs 35 | """ 36 | f_caller = _FunctionCaller(fault_tolerant) 37 | 38 | def map_fn(args: Tuple[DataPoint, int]) -> RowData: 39 | return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller) 40 | 41 | labels = data_points.zipWithIndex().map(map_fn).collect() 42 | return self._numpy_from_row_data(labels) 43 | -------------------------------------------------------------------------------- /snorkel/labeling/lf/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import LabelingFunction, labeling_function # noqa: F401 2 | -------------------------------------------------------------------------------- /snorkel/labeling/lf/core.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Mapping, Optional 2 | 3 | from snorkel.preprocess import BasePreprocessor 4 | from snorkel.types import DataPoint 5 | 6 | 7 | class LabelingFunction: 8 | """Base class for labeling functions. 9 | 10 | A labeling function (LF) is a function that takes a data point 11 | as input and produces an integer label, corresponding to a 12 | class. A labeling function can also abstain from voting by 13 | outputting ``-1``. For examples, see the Snorkel tutorials. 14 | 15 | This class wraps a Python function outputting a label. Extra 16 | functionality, such as running preprocessors and storing 17 | resources, is provided. Simple LFs can be defined via a 18 | decorator. See ``labeling_function``. 19 | 20 | Parameters 21 | ---------- 22 | name 23 | Name of the LF 24 | f 25 | Function that implements the core LF logic 26 | resources 27 | Labeling resources passed in to ``f`` via ``kwargs`` 28 | pre 29 | Preprocessors to run on data points before LF execution 30 | 31 | Raises 32 | ------ 33 | ValueError 34 | Calling incorrectly defined preprocessors 35 | 36 | Attributes 37 | ---------- 38 | name 39 | See above 40 | """ 41 | 42 | def __init__( 43 | self, 44 | name: str, 45 | f: Callable[..., int], 46 | resources: Optional[Mapping[str, Any]] = None, 47 | pre: Optional[List[BasePreprocessor]] = None, 48 | ) -> None: 49 | self.name = name 50 | self._f = f 51 | self._resources = resources or {} 52 | self._pre = pre or [] 53 | 54 | def _preprocess_data_point(self, x: DataPoint) -> DataPoint: 55 | for preprocessor in self._pre: 56 | x = preprocessor(x) 57 | if x is None: 58 | raise ValueError("Preprocessor should not return None") 59 | return x 60 | 61 | def __call__(self, x: DataPoint) -> int: 62 | """Label data point. 63 | 64 | Runs all preprocessors, then passes preprocessed data point to LF. 65 | 66 | Parameters 67 | ---------- 68 | x 69 | Data point to label 70 | 71 | Returns 72 | ------- 73 | int 74 | Label for data point 75 | """ 76 | x = self._preprocess_data_point(x) 77 | return self._f(x, **self._resources) 78 | 79 | def __repr__(self) -> str: 80 | preprocessor_str = f", Preprocessors: {self._pre}" 81 | return f"{type(self).__name__} {self.name}{preprocessor_str}" 82 | 83 | 84 | class labeling_function: 85 | """Decorator to define a LabelingFunction object from a function. 86 | 87 | Parameters 88 | ---------- 89 | name 90 | Name of the LF 91 | resources 92 | Labeling resources passed in to ``f`` via ``kwargs`` 93 | pre 94 | Preprocessors to run on data points before LF execution 95 | 96 | Examples 97 | -------- 98 | >>> @labeling_function() 99 | ... def f(x): 100 | ... return 0 if x.a > 42 else -1 101 | >>> f 102 | LabelingFunction f, Preprocessors: [] 103 | >>> from types import SimpleNamespace 104 | >>> x = SimpleNamespace(a=90, b=12) 105 | >>> f(x) 106 | 0 107 | 108 | >>> @labeling_function(name="my_lf") 109 | ... def g(x): 110 | ... return 0 if x.a > 42 else -1 111 | >>> g 112 | LabelingFunction my_lf, Preprocessors: [] 113 | """ 114 | 115 | def __init__( 116 | self, 117 | name: Optional[str] = None, 118 | resources: Optional[Mapping[str, Any]] = None, 119 | pre: Optional[List[BasePreprocessor]] = None, 120 | ) -> None: 121 | if callable(name): 122 | raise ValueError("Looks like this decorator is missing parentheses!") 123 | self.name = name 124 | self.resources = resources 125 | self.pre = pre 126 | 127 | def __call__(self, f: Callable[..., int]) -> LabelingFunction: 128 | """Wrap a function to create a ``LabelingFunction``. 129 | 130 | Parameters 131 | ---------- 132 | f 133 | Function that implements the core LF logic 134 | 135 | Returns 136 | ------- 137 | LabelingFunction 138 | New ``LabelingFunction`` executing logic in wrapped function 139 | """ 140 | name = self.name or f.__name__ 141 | return LabelingFunction(name=name, f=f, resources=self.resources, pre=self.pre) 142 | -------------------------------------------------------------------------------- /snorkel/labeling/lf/nlp_spark.py: -------------------------------------------------------------------------------- 1 | from snorkel.preprocess.nlp import SpacyPreprocessor 2 | from snorkel.preprocess.spark import make_spark_preprocessor 3 | 4 | from .nlp import ( 5 | BaseNLPLabelingFunction, 6 | SpacyPreprocessorParameters, 7 | base_nlp_labeling_function, 8 | ) 9 | 10 | 11 | class SparkNLPLabelingFunction(BaseNLPLabelingFunction): 12 | r"""Special labeling function type for SpaCy-based LFs running on Spark. 13 | 14 | This class is a Spark-compatible version of ``NLPLabelingFunction``. 15 | See ``NLPLabelingFunction`` for details. 16 | 17 | Parameters 18 | ---------- 19 | name 20 | Name of the LF 21 | f 22 | Function that implements the core LF logic 23 | resources 24 | Labeling resources passed in to ``f`` via ``kwargs`` 25 | pre 26 | Preprocessors to run before SpacyPreprocessor is executed 27 | text_field 28 | Name of data point text field to input 29 | doc_field 30 | Name of data point field to output parsed document to 31 | language 32 | SpaCy model to load 33 | See https://spacy.io/usage/models#usage 34 | disable 35 | List of pipeline components to disable 36 | See https://spacy.io/usage/processing-pipelines#disabling 37 | memoize 38 | Memoize preprocessor outputs? 39 | memoize_key 40 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable) 41 | gpu 42 | Prefer Spacy GPU processing? 43 | 44 | Raises 45 | ------ 46 | ValueError 47 | Calling incorrectly defined preprocessors 48 | 49 | Attributes 50 | ---------- 51 | name 52 | See above 53 | """ 54 | 55 | @classmethod 56 | def _create_preprocessor( 57 | cls, parameters: SpacyPreprocessorParameters 58 | ) -> SpacyPreprocessor: 59 | preprocessor = SpacyPreprocessor(**parameters._asdict()) 60 | make_spark_preprocessor(preprocessor) 61 | return preprocessor 62 | 63 | 64 | class spark_nlp_labeling_function(base_nlp_labeling_function): 65 | """Decorator to define a SparkNLPLabelingFunction object from a function. 66 | 67 | Parameters 68 | ---------- 69 | name 70 | Name of the LF 71 | resources 72 | Labeling resources passed in to ``f`` via ``kwargs`` 73 | pre 74 | Preprocessors to run before SpacyPreprocessor is executed 75 | text_field 76 | Name of data point text field to input 77 | doc_field 78 | Name of data point field to output parsed document to 79 | language 80 | SpaCy model to load 81 | See https://spacy.io/usage/models#usage 82 | disable 83 | List of pipeline components to disable 84 | See https://spacy.io/usage/processing-pipelines#disabling 85 | memoize 86 | Memoize preprocessor outputs? 87 | memoize_key 88 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable) 89 | 90 | Example 91 | ------- 92 | >>> @spark_nlp_labeling_function() 93 | ... def has_person_mention(x): 94 | ... person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 95 | ... return 0 if len(person_ents) > 0 else -1 96 | >>> has_person_mention 97 | SparkNLPLabelingFunction has_person_mention, Preprocessors: [SpacyPreprocessor...] 98 | 99 | >>> from pyspark.sql import Row 100 | >>> x = Row(text="The movie was good.") 101 | >>> has_person_mention(x) 102 | -1 103 | """ 104 | 105 | _lf_cls = SparkNLPLabelingFunction 106 | -------------------------------------------------------------------------------- /snorkel/labeling/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .baselines import MajorityClassVoter, MajorityLabelVoter, RandomVoter # noqa: F401 2 | from .label_model import LabelModel # noqa: F401 3 | -------------------------------------------------------------------------------- /snorkel/labeling/model/base_labeler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | 8 | from snorkel.analysis import Scorer 9 | from snorkel.utils import probs_to_preds 10 | 11 | 12 | class BaseLabeler(ABC): 13 | """Abstract baseline label voter class.""" 14 | 15 | def __init__(self, cardinality: int = 2, **kwargs: Any) -> None: 16 | self.cardinality = cardinality 17 | 18 | @abstractmethod 19 | def predict_proba(self, L: np.ndarray) -> np.ndarray: 20 | """Abstract method for predicting probabilistic labels given a label matrix. 21 | 22 | Parameters 23 | ---------- 24 | L 25 | An [n,m] matrix with values in {-1,0,1,...,k-1}f 26 | 27 | Returns 28 | ------- 29 | np.ndarray 30 | An [n,k] array of probabilistic labels 31 | """ 32 | pass 33 | 34 | def predict( 35 | self, 36 | L: np.ndarray, 37 | return_probs: Optional[bool] = False, 38 | tie_break_policy: str = "abstain", 39 | ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 40 | """Return predicted labels, with ties broken according to policy. 41 | 42 | Policies to break ties include: 43 | "abstain": return an abstain vote (-1) 44 | "true-random": randomly choose among the tied options 45 | "random": randomly choose among tied option using deterministic hash 46 | 47 | NOTE: if tie_break_policy="true-random", repeated runs may have slightly different 48 | results due to difference in broken ties 49 | 50 | 51 | Parameters 52 | ---------- 53 | L 54 | An [n,m] matrix with values in {-1,0,1,...,k-1} 55 | return_probs 56 | Whether to return probs along with preds 57 | tie_break_policy 58 | Policy to break ties when converting probabilistic labels to predictions 59 | 60 | Returns 61 | ------- 62 | np.ndarray 63 | An [n,1] array of integer labels 64 | 65 | (np.ndarray, np.ndarray) 66 | An [n,1] array of integer labels and an [n,k] array of probabilistic labels 67 | """ 68 | Y_probs = self.predict_proba(L) 69 | Y_p = probs_to_preds(Y_probs, tie_break_policy) 70 | if return_probs: 71 | return Y_p, Y_probs 72 | return Y_p 73 | 74 | def score( 75 | self, 76 | L: np.ndarray, 77 | Y: np.ndarray, 78 | metrics: Optional[List[str]] = ["accuracy"], 79 | tie_break_policy: str = "abstain", 80 | ) -> Dict[str, float]: 81 | """Calculate one or more scores from user-specified and/or user-defined metrics. 82 | 83 | Parameters 84 | ---------- 85 | L 86 | An [n,m] matrix with values in {-1,0,1,...,k-1} 87 | Y 88 | Gold labels associated with data points in L 89 | metrics 90 | A list of metric names 91 | tie_break_policy 92 | Policy to break ties when converting probabilistic labels to predictions 93 | 94 | 95 | Returns 96 | ------- 97 | Dict[str, float] 98 | A dictionary mapping metric names to metric scores 99 | """ 100 | if tie_break_policy == "abstain": # pragma: no cover 101 | logging.warning( 102 | "Metrics calculated over data points with non-abstain labels only" 103 | ) 104 | 105 | Y_pred, Y_prob = self.predict( 106 | L, return_probs=True, tie_break_policy=tie_break_policy 107 | ) 108 | 109 | scorer = Scorer(metrics=metrics) 110 | results = scorer.score(Y, Y_pred, Y_prob) 111 | return results 112 | 113 | def save(self, destination: str) -> None: 114 | """Save label model. 115 | 116 | Parameters 117 | ---------- 118 | destination 119 | Filename for saving model 120 | 121 | Example 122 | ------- 123 | >>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP 124 | """ 125 | f = open(destination, "wb") 126 | pickle.dump(self.__dict__, f) 127 | f.close() 128 | 129 | def load(self, source: str) -> None: 130 | """Load existing label model. 131 | 132 | Parameters 133 | ---------- 134 | source 135 | Filename to load model from 136 | 137 | Example 138 | ------- 139 | Load parameters saved in ``saved_label_model`` 140 | 141 | >>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP 142 | """ 143 | f = open(source, "rb") 144 | tmp_dict = pickle.load(f) 145 | f.close() 146 | self.__dict__.update(tmp_dict) 147 | -------------------------------------------------------------------------------- /snorkel/labeling/model/baselines.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from snorkel.labeling.model.base_labeler import BaseLabeler 6 | 7 | 8 | class RandomVoter(BaseLabeler): 9 | """Random vote label model. 10 | 11 | Example 12 | ------- 13 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]]) 14 | >>> random_voter = RandomVoter() 15 | >>> predictions = random_voter.predict_proba(L) 16 | """ 17 | 18 | def predict_proba(self, L: np.ndarray) -> np.ndarray: 19 | """ 20 | Assign random votes to the data points. 21 | 22 | Parameters 23 | ---------- 24 | L 25 | An [n, m] matrix of labels 26 | 27 | Returns 28 | ------- 29 | np.ndarray 30 | A [n, k] array of probabilistic labels 31 | 32 | Example 33 | ------- 34 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]]) 35 | >>> random_voter = RandomVoter() 36 | >>> predictions = random_voter.predict_proba(L) 37 | """ 38 | n = L.shape[0] 39 | Y_p = np.random.rand(n, self.cardinality) 40 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1) 41 | return Y_p 42 | 43 | 44 | class MajorityClassVoter(BaseLabeler): 45 | """Majority class label model.""" 46 | 47 | def fit( # type: ignore 48 | self, balance: np.ndarray, *args: Any, **kwargs: Any 49 | ) -> None: 50 | """Train majority class model. 51 | 52 | Set class balance for majority class label model. 53 | 54 | Parameters 55 | ---------- 56 | balance 57 | A [k] array of class probabilities 58 | """ 59 | self.balance = balance 60 | 61 | def predict_proba(self, L: np.ndarray) -> np.ndarray: 62 | """Predict probabilities using majority class. 63 | 64 | Assign majority class vote to each datapoint. 65 | In case of multiple majority classes, assign equal probabilities among them. 66 | 67 | 68 | Parameters 69 | ---------- 70 | L 71 | An [n, m] matrix of labels 72 | 73 | Returns 74 | ------- 75 | np.ndarray 76 | A [n, k] array of probabilistic labels 77 | 78 | Example 79 | ------- 80 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]]) 81 | >>> maj_class_voter = MajorityClassVoter() 82 | >>> maj_class_voter.fit(balance=np.array([0.8, 0.2])) 83 | >>> maj_class_voter.predict_proba(L) 84 | array([[1., 0.], 85 | [1., 0.], 86 | [1., 0.]]) 87 | """ 88 | n = L.shape[0] 89 | Y_p = np.zeros((n, self.cardinality)) 90 | max_classes = np.where(self.balance == max(self.balance)) 91 | for c in max_classes: 92 | Y_p[:, c] = 1.0 93 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1) 94 | return Y_p 95 | 96 | 97 | class MajorityLabelVoter(BaseLabeler): 98 | """Majority vote label model.""" 99 | 100 | def predict_proba(self, L: np.ndarray) -> np.ndarray: 101 | """Predict probabilities using majority vote. 102 | 103 | Assign vote by calculating majority vote across all labeling functions. 104 | In case of ties, non-integer probabilities are possible. 105 | 106 | Parameters 107 | ---------- 108 | L 109 | An [n, m] matrix of labels 110 | 111 | Returns 112 | ------- 113 | np.ndarray 114 | A [n, k] array of probabilistic labels 115 | 116 | Example 117 | ------- 118 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]]) 119 | >>> maj_voter = MajorityLabelVoter() 120 | >>> maj_voter.predict_proba(L) 121 | array([[1. , 0. ], 122 | [0.5, 0.5], 123 | [0.5, 0.5]]) 124 | """ 125 | n, m = L.shape 126 | Y_p = np.zeros((n, self.cardinality)) 127 | for i in range(n): 128 | counts = np.zeros(self.cardinality) 129 | for j in range(m): 130 | if L[i, j] != -1: 131 | counts[L[i, j]] += 1 132 | Y_p[i, :] = np.where(counts == max(counts), 1, 0) 133 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1) 134 | return Y_p 135 | -------------------------------------------------------------------------------- /snorkel/labeling/model/graph_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Tuple 2 | 3 | import networkx as nx 4 | 5 | 6 | def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Graph: 7 | """ 8 | Given a set of int nodes i and edges (i,j), returns a clique tree. 9 | 10 | Clique tree is an object G for which: 11 | - G.nodes[i]['members'] contains the set of original nodes in the ith 12 | maximal clique 13 | - G[i][j]['members'] contains the set of original nodes in the seperator 14 | set between maximal cliques i and j 15 | 16 | Note: This method is currently only implemented for chordal graphs; TODO: 17 | add a step to triangulate non-chordal graphs. 18 | 19 | Parameters 20 | ---------- 21 | nodes 22 | A list of nodes indices 23 | edges 24 | A list of tuples, where each tuple has indices for connected nodes 25 | 26 | Returns 27 | ------- 28 | networkx.Graph 29 | An object G representing clique tree 30 | """ 31 | # Form the original graph G1 32 | G1 = nx.Graph() 33 | G1.add_nodes_from(nodes) 34 | G1.add_edges_from(edges) 35 | 36 | # Check if graph is chordal 37 | # TODO: Add step to triangulate graph if not 38 | if not nx.is_chordal(G1): 39 | raise NotImplementedError("Graph triangulation not implemented.") 40 | 41 | # Create maximal clique graph G2 42 | # Each node is a maximal clique C_i 43 | # Let w = |C_i \cap C_j|; C_i, C_j have an edge with weight w if w > 0 44 | G2 = nx.Graph() 45 | for i, c in enumerate(nx.chordal_graph_cliques(G1)): 46 | G2.add_node(i, members=c) 47 | for i in G2.nodes: 48 | for j in G2.nodes: 49 | S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"]) 50 | w = len(S) 51 | if w > 0: 52 | G2.add_edge(i, j, weight=w, members=S) 53 | 54 | # Return a minimum spanning tree of G2 55 | return nx.minimum_spanning_tree(G2) 56 | -------------------------------------------------------------------------------- /snorkel/labeling/model/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import DefaultDict, Dict, List, Optional 4 | 5 | 6 | class Logger: 7 | """Class for logging LabelModel. 8 | 9 | Parameters 10 | ---------- 11 | log_freq 12 | Number of units at which to log model 13 | 14 | Attributes 15 | ---------- 16 | log_freq 17 | Number of units at which to log model 18 | unit_count 19 | Running total of number of units passed without logging 20 | """ 21 | 22 | def __init__(self, log_freq: int) -> None: 23 | self.log_freq = log_freq 24 | self.unit_count = -1 25 | 26 | def check(self) -> bool: 27 | """Check if the logging frequency has been met. 28 | 29 | Returns 30 | ------- 31 | bool 32 | Whether to log or not based on logging frequency 33 | """ 34 | self.unit_count += 1 35 | return self.unit_count % self.log_freq == 0 36 | 37 | def log(self, metrics_dict: Dict[str, float]) -> None: 38 | """Print all metrics in metrics_dict to screen. 39 | 40 | Parameters 41 | ---------- 42 | metrics_dict 43 | Dictionary of metric names (keys) and values to log 44 | 45 | Raises 46 | ------ 47 | Exception 48 | If metric names formatted incorrectly 49 | """ 50 | score_strings: DefaultDict[str, List[str]] = defaultdict(list) 51 | for full_name, value in metrics_dict.items(): 52 | task: Optional[str] 53 | if full_name.count("/") == 2: 54 | task, split, metric = full_name.split("/") 55 | elif full_name.count("/") == 1: 56 | task = None 57 | split, metric = full_name.split("/") 58 | else: 59 | msg = f"Metric should have form task/split/metric or split/metric, not: {full_name}" 60 | raise Exception(msg) 61 | 62 | if task: 63 | metric_name = f"{task}/{metric}" 64 | else: 65 | metric_name = metric 66 | if isinstance(value, float): 67 | score_strings[split].append(f"{metric_name}={value:0.3f}") 68 | else: 69 | score_strings[split].append(f"{metric_name}={value}") 70 | 71 | string = f"[{self.unit_count} epochs]:" 72 | 73 | if score_strings["train"]: 74 | train_scores = f"{', '.join(score_strings['train'])}" 75 | string += f" TRAIN:[{train_scores}]" 76 | if score_strings["valid"]: 77 | valid_scores = f"{', '.join(score_strings['valid'])}" 78 | string += f" VALID:[{valid_scores}]" 79 | logging.info(string) 80 | -------------------------------------------------------------------------------- /snorkel/labeling/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def filter_unlabeled_dataframe( 8 | X: pd.DataFrame, y: np.ndarray, L: np.ndarray 9 | ) -> Tuple[pd.DataFrame, np.ndarray]: 10 | """Filter out examples not covered by any labeling function. 11 | 12 | Parameters 13 | ---------- 14 | X 15 | Data points in a Pandas DataFrame. 16 | y 17 | Matrix of probabilities output by label model's predict_proba method. 18 | L 19 | Matrix of labels emitted by LFs. 20 | 21 | Returns 22 | ------- 23 | pd.DataFrame 24 | Data points that were labeled by at least one LF in L. 25 | np.ndarray 26 | Probabilities matrix for data points labeled by at least one LF in L. 27 | """ 28 | mask = (L != -1).any(axis=1) 29 | return X.iloc[mask], y[mask] 30 | -------------------------------------------------------------------------------- /snorkel/map/__init__.py: -------------------------------------------------------------------------------- 1 | """Generic utilities for data point to data point operations.""" 2 | 3 | from .core import BaseMapper, LambdaMapper, Mapper, lambda_mapper # noqa: F401 4 | -------------------------------------------------------------------------------- /snorkel/map/spark.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import Row 2 | 3 | from snorkel.types import FieldMap 4 | 5 | from .core import Mapper 6 | 7 | 8 | def _update_fields(x: Row, mapped_fields: FieldMap) -> Row: 9 | # ``pyspark.sql.Row`` objects are not mutable, so need to 10 | # reconstruct 11 | all_fields = x.asDict() 12 | all_fields.update(mapped_fields) 13 | return Row(**all_fields) 14 | 15 | 16 | def make_spark_mapper(mapper: Mapper) -> Mapper: 17 | """Convert ``Mapper`` to be compatible with PySpark. 18 | 19 | Parameters 20 | ---------- 21 | mapper 22 | Mapper to make compatible with PySpark 23 | """ 24 | mapper._update_fields = _update_fields # type: ignore 25 | return mapper 26 | -------------------------------------------------------------------------------- /snorkel/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | """Preprocessors for LFs, TFs, and SFs.""" 2 | 3 | from .core import ( # noqa: F401 4 | BasePreprocessor, 5 | LambdaPreprocessor, 6 | Preprocessor, 7 | preprocessor, 8 | ) 9 | -------------------------------------------------------------------------------- /snorkel/preprocess/core.py: -------------------------------------------------------------------------------- 1 | from snorkel.map import BaseMapper, LambdaMapper, Mapper, lambda_mapper 2 | 3 | """Base classes for preprocessors. 4 | 5 | A preprocessor is a data point to data point mapping in a labeling 6 | pipeline. This allows Snorkel operations (e.g. LFs) to share common 7 | preprocessing steps that make it easier to express labeling logic. 8 | A simple example for text processing is concatenating the title and 9 | body of an article. For a more complex example, see 10 | ``snorkel.preprocess.nlp.SpacyPreprocessor``. 11 | """ 12 | 13 | # Used for type checking only 14 | # Note: subclassing as below trips up mypy 15 | BasePreprocessor = BaseMapper 16 | 17 | 18 | class Preprocessor(Mapper): 19 | """Base class for preprocessors. 20 | 21 | See ``snorkel.map.core.Mapper`` for details. 22 | """ 23 | 24 | pass 25 | 26 | 27 | class LambdaPreprocessor(LambdaMapper): 28 | """Convenience class for defining preprocessors from functions. 29 | 30 | See ``snorkel.map.core.LambdaMapper`` for details. 31 | """ 32 | 33 | pass 34 | 35 | 36 | class preprocessor(lambda_mapper): 37 | """Decorate functions to create preprocessors. 38 | 39 | See ``snorkel.map.core.lambda_mapper`` for details. 40 | 41 | Example 42 | ------- 43 | >>> @preprocessor() 44 | ... def combine_text_preprocessor(x): 45 | ... x.article = f"{x.title} {x.body}" 46 | ... return x 47 | >>> from snorkel.preprocess.nlp import SpacyPreprocessor 48 | >>> spacy_preprocessor = SpacyPreprocessor("article", "article_parsed") 49 | 50 | We can now add our preprocessors to an LF. 51 | 52 | >>> preprocessors = [combine_text_preprocessor, spacy_preprocessor] 53 | >>> from snorkel.labeling.lf import labeling_function 54 | >>> @labeling_function(pre=preprocessors) 55 | ... def article_mentions_person(x): 56 | ... for ent in x.article_parsed.ents: 57 | ... if ent.label_ == "PERSON": 58 | ... return ABSTAIN 59 | ... return NEGATIVE 60 | """ 61 | 62 | pass 63 | -------------------------------------------------------------------------------- /snorkel/preprocess/nlp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import spacy 4 | 5 | from snorkel.types import FieldMap, HashingFunction 6 | 7 | from .core import BasePreprocessor, Preprocessor 8 | 9 | EN_CORE_WEB_SM = "en_core_web_sm" 10 | 11 | 12 | class SpacyPreprocessor(Preprocessor): 13 | """Preprocessor that parses input text via a SpaCy model. 14 | 15 | A common approach to writing LFs over text is to first use 16 | a natural language parser to decompose the text into tokens, 17 | part-of-speech tags, etc. SpaCy (https://spacy.io/) is a 18 | popular tool for doing this. This preprocessor adds a 19 | SpaCy ``Doc`` object to the data point. A ``Doc`` object is 20 | a sequence of ``Token`` objects, which contain information 21 | on lemmatization, parts-of-speech, etc. ``Doc`` objects also 22 | contain fields like ``Doc.ents``, a list of named entities, 23 | and ``Doc.noun_chunks``, a list of noun phrases. For details 24 | of SpaCy ``Doc`` objects and a full attribute listing, 25 | see https://spacy.io/api/doc. 26 | 27 | Parameters 28 | ---------- 29 | text_field 30 | Name of data point text field to input 31 | doc_field 32 | Name of data point field to output parsed document to 33 | language 34 | SpaCy model to load 35 | See https://spacy.io/usage/models#usage 36 | disable 37 | List of pipeline components to disable 38 | See https://spacy.io/usage/processing-pipelines#disabling 39 | pre 40 | Preprocessors to run before this preprocessor is executed 41 | memoize 42 | Memoize preprocessor outputs? 43 | memoize_key 44 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable) 45 | gpu 46 | Prefer Spacy GPU processing? 47 | """ 48 | 49 | def __init__( 50 | self, 51 | text_field: str, 52 | doc_field: str, 53 | language: str = EN_CORE_WEB_SM, 54 | disable: Optional[List[str]] = None, 55 | pre: Optional[List[BasePreprocessor]] = None, 56 | memoize: bool = False, 57 | memoize_key: Optional[HashingFunction] = None, 58 | gpu: bool = False, 59 | ) -> None: 60 | name = type(self).__name__ 61 | super().__init__( 62 | name, 63 | field_names=dict(text=text_field), 64 | mapped_field_names=dict(doc=doc_field), 65 | pre=pre, 66 | memoize=memoize, 67 | memoize_key=memoize_key, 68 | ) 69 | self.gpu = gpu 70 | if self.gpu: 71 | spacy.prefer_gpu() 72 | self._nlp = spacy.load(language, disable=disable or []) 73 | 74 | def run(self, text: str) -> FieldMap: # type: ignore 75 | """Run the SpaCy model on input text. 76 | 77 | Parameters 78 | ---------- 79 | text 80 | Text of document to parse 81 | 82 | Returns 83 | ------- 84 | FieldMap 85 | Dictionary with a single key (``"doc"``), mapping to the 86 | parsed SpaCy ``Doc`` object 87 | """ 88 | # Note: not trying to add the fields of `Doc` to top-level 89 | # as most are Cython property methods computed on the fly. 90 | return dict(doc=self._nlp(text)) 91 | -------------------------------------------------------------------------------- /snorkel/preprocess/spark.py: -------------------------------------------------------------------------------- 1 | from snorkel.map.spark import make_spark_mapper as make_spark_preprocessor # noqa: F401 2 | -------------------------------------------------------------------------------- /snorkel/slicing/__init__.py: -------------------------------------------------------------------------------- 1 | """Programmatic data set slicing: SF creation, monitoring utilities, and representation learning for slices.""" 2 | 3 | from .apply.core import PandasSFApplier, SFApplier # noqa: F401 4 | from .modules.slice_combiner import SliceCombinerModule # noqa: F401 5 | from .monitor import slice_dataframe # noqa: F401 6 | from .sf.core import SlicingFunction, slicing_function # noqa: F401 7 | from .sliceaware_classifier import SliceAwareClassifier # noqa: F401 8 | from .utils import add_slice_labels, convert_to_slice_tasks # noqa: F401 9 | -------------------------------------------------------------------------------- /snorkel/slicing/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/slicing/apply/__init__.py -------------------------------------------------------------------------------- /snorkel/slicing/apply/core.py: -------------------------------------------------------------------------------- 1 | from snorkel.labeling import LFApplier, PandasLFApplier 2 | 3 | 4 | class SFApplier(LFApplier): 5 | """SF applier for a list of data points. 6 | 7 | See ``snorkel.labeling.core.LFApplier`` for details. 8 | """ 9 | 10 | _use_recarray = True 11 | 12 | 13 | class PandasSFApplier(PandasLFApplier): 14 | """SF applier for a Pandas DataFrame. 15 | 16 | See ``snorkel.labeling.core.PandasLFApplier`` for details. 17 | """ 18 | 19 | _use_recarray = True 20 | -------------------------------------------------------------------------------- /snorkel/slicing/apply/dask.py: -------------------------------------------------------------------------------- 1 | from snorkel.labeling.apply.dask import ( # pragma: no cover 2 | DaskLFApplier, 3 | PandasParallelLFApplier, 4 | ) 5 | 6 | 7 | class DaskSFApplier(DaskLFApplier): # pragma: no cover 8 | """SF applier for a Dask DataFrame. 9 | 10 | See ``snorkel.labeling.apply.dask.DaskLFApplier`` for details. 11 | """ 12 | 13 | _use_recarray = True 14 | 15 | 16 | class PandasParallelSFApplier(PandasParallelLFApplier): # pragma: no cover 17 | """Parallel SF applier for a Pandas DataFrame. 18 | 19 | See ``snorkel.labeling.apply.dask.PandasParallelLFApplier`` for details. 20 | """ 21 | 22 | _use_recarray = True 23 | -------------------------------------------------------------------------------- /snorkel/slicing/apply/spark.py: -------------------------------------------------------------------------------- 1 | from snorkel.labeling.apply.spark import SparkLFApplier as SparkSFApplier # noqa: F401 2 | -------------------------------------------------------------------------------- /snorkel/slicing/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/slicing/modules/__init__.py -------------------------------------------------------------------------------- /snorkel/slicing/monitor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from snorkel.slicing import PandasSFApplier 5 | from snorkel.slicing.sf import SlicingFunction 6 | 7 | 8 | def slice_dataframe( 9 | df: pd.DataFrame, slicing_function: SlicingFunction 10 | ) -> pd.DataFrame: 11 | """Return a dataframe with examples corresponding to specified ``SlicingFunction``. 12 | 13 | Parameters 14 | ---------- 15 | df 16 | A pandas DataFrame that will be sliced 17 | slicing_function 18 | SlicingFunction which will operate over df to return a subset of examples; 19 | function returns a subset of data for which ``slicing_function`` output is True 20 | 21 | Returns 22 | ------- 23 | pd.DataFrame 24 | A DataFrame including only examples belonging to slice_name 25 | """ 26 | 27 | S = PandasSFApplier([slicing_function]).apply(df) 28 | 29 | # Index into the SF labels by name 30 | df_idx = np.where(S[slicing_function.name])[0] # type: ignore 31 | return df.iloc[df_idx] 32 | -------------------------------------------------------------------------------- /snorkel/slicing/sf/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import SlicingFunction, slicing_function # noqa: F401 2 | -------------------------------------------------------------------------------- /snorkel/slicing/sf/core.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Mapping, Optional 2 | 3 | from snorkel.labeling.lf import LabelingFunction 4 | from snorkel.preprocess import BasePreprocessor 5 | 6 | 7 | class SlicingFunction(LabelingFunction): 8 | """Base class for slicing functions. 9 | 10 | See ``snorkel.labeling.lf.LabelingFunction`` for details. 11 | """ 12 | 13 | pass 14 | 15 | 16 | class slicing_function: 17 | """Decorator to define a SlicingFunction object from a function. 18 | 19 | Parameters 20 | ---------- 21 | name 22 | Name of the SF 23 | resources 24 | Slicing resources passed in to ``f`` via ``kwargs`` 25 | preprocessors 26 | Preprocessors to run on data points before SF execution 27 | 28 | Examples 29 | -------- 30 | >>> @slicing_function() 31 | ... def f(x): 32 | ... return x.a > 42 33 | >>> f 34 | SlicingFunction f, Preprocessors: [] 35 | >>> from types import SimpleNamespace 36 | >>> x = SimpleNamespace(a=90, b=12) 37 | >>> f(x) 38 | True 39 | 40 | >>> @slicing_function(name="my_sf") 41 | ... def g(x): 42 | ... return 0 if x.a > 42 else -1 43 | >>> g 44 | SlicingFunction my_sf, Preprocessors: [] 45 | """ 46 | 47 | def __init__( 48 | self, 49 | name: Optional[str] = None, 50 | resources: Optional[Mapping[str, Any]] = None, 51 | pre: Optional[List[BasePreprocessor]] = None, 52 | ) -> None: 53 | if callable(name): 54 | raise ValueError("Looks like this decorator is missing parentheses!") 55 | self.name = name 56 | self.resources = resources 57 | self.pre = pre 58 | 59 | def __call__(self, f: Callable[..., int]) -> SlicingFunction: 60 | """Wrap a function to create a ``SlicingFunction``. 61 | 62 | Parameters 63 | ---------- 64 | f 65 | Function that implements the core LF logic 66 | 67 | Returns 68 | ------- 69 | SlicingFunction 70 | New ``SlicingFunction`` executing logic in wrapped function 71 | """ 72 | name = self.name or f.__name__ 73 | return SlicingFunction(name=name, f=f, resources=self.resources, pre=self.pre) 74 | -------------------------------------------------------------------------------- /snorkel/slicing/sf/nlp.py: -------------------------------------------------------------------------------- 1 | from snorkel.labeling.lf.nlp import ( 2 | BaseNLPLabelingFunction, 3 | SpacyPreprocessorParameters, 4 | base_nlp_labeling_function, 5 | ) 6 | from snorkel.preprocess.nlp import SpacyPreprocessor 7 | 8 | 9 | class NLPSlicingFunction(BaseNLPLabelingFunction): 10 | r"""Special labeling function type for spaCy-based LFs. 11 | 12 | This class is a special version of ``LabelingFunction``. It 13 | has a ``SpacyPreprocessor`` integrated which shares a cache 14 | with all other ``NLPLabelingFunction`` instances. This makes 15 | it easy to define LFs that have a text input field and have 16 | logic written over spaCy ``Doc`` objects. Examples passed 17 | into an ``NLPLabelingFunction`` will have a new field which 18 | can be accessed which contains a spaCy ``Doc``. By default, 19 | this field is called ``doc``. A ``Doc`` object is 20 | a sequence of ``Token`` objects, which contain information 21 | on lemmatization, parts-of-speech, etc. ``Doc`` objects also 22 | contain fields like ``Doc.ents``, a list of named entities, 23 | and ``Doc.noun_chunks``, a list of noun phrases. For details 24 | of spaCy ``Doc`` objects and a full attribute listing, 25 | see https://spacy.io/api/doc. 26 | 27 | Simple ``NLPLabelingFunction``\s can be defined via a 28 | decorator. See ``nlp_labeling_function``. 29 | 30 | Parameters 31 | ---------- 32 | name 33 | Name of the LF 34 | f 35 | Function that implements the core LF logic 36 | resources 37 | Labeling resources passed in to ``f`` via ``kwargs`` 38 | pre 39 | Preprocessors to run before SpacyPreprocessor is executed 40 | text_field 41 | Name of data point text field to input 42 | doc_field 43 | Name of data point field to output parsed document to 44 | language 45 | spaCy model to load 46 | See https://spacy.io/usage/models#usage 47 | disable 48 | List of pipeline components to disable 49 | See https://spacy.io/usage/processing-pipelines#disabling 50 | memoize 51 | Memoize preprocessor outputs? 52 | memoize_key 53 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable) 54 | 55 | Raises 56 | ------ 57 | ValueError 58 | Calling incorrectly defined preprocessors 59 | 60 | Example 61 | ------- 62 | >>> def f(x): 63 | ... person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 64 | ... return len(person_ents) > 0 65 | >>> has_person_mention = NLPSlicingFunction(name="has_person_mention", f=f) 66 | >>> has_person_mention 67 | NLPSlicingFunction has_person_mention, Preprocessors: [SpacyPreprocessor...] 68 | 69 | >>> from types import SimpleNamespace 70 | >>> x = SimpleNamespace(text="The movie was good.") 71 | >>> has_person_mention(x) 72 | False 73 | 74 | Attributes 75 | ---------- 76 | name 77 | See above 78 | """ 79 | 80 | @classmethod 81 | def _create_preprocessor( 82 | cls, parameters: SpacyPreprocessorParameters 83 | ) -> SpacyPreprocessor: 84 | return SpacyPreprocessor(**parameters._asdict()) 85 | 86 | 87 | class nlp_slicing_function(base_nlp_labeling_function): 88 | """Decorator to define a NLPSlicingFunction child object from a function. 89 | 90 | TODO: Implement a common parent decorator for Snorkel operators 91 | """ 92 | 93 | _lf_cls = NLPSlicingFunction 94 | -------------------------------------------------------------------------------- /snorkel/synthetic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/synthetic/__init__.py -------------------------------------------------------------------------------- /snorkel/synthetic/synthetic_data.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | def generate_simple_label_matrix( 7 | n: int, m: int, cardinality: int, abstain_multiplier: float = 1.0 8 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 9 | """Generate a synthetic label matrix with true parameters and labels. 10 | 11 | This function generates a set of labeling function conditional probability tables, 12 | P(LF=l | Y=y), stored as a matrix P, and true labels Y, and then generates the 13 | resulting label matrix L. 14 | 15 | Parameters 16 | ---------- 17 | n 18 | Number of data points 19 | m 20 | Number of labeling functions 21 | cardinality 22 | Cardinality of true labels (i.e. not including abstains) 23 | abstain_multiplier 24 | Factor to multiply the probability of abstaining by 25 | 26 | Returns 27 | ------- 28 | Tuple[np.ndarray, np.ndarray, np.ndarray] 29 | A tuple containing the LF conditional probabilities P, 30 | the true labels Y, and the output label matrix L 31 | """ 32 | # Generate the conditional probability tables for the LFs 33 | # The first axis is LF, second is LF output label, third is true class label 34 | # Note that we include abstains in the LF output space, and that we bias the 35 | # conditional probabilities towards being non-adversarial 36 | P = np.empty((m, cardinality + 1, cardinality)) 37 | for i in range(m): 38 | p = np.random.rand(cardinality + 1, cardinality) 39 | 40 | # Bias the LFs to being non-adversarial 41 | p[1:, :] += (cardinality - 1) * np.eye(cardinality) 42 | 43 | # Optionally increase the abstain probability by some multiplier; note this is 44 | # to simulate the common setting where LFs label very sparsely 45 | p[0, :] *= abstain_multiplier 46 | 47 | # Normalize the conditional probabilities table 48 | P[i] = p @ np.diag(1 / p.sum(axis=0)) 49 | 50 | # Generate the true datapoint labels 51 | # Note: Assuming balanced classes to start 52 | Y = np.random.choice(cardinality, n) 53 | 54 | # Generate the label matrix L 55 | L: np.ndarray = np.empty((n, m), dtype=int) 56 | for i in range(n): 57 | for j in range(m): 58 | L[i, j] = np.random.choice(cardinality + 1, p=P[j, :, Y[i]]) - 1 59 | return P, Y, L 60 | -------------------------------------------------------------------------------- /snorkel/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import Config # noqa: F401 2 | from .data import DataPoint, DataPoints, Field, FieldMap # noqa: F401 3 | from .hashing import HashingFunction # noqa: F401 4 | -------------------------------------------------------------------------------- /snorkel/types/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | Config = NamedTuple 4 | -------------------------------------------------------------------------------- /snorkel/types/data.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Sequence 2 | 3 | DataPoint = Any 4 | DataPoints = Sequence[DataPoint] 5 | 6 | Field = Any 7 | FieldMap = Mapping[str, Field] 8 | -------------------------------------------------------------------------------- /snorkel/types/hashing.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Hashable 2 | from typing import Any, Callable 3 | 4 | HashingFunction = Callable[[Any], Hashable] 5 | -------------------------------------------------------------------------------- /snorkel/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """General machine learning utilities shared across Snorkel.""" 2 | 3 | from .core import ( # noqa: F401 4 | filter_labels, 5 | preds_to_probs, 6 | probs_to_preds, 7 | to_int_label_array, 8 | ) 9 | -------------------------------------------------------------------------------- /snorkel/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from snorkel.types import Config 4 | 5 | 6 | def merge_config(config: Config, config_updates: Dict[str, Any]) -> Config: 7 | """Merge a (potentially nested) dict of kwargs into a config (NamedTuple). 8 | 9 | Parameters 10 | ---------- 11 | config 12 | An instantiated Config to update 13 | config_updates 14 | A potentially nested dict of settings to update in the Config 15 | 16 | Returns 17 | ------- 18 | Config 19 | The updated Config 20 | 21 | Example 22 | ------- 23 | ``` 24 | config_updates = { 25 | "n_epochs": 5, 26 | "optimizer_config": { 27 | "lr": 0.001, 28 | } 29 | } 30 | trainer_config = merge_config(TrainerConfig(), config_updates) 31 | ``` 32 | """ 33 | for key, value in config_updates.items(): 34 | if isinstance(value, dict): 35 | config_updates[key] = merge_config(getattr(config, key), value) 36 | return config._replace(**config_updates) 37 | -------------------------------------------------------------------------------- /snorkel/utils/data_operators.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import List 3 | 4 | 5 | def check_unique_names(names: List[str]) -> None: 6 | """Check that operator names are unique.""" 7 | k, ct = Counter(names).most_common(1)[0] 8 | if ct > 1: 9 | raise ValueError(f"Operator names not unique: {ct} operators with name {k}") 10 | -------------------------------------------------------------------------------- /snorkel/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from snorkel.types import Config 2 | 3 | 4 | class ExponentialLRSchedulerConfig(Config): 5 | """Settings for Exponential decay learning rate scheduler.""" 6 | 7 | gamma: float = 0.9 8 | 9 | 10 | class StepLRSchedulerConfig(Config): 11 | """Settings for Step decay learning rate scheduler.""" 12 | 13 | gamma: float = 0.9 14 | step_size: int = 5 15 | 16 | 17 | class LRSchedulerConfig(Config): 18 | """Settings common to all LRSchedulers. 19 | 20 | Parameters 21 | ---------- 22 | warmup_steps 23 | The number of warmup_units over which to perform learning rate warmup (a linear 24 | increase from 0 to the specified lr) 25 | warmup_unit 26 | The unit to use when counting warmup (one of ["batches", "epochs"]) 27 | warmup_percentage 28 | The percentage of the training procedure to warm up over (ignored if 29 | warmup_steps is non-zero) 30 | min_lr 31 | The minimum learning rate to use during training (the learning rate specified 32 | by a learning rate scheduler will be rounded up to this if it is lower) 33 | exponential_config 34 | Extra settings for the ExponentialLRScheduler 35 | step_config 36 | Extra settings for the StepLRScheduler 37 | """ 38 | 39 | warmup_steps: float = 0 # warm up steps 40 | warmup_unit: str = "batches" # [epochs, batches] 41 | warmup_percentage: float = 0.0 # warm up percentage 42 | min_lr: float = 0.0 # minimum learning rate 43 | exponential_config: ExponentialLRSchedulerConfig = ( 44 | ExponentialLRSchedulerConfig() # type:ignore 45 | ) 46 | step_config: StepLRSchedulerConfig = StepLRSchedulerConfig() # type:ignore 47 | -------------------------------------------------------------------------------- /snorkel/utils/optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from snorkel.types import Config 4 | 5 | 6 | class SGDOptimizerConfig(Config): 7 | """Settings for SGD optimizer.""" 8 | 9 | momentum: float = 0.9 10 | 11 | 12 | class AdamOptimizerConfig(Config): 13 | """Settings for Adam optimizer.""" 14 | 15 | amsgrad: bool = False 16 | betas: Tuple[float, float] = (0.9, 0.999) 17 | 18 | 19 | class AdamaxOptimizerConfig(Config): 20 | """Settings for Adamax optimizer.""" 21 | 22 | betas: Tuple[float, float] = (0.9, 0.999) 23 | eps: float = 1e-8 24 | 25 | 26 | class OptimizerConfig(Config): 27 | """Settings common to all optimizers.""" 28 | 29 | sgd_config: SGDOptimizerConfig = SGDOptimizerConfig() # type:ignore 30 | adam_config: AdamOptimizerConfig = AdamOptimizerConfig() # type:ignore 31 | adamax_config: AdamaxOptimizerConfig = AdamaxOptimizerConfig() # type:ignore 32 | -------------------------------------------------------------------------------- /snorkel/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "10" 3 | _REVISION = "1+dev" 4 | 5 | VERSION_SHORT = f"{_MAJOR}.{_MINOR}" 6 | VERSION = f"{_MAJOR}.{_MINOR}.{_REVISION}" 7 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/__init__.py -------------------------------------------------------------------------------- /test/analysis/test_error_analysis.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from snorkel.analysis import get_label_buckets, get_label_instances 6 | 7 | 8 | class ErrorAnalysisTest(unittest.TestCase): 9 | def test_get_label_buckets(self) -> None: 10 | y1 = np.array([[2], [1], [3], [1], [1], [3]]) 11 | y2 = np.array([1, 2, 3, 1, 2, 3]) 12 | buckets = get_label_buckets(y1, y2) 13 | expected_buckets = {(2, 1): [0], (1, 2): [1, 4], (3, 3): [2, 5], (1, 1): [3]} 14 | expected_buckets = {k: np.array(v) for k, v in expected_buckets.items()} 15 | np.testing.assert_equal(buckets, expected_buckets) 16 | 17 | y1_1d = np.array([2, 1, 3, 1, 1, 3]) 18 | buckets = get_label_buckets(y1_1d, y2) 19 | np.testing.assert_equal(buckets, expected_buckets) 20 | 21 | def test_get_label_buckets_multi(self) -> None: 22 | y1 = np.array([[2], [1], [3], [1], [1], [3]]) 23 | y2 = np.array([1, 2, 3, 1, 2, 3]) 24 | y3 = np.array([[3], [2], [1], [1], [2], [3]]) 25 | buckets = get_label_buckets(y1, y2, y3) 26 | expected_buckets = { 27 | (2, 1, 3): [0], 28 | (1, 2, 2): [1, 4], 29 | (3, 3, 1): [2], 30 | (1, 1, 1): [3], 31 | (3, 3, 3): [5], 32 | } 33 | expected_buckets = {k: np.array(v) for k, v in expected_buckets.items()} 34 | np.testing.assert_equal(buckets, expected_buckets) 35 | 36 | def test_get_label_buckets_bad_shape(self) -> None: 37 | with self.assertRaisesRegex(ValueError, "same number of elements"): 38 | get_label_buckets(np.array([0, 1, 1]), np.array([1, 1])) 39 | 40 | def test_get_label_instances(self) -> None: 41 | x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) 42 | y1 = np.array([1, 0, 0, 0]) 43 | y2 = np.array([1, 1, 1, 0]) 44 | instances = get_label_instances((0, 1), x, y1, y2) 45 | expected_instances = np.array([[3, 4], [5, 6]]) 46 | np.testing.assert_equal(instances, expected_instances) 47 | 48 | x = np.array(["this", "is", "a", "test", "of", "multi"]) 49 | y1 = np.array([[2], [1], [3], [1], [1], [3]]) 50 | y2 = np.array([1, 2, 3, 1, 2, 3]) 51 | y3 = np.array([[3], [2], [1], [1], [2], [3]]) 52 | instances = get_label_instances((3, 3, 3), x, y1, y2, y3) 53 | expected_instances = np.array(["multi"]) 54 | np.testing.assert_equal(instances, expected_instances) 55 | 56 | def test_get_label_instances_exceptions(self) -> None: 57 | x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) 58 | y1 = np.array([1, 0, 0, 0]) 59 | y2 = np.array([1, 1, 1, 0]) 60 | instances = get_label_instances((2, 0), x, y1, y2) 61 | expected_instances = np.array([]) 62 | np.testing.assert_equal(instances, expected_instances) 63 | 64 | with self.assertRaisesRegex( 65 | ValueError, "Number of lists must match the amount of labels in bucket" 66 | ): 67 | get_label_instances((1, 0), x, y1) 68 | 69 | x = np.array([[1, 2], [3, 4], [5, 6]]) 70 | with self.assertRaisesRegex( 71 | ValueError, 72 | "Number of rows in x does not match number of elements in at least one label list", 73 | ): 74 | get_label_instances((1, 0), x, y1, y2) 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /test/augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/__init__.py -------------------------------------------------------------------------------- /test/augmentation/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/apply/__init__.py -------------------------------------------------------------------------------- /test/augmentation/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/policy/__init__.py -------------------------------------------------------------------------------- /test/augmentation/policy/test_core.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snorkel.augmentation import ApplyAllPolicy, ApplyEachPolicy 4 | 5 | 6 | class TestPolicy(unittest.TestCase): 7 | def test_apply_each_policy(self): 8 | policy = ApplyEachPolicy(3, keep_original=True) 9 | samples = policy.generate_for_example() 10 | self.assertEqual(samples, [[], [0], [1], [2]]) 11 | 12 | policy = ApplyEachPolicy(3, keep_original=False) 13 | samples = policy.generate_for_example() 14 | self.assertEqual(samples, [[0], [1], [2]]) 15 | 16 | def test_apply_all_policy(self): 17 | policy = ApplyAllPolicy(3, n_per_original=2, keep_original=False) 18 | samples = policy.generate_for_example() 19 | self.assertEqual(samples, [[0, 1, 2], [0, 1, 2]]) 20 | -------------------------------------------------------------------------------- /test/augmentation/policy/test_sampling.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snorkel.augmentation import MeanFieldPolicy, RandomPolicy 4 | 5 | 6 | class TestSamplingPolicy(unittest.TestCase): 7 | def test_random_policy(self): 8 | policy = RandomPolicy(2, sequence_length=2) 9 | n_samples = 100 10 | samples = [policy.generate() for _ in range(n_samples)] 11 | a_ct = samples.count([0, 0]) 12 | b_ct = samples.count([0, 1]) 13 | c_ct = samples.count([1, 0]) 14 | d_ct = samples.count([1, 1]) 15 | self.assertGreater(a_ct, 0) 16 | self.assertGreater(b_ct, 0) 17 | self.assertGreater(c_ct, 0) 18 | self.assertGreater(d_ct, 0) 19 | self.assertEqual(a_ct + b_ct + c_ct + d_ct, n_samples) 20 | 21 | def test_mean_field_policy(self): 22 | policy = MeanFieldPolicy(2, sequence_length=2, p=[1, 0]) 23 | n_samples = 100 24 | samples = [policy.generate() for _ in range(n_samples)] 25 | self.assertEqual(samples.count([0, 0]), n_samples) 26 | -------------------------------------------------------------------------------- /test/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/classification/__init__.py -------------------------------------------------------------------------------- /test/classification/test_classifier_convergence.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | import torch 9 | import torch.nn as nn 10 | 11 | from snorkel.analysis import Scorer 12 | from snorkel.classification import ( 13 | DictDataLoader, 14 | DictDataset, 15 | MultitaskClassifier, 16 | Operation, 17 | Task, 18 | Trainer, 19 | ) 20 | 21 | N_TRAIN = 1000 22 | N_VALID = 300 23 | 24 | 25 | class ClassifierConvergenceTest(unittest.TestCase): 26 | @classmethod 27 | def setUpClass(cls): 28 | # Ensure deterministic runs 29 | random.seed(123) 30 | np.random.seed(123) 31 | torch.manual_seed(123) 32 | 33 | @pytest.mark.complex 34 | def test_convergence(self): 35 | """Test multitask classifier convergence with two tasks.""" 36 | 37 | dataloaders = [] 38 | 39 | for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]): 40 | df = create_data(N_TRAIN, offset) 41 | dataloader = create_dataloader(df, "train", task_name) 42 | dataloaders.append(dataloader) 43 | 44 | for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]): 45 | df = create_data(N_VALID, offset) 46 | dataloader = create_dataloader(df, "valid", task_name) 47 | dataloaders.append(dataloader) 48 | 49 | task1 = create_task("task1", module_suffixes=["A", "A"]) 50 | task2 = create_task("task2", module_suffixes=["A", "B"]) 51 | model = MultitaskClassifier(tasks=[task1, task2]) 52 | 53 | # Train 54 | trainer = Trainer(lr=0.0024, n_epochs=10, progress_bar=False) 55 | trainer.fit(model, dataloaders) 56 | scores = model.score(dataloaders) 57 | 58 | # Confirm near perfect scores on both tasks 59 | for idx, task_name in enumerate(["task1", "task2"]): 60 | self.assertGreater(scores[f"{task_name}/TestData/valid/accuracy"], 0.95) 61 | 62 | # Calculate/check train/val loss 63 | train_dataset = dataloaders[idx].dataset 64 | train_loss_output = model.calculate_loss( 65 | train_dataset.X_dict, train_dataset.Y_dict 66 | ) 67 | train_loss = train_loss_output[0][task_name].item() 68 | self.assertLess(train_loss, 0.05) 69 | 70 | val_dataset = dataloaders[2 + idx].dataset 71 | val_loss_output = model.calculate_loss( 72 | val_dataset.X_dict, val_dataset.Y_dict 73 | ) 74 | val_loss = val_loss_output[0][task_name].item() 75 | self.assertLess(val_loss, 0.05) 76 | 77 | 78 | def create_data(n: int, offset=0) -> pd.DataFrame: 79 | """Create uniform X data from [-1, 1] on both axes. 80 | 81 | Create labels with linear decision boundaries related to the two coordinates of X. 82 | """ 83 | X = (np.random.random((n, 2)) * 2 - 1).astype(np.float32) 84 | Y = (X[:, 0] < X[:, 1] + offset).astype(int) 85 | 86 | df = pd.DataFrame({"x1": X[:, 0], "x2": X[:, 1], "y": Y}) 87 | return df 88 | 89 | 90 | def create_dataloader(df: pd.DataFrame, split: str, task_name: str) -> DictDataLoader: 91 | dataset = DictDataset( 92 | name="TestData", 93 | split=split, 94 | X_dict={ 95 | "coordinates": torch.stack( 96 | (torch.tensor(df["x1"]), torch.tensor(df["x2"])), dim=1 97 | ) 98 | }, 99 | Y_dict={task_name: torch.tensor(df["y"], dtype=torch.long)}, 100 | ) 101 | 102 | dataloader = DictDataLoader( 103 | dataset=dataset, batch_size=4, shuffle=(dataset.split == "train") 104 | ) 105 | return dataloader 106 | 107 | 108 | def create_task(task_name: str, module_suffixes: List[str]) -> Task: 109 | module1_name = f"linear1{module_suffixes[0]}" 110 | module2_name = f"linear2{module_suffixes[1]}" 111 | 112 | module_pool = nn.ModuleDict( 113 | { 114 | module1_name: nn.Sequential(nn.Linear(2, 20), nn.ReLU()), 115 | module2_name: nn.Linear(20, 2), 116 | } 117 | ) 118 | 119 | op1 = Operation(module_name=module1_name, inputs=[("_input_", "coordinates")]) 120 | op2 = Operation(module_name=module2_name, inputs=[op1.name]) 121 | 122 | op_sequence = [op1, op2] 123 | 124 | task = Task( 125 | name=task_name, 126 | module_pool=module_pool, 127 | op_sequence=op_sequence, 128 | scorer=Scorer(metrics=["accuracy"]), 129 | ) 130 | 131 | return task 132 | 133 | 134 | if __name__ == "__main__": 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /test/classification/test_task.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch.nn as nn 4 | 5 | from snorkel.classification import Operation, Task 6 | 7 | TASK_NAME = "TestTask" 8 | 9 | 10 | class TaskTest(unittest.TestCase): 11 | def test_task_creation(self): 12 | module_pool = nn.ModuleDict( 13 | { 14 | "linear1": nn.Sequential(nn.Linear(2, 10), nn.ReLU()), 15 | "linear2": nn.Linear(10, 1), 16 | } 17 | ) 18 | 19 | op_sequence = [ 20 | Operation( 21 | name="the_first_layer", module_name="linear1", inputs=["_input_"] 22 | ), 23 | Operation( 24 | name="the_second_layer", 25 | module_name="linear2", 26 | inputs=["the_first_layer"], 27 | ), 28 | ] 29 | 30 | task = Task(name=TASK_NAME, module_pool=module_pool, op_sequence=op_sequence) 31 | 32 | # Task has no functionality on its own 33 | # Here we only confirm that the object was initialized 34 | self.assertEqual(task.name, TASK_NAME) 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /test/classification/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from snorkel.classification.utils import ( 6 | collect_flow_outputs_by_suffix, 7 | list_to_tensor, 8 | pad_batch, 9 | ) 10 | 11 | 12 | class UtilsTest(unittest.TestCase): 13 | def test_pad_batch(self): 14 | batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])] 15 | padded_batch, mask_batch = pad_batch(batch) 16 | 17 | self.assertTrue( 18 | torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])) 19 | ) 20 | self.assertTrue( 21 | torch.equal(mask_batch, torch.Tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]])) 22 | ) 23 | 24 | padded_batch, mask_batch = pad_batch(batch, max_len=2) 25 | 26 | self.assertTrue( 27 | torch.equal(padded_batch, torch.Tensor([[1, 2], [3, 0], [4, 5]])) 28 | ) 29 | self.assertTrue(torch.equal(mask_batch, torch.Tensor([[0, 0], [0, 1], [0, 0]]))) 30 | 31 | padded_batch, mask_batch = pad_batch(batch, pad_value=-1) 32 | 33 | self.assertTrue( 34 | torch.equal( 35 | padded_batch, torch.Tensor([[1, 2, -1], [3, -1, -1], [4, 5, 6]]) 36 | ) 37 | ) 38 | self.assertTrue( 39 | torch.equal(mask_batch, torch.Tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]])) 40 | ) 41 | 42 | padded_batch, mask_batch = pad_batch(batch, left_padded=True) 43 | 44 | self.assertTrue( 45 | torch.equal(padded_batch, torch.Tensor([[0, 1, 2], [0, 0, 3], [4, 5, 6]])) 46 | ) 47 | self.assertTrue( 48 | torch.equal(mask_batch, torch.Tensor([[1, 0, 0], [1, 1, 0], [0, 0, 0]])) 49 | ) 50 | 51 | padded_batch, mask_batch = pad_batch(batch, max_len=2, left_padded=True) 52 | 53 | self.assertTrue( 54 | torch.equal(padded_batch, torch.Tensor([[1, 2], [0, 3], [5, 6]])) 55 | ) 56 | self.assertTrue(torch.equal(mask_batch, torch.Tensor([[0, 0], [1, 0], [0, 0]]))) 57 | 58 | def test_list_to_tensor(self): 59 | # list of 1-D tensor with the different length 60 | batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])] 61 | 62 | padded_batch = list_to_tensor(batch) 63 | 64 | self.assertTrue( 65 | torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])) 66 | ) 67 | 68 | # list of 1-D tensor with the same length 69 | batch = [ 70 | torch.Tensor([1, 2, 3]), 71 | torch.Tensor([4, 5, 6]), 72 | torch.Tensor([7, 8, 9]), 73 | ] 74 | 75 | padded_batch = list_to_tensor(batch) 76 | 77 | self.assertTrue( 78 | torch.equal(padded_batch, torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) 79 | ) 80 | 81 | # list of 2-D tensor with the same size 82 | batch = [ 83 | torch.Tensor([[1, 2, 3], [1, 2, 3]]), 84 | torch.Tensor([[4, 5, 6], [4, 5, 6]]), 85 | torch.Tensor([[7, 8, 9], [7, 8, 9]]), 86 | ] 87 | 88 | padded_batch = list_to_tensor(batch) 89 | 90 | self.assertTrue( 91 | torch.equal( 92 | padded_batch, 93 | torch.Tensor( 94 | [ 95 | [[1, 2, 3], [1, 2, 3]], 96 | [[4, 5, 6], [4, 5, 6]], 97 | [[7, 8, 9], [7, 8, 9]], 98 | ] 99 | ), 100 | ) 101 | ) 102 | 103 | # list of tensor with the different size 104 | batch = [ 105 | torch.Tensor([[1, 2], [2, 3]]), 106 | torch.Tensor([4, 5, 6]), 107 | torch.Tensor([7, 8, 9, 0]), 108 | ] 109 | 110 | padded_batch = list_to_tensor(batch) 111 | 112 | self.assertTrue( 113 | torch.equal( 114 | padded_batch, torch.Tensor([[1, 2, 2, 3], [4, 5, 6, 0], [7, 8, 9, 0]]) 115 | ) 116 | ) 117 | 118 | def test_collect_flow_outputs_by_suffix(self): 119 | flow_dict = { 120 | "a_pred_head": torch.Tensor([1]), 121 | "b_pred_head": torch.Tensor([2]), 122 | "c_pred": torch.Tensor([3]), 123 | } 124 | outputs = collect_flow_outputs_by_suffix(flow_dict, "_head") 125 | self.assertIn(torch.Tensor([1]), outputs) 126 | self.assertIn(torch.Tensor([2]), outputs) 127 | 128 | if __name__ == "__main__": 129 | unittest.main() 130 | -------------------------------------------------------------------------------- /test/classification/training/loggers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/classification/training/loggers/__init__.py -------------------------------------------------------------------------------- /test/classification/training/loggers/test_checkpointer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import unittest 5 | 6 | from snorkel.classification.multitask_classifier import MultitaskClassifier 7 | from snorkel.classification.training.loggers import Checkpointer 8 | 9 | log_manager_config = {"counter_unit": "epochs", "evaluation_freq": 1} 10 | 11 | 12 | class TestLogManager(unittest.TestCase): 13 | def setUp(self) -> None: 14 | self.test_dir = tempfile.mkdtemp() 15 | 16 | def tearDown(self) -> None: 17 | shutil.rmtree(self.test_dir) 18 | 19 | def test_checkpointer(self) -> None: 20 | checkpointer = Checkpointer( 21 | **log_manager_config, 22 | checkpoint_dir=self.test_dir, 23 | checkpoint_runway=3, 24 | checkpoint_metric="task/dataset/valid/f1:max", 25 | ) 26 | model = MultitaskClassifier([]) 27 | checkpointer.checkpoint(2, model, {"task/dataset/valid/f1": 0.5}) 28 | self.assertEqual(len(checkpointer.best_metric_dict), 0) 29 | checkpointer.checkpoint( 30 | 3, model, {"task/dataset/valid/f1": 0.8, "task/dataset/valid/f2": 0.5} 31 | ) 32 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8) 33 | checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.9}) 34 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.9) 35 | 36 | def test_checkpointer_min(self) -> None: 37 | checkpointer = Checkpointer( 38 | **log_manager_config, 39 | checkpoint_dir=self.test_dir, 40 | checkpoint_runway=3, 41 | checkpoint_metric="task/dataset/valid/f1:min", 42 | ) 43 | model = MultitaskClassifier([]) 44 | checkpointer.checkpoint( 45 | 3, model, {"task/dataset/valid/f1": 0.8, "task/dataset/valid/f2": 0.5} 46 | ) 47 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8) 48 | checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.7}) 49 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.7) 50 | 51 | def test_checkpointer_clear(self) -> None: 52 | checkpoint_dir = os.path.join(self.test_dir, "clear") 53 | checkpointer = Checkpointer( 54 | **log_manager_config, 55 | checkpoint_dir=checkpoint_dir, 56 | checkpoint_metric="task/dataset/valid/f1:max", 57 | checkpoint_clear=True, 58 | ) 59 | model = MultitaskClassifier([]) 60 | checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8}) 61 | expected_files = ["checkpoint_1.pth", "best_model_task_dataset_valid_f1.pth"] 62 | self.assertEqual(set(os.listdir(checkpoint_dir)), set(expected_files)) 63 | checkpointer.clear() 64 | expected_files = ["best_model_task_dataset_valid_f1.pth"] 65 | self.assertEqual(os.listdir(checkpoint_dir), expected_files) 66 | 67 | def test_checkpointer_load_best(self) -> None: 68 | checkpoint_dir = os.path.join(self.test_dir, "clear") 69 | checkpointer = Checkpointer( 70 | **log_manager_config, 71 | checkpoint_dir=checkpoint_dir, 72 | checkpoint_metric="task/dataset/valid/f1:max", 73 | ) 74 | model = MultitaskClassifier([]) 75 | checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8}) 76 | load_model = checkpointer.load_best_model(model) 77 | self.assertEqual(model, load_model) 78 | 79 | def test_bad_checkpoint_runway(self) -> None: 80 | with self.assertRaisesRegex(ValueError, "checkpoint_runway"): 81 | Checkpointer(**log_manager_config, checkpoint_runway=-1) 82 | 83 | def test_no_zero_frequency(self) -> None: 84 | with self.assertRaisesRegex(ValueError, "checkpoint freq"): 85 | Checkpointer( 86 | **log_manager_config, checkpoint_dir=self.test_dir, checkpoint_factor=0 87 | ) 88 | 89 | def test_bad_metric_name(self) -> None: 90 | with self.assertRaisesRegex(ValueError, "metric_name:mode"): 91 | Checkpointer( 92 | **log_manager_config, 93 | checkpoint_dir=self.test_dir, 94 | checkpoint_metric="task/dataset/split/f1-min", 95 | ) 96 | 97 | with self.assertRaisesRegex(ValueError, "metric mode"): 98 | Checkpointer( 99 | **log_manager_config, 100 | checkpoint_dir=self.test_dir, 101 | checkpoint_metric="task/dataset/split/f1:mode", 102 | ) 103 | 104 | with self.assertRaisesRegex(ValueError, "checkpoint_metric must be formatted"): 105 | Checkpointer( 106 | **log_manager_config, 107 | checkpoint_dir=self.test_dir, 108 | checkpoint_metric="accuracy:max", 109 | ) 110 | -------------------------------------------------------------------------------- /test/classification/training/loggers/test_log_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import unittest 6 | 7 | from snorkel.classification.training.loggers import LogWriter 8 | from snorkel.types import Config 9 | 10 | 11 | class TempConfig(Config): 12 | a: int = 42 13 | b: str = "foo" 14 | 15 | 16 | class TestLogWriter(unittest.TestCase): 17 | def setUp(self): 18 | self.test_dir = tempfile.mkdtemp() 19 | 20 | def tearDown(self): 21 | shutil.rmtree(self.test_dir) 22 | 23 | def test_log_writer(self): 24 | run_name = "my_run" 25 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir) 26 | log_writer.add_scalar("my_value", value=0.5, step=2) 27 | 28 | log_filename = "my_log.json" 29 | log_writer.write_log(log_filename) 30 | 31 | log_path = os.path.join(self.test_dir, run_name, log_filename) 32 | with open(log_path, "r") as f: 33 | log = json.load(f) 34 | 35 | log_expected = dict(my_value=[[2, 0.5]]) 36 | self.assertEqual(log, log_expected) 37 | 38 | def test_write_text(self) -> None: 39 | run_name = "my_run" 40 | filename = "my_text.txt" 41 | text = "my log text" 42 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir) 43 | log_writer.write_text(text, filename) 44 | log_path = os.path.join(self.test_dir, run_name, filename) 45 | with open(log_path, "r") as f: 46 | file_text = f.read() 47 | self.assertEqual(text, file_text) 48 | 49 | def test_write_config(self) -> None: 50 | run_name = "my_run" 51 | config = TempConfig(b="bar") # type: ignore 52 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir) 53 | log_writer.write_config(config) 54 | log_path = os.path.join(self.test_dir, run_name, "config.json") 55 | with open(log_path, "r") as f: 56 | file_config = json.load(f) 57 | self.assertEqual(config._asdict(), file_config) 58 | 59 | 60 | if __name__ == "__main__": 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /test/classification/training/loggers/test_tensorboard_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import unittest 6 | 7 | from snorkel.classification.training.loggers import TensorBoardWriter 8 | from snorkel.types import Config 9 | 10 | 11 | class TempConfig(Config): 12 | a: int = 42 13 | b: str = "foo" 14 | 15 | 16 | class TestTensorBoardWriter(unittest.TestCase): 17 | def setUp(self): 18 | self.test_dir = tempfile.mkdtemp() 19 | 20 | def tearDown(self): 21 | shutil.rmtree(self.test_dir) 22 | 23 | def test_tensorboard_writer(self): 24 | # Note: this just tests API calls. We rely on 25 | # tensorboard's unit tests for correctness. 26 | run_name = "my_run" 27 | config = TempConfig(b="bar") 28 | writer = TensorBoardWriter(run_name=run_name, log_dir=self.test_dir) 29 | writer.add_scalar("my_value", value=0.5, step=2) 30 | writer.write_config(config) 31 | log_path = os.path.join(self.test_dir, run_name, "config.json") 32 | with open(log_path, "r") as f: 33 | file_config = json.load(f) 34 | self.assertEqual(config._asdict(), file_config) 35 | writer.cleanup() 36 | -------------------------------------------------------------------------------- /test/classification/training/schedulers/test_schedulers.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from snorkel.classification import DictDataLoader, DictDataset 8 | from snorkel.classification.training.schedulers import ( 9 | SequentialScheduler, 10 | ShuffledScheduler, 11 | ) 12 | 13 | dataset1 = DictDataset( 14 | "d1", 15 | "train", 16 | X_dict={"data": [0, 1, 2, 3, 4]}, 17 | Y_dict={"labels": torch.LongTensor([1, 1, 1, 1, 1])}, 18 | ) 19 | dataset2 = DictDataset( 20 | "d2", 21 | "train", 22 | X_dict={"data": [5, 6, 7, 8, 9]}, 23 | Y_dict={"labels": torch.LongTensor([2, 2, 2, 2, 2])}, 24 | ) 25 | 26 | dataloader1 = DictDataLoader(dataset1, batch_size=2) 27 | dataloader2 = DictDataLoader(dataset2, batch_size=2) 28 | dataloaders = [dataloader1, dataloader2] 29 | 30 | 31 | class SequentialTest(unittest.TestCase): 32 | def test_sequential(self): 33 | scheduler = SequentialScheduler() 34 | data = [] 35 | for batch, dl in scheduler.get_batches(dataloaders): 36 | X_dict, Y_dict = batch 37 | data.extend(X_dict["data"]) 38 | self.assertEqual(data, sorted(data)) 39 | 40 | def test_shuffled(self): 41 | random.seed(123) 42 | np.random.seed(123) 43 | torch.manual_seed(123) 44 | scheduler = ShuffledScheduler() 45 | data = [] 46 | for batch, dl in scheduler.get_batches(dataloaders): 47 | X_dict, Y_dict = batch 48 | data.extend(X_dict["data"]) 49 | self.assertNotEqual(data, sorted(data)) 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /test/labeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/__init__.py -------------------------------------------------------------------------------- /test/labeling/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/apply/__init__.py -------------------------------------------------------------------------------- /test/labeling/apply/lf_applier_spark_test_script.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to manually test 3 | `snorkel.labeling.apply.lf_applier_spark.SparkLFApplier` 4 | 5 | To test on AWS EMR: 6 | 1. Allocate an EMR cluster (e.g. label 5.24.0) with > 1 worker and SSH permissions 7 | 2. Clone and pip install snorkel on the master node 8 | ``` 9 | sudo yum install git 10 | git clone https://github.com/snorkel-team/snorkel 11 | cd snorkel 12 | python3 -m pip install -t snorkel-package . 13 | cd snorkel-package 14 | zip -r ../snorkel-package.zip . 15 | cd .. 16 | ``` 17 | 3. Run 18 | ``` 19 | sudo sed -i -e \ 20 | '$a\\export PYSPARK_PYTHON=/usr/bin/python3' \ 21 | /etc/spark/conf/spark-env.sh 22 | ``` 23 | 4. Run 24 | ``` 25 | spark-submit \ 26 | --py-files snorkel-package.zip \ 27 | test/labeling/apply/lf_applier_spark_test_script.py 28 | ``` 29 | """ 30 | 31 | import logging 32 | from typing import List 33 | 34 | import numpy as np 35 | from pyspark import SparkContext 36 | 37 | from snorkel.labeling.apply.spark import SparkLFApplier 38 | from snorkel.labeling.lf import labeling_function 39 | from snorkel.types import DataPoint 40 | 41 | logging.basicConfig(level=logging.INFO) 42 | 43 | 44 | @labeling_function() 45 | def f(x: DataPoint) -> int: 46 | return 1 if x > 42 else 0 47 | 48 | 49 | @labeling_function(resources=dict(db=[3, 6, 9])) 50 | def g(x: DataPoint, db: List[int]) -> int: 51 | return 1 if x in db else 0 52 | 53 | 54 | DATA = [3, 43, 12, 9] 55 | L_EXPECTED = np.array([[0, 1], [1, 0], [0, 0], [0, 1]]) 56 | 57 | 58 | def build_lf_matrix() -> None: 59 | logging.info("Getting Spark context") 60 | sc = SparkContext() 61 | sc.addPyFile("snorkel-package.zip") 62 | rdd = sc.parallelize(DATA) 63 | 64 | logging.info("Applying LFs") 65 | lf_applier = SparkLFApplier([f, g]) 66 | L = lf_applier.apply(rdd) 67 | 68 | np.testing.assert_equal(L.toarray(), L_EXPECTED) 69 | 70 | 71 | if __name__ == "__main__": 72 | build_lf_matrix() 73 | -------------------------------------------------------------------------------- /test/labeling/apply/test_spark.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from pyspark import SparkContext 8 | from pyspark.sql import Row, SQLContext 9 | 10 | from snorkel.labeling import labeling_function 11 | from snorkel.labeling.apply.spark import SparkLFApplier 12 | from snorkel.preprocess import preprocessor 13 | from snorkel.types import DataPoint 14 | 15 | 16 | @preprocessor() 17 | def square(x: Row) -> Row: 18 | return Row(num=x.num, num_squared=x.num**2) 19 | 20 | 21 | @labeling_function() 22 | def f(x: DataPoint) -> int: 23 | return 0 if x.num > 42 else -1 24 | 25 | 26 | @labeling_function(pre=[square]) 27 | def fp(x: DataPoint) -> int: 28 | return 0 if x.num_squared > 42 else -1 29 | 30 | 31 | @labeling_function(resources=dict(db=[3, 6, 9])) 32 | def g(x: DataPoint, db: List[int]) -> int: 33 | return 0 if x.num in db else -1 34 | 35 | 36 | @labeling_function() 37 | def f_bad(x: DataPoint) -> int: 38 | return 0 if x.mum > 42 else -1 39 | 40 | 41 | DATA = [3, 43, 12, 9, 3] 42 | L_EXPECTED = np.array([[-1, 0], [0, -1], [-1, -1], [-1, 0], [-1, 0]]) 43 | L_EXPECTED_BAD = np.array([[-1, -1], [0, -1], [-1, -1], [-1, -1], [-1, -1]]) 44 | L_PREPROCESS_EXPECTED = np.array([[-1, -1], [0, 0], [-1, 0], [-1, 0], [-1, -1]]) 45 | 46 | TEXT_DATA = ["Jane", "Jane plays soccer.", "Jane plays soccer."] 47 | L_TEXT_EXPECTED = np.array([[0, -1], [0, 0], [0, 0]]) 48 | 49 | 50 | class TestSparkApplier(unittest.TestCase): 51 | @pytest.mark.complex 52 | @pytest.mark.spark 53 | def test_lf_applier_spark(self) -> None: 54 | sc = SparkContext.getOrCreate() 55 | sql = SQLContext(sc) 56 | df = pd.DataFrame(dict(num=DATA)) 57 | rdd = sql.createDataFrame(df).rdd 58 | applier = SparkLFApplier([f, g]) 59 | L = applier.apply(rdd) 60 | np.testing.assert_equal(L, L_EXPECTED) 61 | 62 | @pytest.mark.complex 63 | @pytest.mark.spark 64 | def test_lf_applier_spark_fault(self) -> None: 65 | sc = SparkContext.getOrCreate() 66 | sql = SQLContext(sc) 67 | df = pd.DataFrame(dict(num=DATA)) 68 | rdd = sql.createDataFrame(df).rdd 69 | applier = SparkLFApplier([f, f_bad]) 70 | with self.assertRaises(Exception): 71 | applier.apply(rdd) 72 | L = applier.apply(rdd, fault_tolerant=True) 73 | np.testing.assert_equal(L, L_EXPECTED_BAD) 74 | 75 | @pytest.mark.complex 76 | @pytest.mark.spark 77 | def test_lf_applier_spark_preprocessor(self) -> None: 78 | sc = SparkContext.getOrCreate() 79 | sql = SQLContext(sc) 80 | df = pd.DataFrame(dict(num=DATA)) 81 | rdd = sql.createDataFrame(df).rdd 82 | applier = SparkLFApplier([f, fp]) 83 | L = applier.apply(rdd) 84 | np.testing.assert_equal(L, L_PREPROCESS_EXPECTED) 85 | 86 | @pytest.mark.complex 87 | @pytest.mark.spark 88 | def test_lf_applier_spark_preprocessor_memoized(self) -> None: 89 | sc = SparkContext.getOrCreate() 90 | sql = SQLContext(sc) 91 | 92 | @preprocessor(memoize=True) 93 | def square_memoize(x: DataPoint) -> DataPoint: 94 | return Row(num=x.num, num_squared=x.num**2) 95 | 96 | @labeling_function(pre=[square_memoize]) 97 | def fp_memoized(x: DataPoint) -> int: 98 | return 0 if x.num_squared > 42 else -1 99 | 100 | df = pd.DataFrame(dict(num=DATA)) 101 | rdd = sql.createDataFrame(df).rdd 102 | applier = SparkLFApplier([f, fp_memoized]) 103 | L = applier.apply(rdd) 104 | np.testing.assert_equal(L, L_PREPROCESS_EXPECTED) 105 | -------------------------------------------------------------------------------- /test/labeling/lf/test_core.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import unittest 3 | from types import SimpleNamespace 4 | from typing import List 5 | 6 | from snorkel.labeling import LabelingFunction, labeling_function 7 | from snorkel.preprocess import preprocessor 8 | from snorkel.types import DataPoint 9 | 10 | 11 | @preprocessor() 12 | def square(x: DataPoint) -> DataPoint: 13 | x.num = x.num**2 14 | return x 15 | 16 | 17 | @preprocessor() 18 | def returns_none(x: DataPoint) -> DataPoint: 19 | return None 20 | 21 | 22 | def f(x: DataPoint) -> int: 23 | return 0 if x.num > 42 else -1 24 | 25 | 26 | def g(x: DataPoint, db: List[int]) -> int: 27 | return 0 if x.num in db else -1 28 | 29 | 30 | class TestLabelingFunction(unittest.TestCase): 31 | def _run_lf(self, lf: LabelingFunction) -> None: 32 | x_43 = SimpleNamespace(num=43) 33 | x_19 = SimpleNamespace(num=19) 34 | self.assertEqual(lf(x_43), 0) 35 | self.assertEqual(lf(x_19), -1) 36 | 37 | def test_labeling_function(self) -> None: 38 | lf = LabelingFunction(name="my_lf", f=f) 39 | self._run_lf(lf) 40 | 41 | def test_labeling_function_resources(self) -> None: 42 | db = [3, 6, 43] 43 | lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db)) 44 | self._run_lf(lf) 45 | 46 | def test_labeling_function_preprocessor(self) -> None: 47 | lf = LabelingFunction(name="my_lf", f=f, pre=[square, square]) 48 | x_43 = SimpleNamespace(num=43) 49 | x_6 = SimpleNamespace(num=6) 50 | x_2 = SimpleNamespace(num=2) 51 | self.assertEqual(lf(x_43), 0) 52 | self.assertEqual(lf(x_6), 0) 53 | self.assertEqual(lf(x_2), -1) 54 | 55 | def test_labeling_function_returns_none(self) -> None: 56 | lf = LabelingFunction(name="my_lf", f=f, pre=[square, returns_none]) 57 | x_43 = SimpleNamespace(num=43) 58 | with self.assertRaises(ValueError): 59 | lf(x_43) 60 | 61 | def test_labeling_function_serialize(self) -> None: 62 | db = [3, 6, 43] 63 | lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db)) 64 | lf_load = pickle.loads(pickle.dumps(lf)) 65 | self._run_lf(lf_load) 66 | 67 | def test_labeling_function_decorator(self) -> None: 68 | @labeling_function() 69 | def lf(x: DataPoint) -> int: 70 | return 0 if x.num > 42 else -1 71 | 72 | self.assertIsInstance(lf, LabelingFunction) 73 | self.assertEqual(lf.name, "lf") 74 | self._run_lf(lf) 75 | 76 | def test_labeling_function_decorator_args(self) -> None: 77 | db = [3, 6, 43**2] 78 | 79 | @labeling_function(name="my_lf", resources=dict(db=db), pre=[square]) 80 | def lf(x: DataPoint, db: List[int]) -> int: 81 | return 0 if x.num in db else -1 82 | 83 | self.assertIsInstance(lf, LabelingFunction) 84 | self.assertEqual(lf.name, "my_lf") 85 | self._run_lf(lf) 86 | 87 | def test_labeling_function_decorator_no_parens(self) -> None: 88 | with self.assertRaisesRegex(ValueError, "missing parentheses"): 89 | 90 | @labeling_function 91 | def lf(x: DataPoint) -> int: 92 | return 0 if x.num > 42 else -1 93 | -------------------------------------------------------------------------------- /test/labeling/lf/test_nlp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | 4 | import dill 5 | import pytest 6 | 7 | from snorkel.labeling.lf.nlp import NLPLabelingFunction, nlp_labeling_function 8 | from snorkel.preprocess import preprocessor 9 | from snorkel.types import DataPoint 10 | 11 | 12 | @preprocessor() 13 | def combine_text(x: DataPoint) -> DataPoint: 14 | x.text = f"{x.title} {x.article}" 15 | return x 16 | 17 | 18 | def has_person_mention(x: DataPoint) -> int: 19 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 20 | return 0 if len(person_ents) > 0 else -1 21 | 22 | 23 | class TestNLPLabelingFunction(unittest.TestCase): 24 | def _run_lf(self, lf: NLPLabelingFunction) -> None: 25 | x = SimpleNamespace( 26 | num=8, title="Great film!", article="The movie is really great!" 27 | ) 28 | self.assertEqual(lf(x), -1) 29 | x = SimpleNamespace(num=8, title="Nice movie!", article="Jane Doe acted well.") 30 | self.assertEqual(lf(x), 0) 31 | 32 | def test_nlp_labeling_function(self) -> None: 33 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text]) 34 | self._run_lf(lf) 35 | 36 | def test_nlp_labeling_function_memoized(self) -> None: 37 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text]) 38 | lf._nlp_config.nlp.reset_cache() 39 | self.assertEqual(len(lf._nlp_config.nlp._cache), 0) 40 | self._run_lf(lf) 41 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2) 42 | self._run_lf(lf) 43 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2) 44 | 45 | @pytest.mark.complex 46 | def test_labeling_function_serialize(self) -> None: 47 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text]) 48 | lf_load = dill.loads(dill.dumps(lf)) 49 | self._run_lf(lf_load) 50 | 51 | def test_nlp_labeling_function_decorator(self) -> None: 52 | @nlp_labeling_function(pre=[combine_text]) 53 | def has_person_mention(x: DataPoint) -> int: 54 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 55 | return 0 if len(person_ents) > 0 else -1 56 | 57 | self.assertIsInstance(has_person_mention, NLPLabelingFunction) 58 | self.assertEqual(has_person_mention.name, "has_person_mention") 59 | self._run_lf(has_person_mention) 60 | 61 | def test_nlp_labeling_function_decorator_no_parens(self) -> None: 62 | with self.assertRaisesRegex(ValueError, "missing parentheses"): 63 | 64 | @nlp_labeling_function 65 | def has_person_mention(x: DataPoint) -> int: 66 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 67 | return 0 if len(person_ents) > 0 else -1 68 | 69 | def test_nlp_labeling_function_shared_cache(self) -> None: 70 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text]) 71 | 72 | @nlp_labeling_function(pre=[combine_text]) 73 | def lf2(x: DataPoint) -> int: 74 | return 0 if len(x.doc) < 9 else -1 75 | 76 | lf._nlp_config.nlp.reset_cache() 77 | self.assertEqual(len(lf._nlp_config.nlp._cache), 0) 78 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 0) 79 | self._run_lf(lf) 80 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2) 81 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 2) 82 | self._run_lf(lf2) 83 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2) 84 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 2) 85 | 86 | def test_nlp_labeling_function_raises(self) -> None: 87 | with self.assertRaisesRegex(ValueError, "different parameters"): 88 | 89 | @nlp_labeling_function() 90 | def has_person_mention(x: DataPoint) -> int: 91 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 92 | return 0 if len(person_ents) > 0 else -1 93 | -------------------------------------------------------------------------------- /test/labeling/lf/test_nlp_spark.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | 4 | import pytest 5 | from pyspark.sql import Row 6 | 7 | from snorkel.labeling.lf.nlp import NLPLabelingFunction 8 | from snorkel.labeling.lf.nlp_spark import ( 9 | SparkNLPLabelingFunction, 10 | spark_nlp_labeling_function, 11 | ) 12 | from snorkel.types import DataPoint 13 | 14 | 15 | def has_person_mention(x: DataPoint) -> int: 16 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 17 | return 0 if len(person_ents) > 0 else -1 18 | 19 | 20 | @pytest.mark.spark 21 | class TestNLPLabelingFunction(unittest.TestCase): 22 | def _run_lf(self, lf: SparkNLPLabelingFunction) -> None: 23 | x = Row(num=8, text="The movie is really great!") 24 | self.assertEqual(lf(x), -1) 25 | x = Row(num=8, text="Jane Doe acted well.") 26 | self.assertEqual(lf(x), 0) 27 | 28 | def test_nlp_labeling_function(self) -> None: 29 | lf = SparkNLPLabelingFunction(name="my_lf", f=has_person_mention) 30 | self._run_lf(lf) 31 | 32 | def test_nlp_labeling_function_decorator(self) -> None: 33 | @spark_nlp_labeling_function() 34 | def has_person_mention(x: DataPoint) -> int: 35 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 36 | return 0 if len(person_ents) > 0 else -1 37 | 38 | self.assertIsInstance(has_person_mention, SparkNLPLabelingFunction) 39 | self.assertEqual(has_person_mention.name, "has_person_mention") 40 | self._run_lf(has_person_mention) 41 | 42 | def test_spark_nlp_labeling_function_with_nlp_labeling_function(self) -> None: 43 | # Do they have separate _nlp_configs? 44 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention) 45 | lf_spark = SparkNLPLabelingFunction(name="my_lf_spark", f=has_person_mention) 46 | self.assertEqual(lf(SimpleNamespace(num=8, text="Jane Doe acted well.")), 0) 47 | self._run_lf(lf_spark) 48 | -------------------------------------------------------------------------------- /test/labeling/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/model/__init__.py -------------------------------------------------------------------------------- /test/labeling/model/test_baseline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from snorkel.labeling.model import MajorityClassVoter, MajorityLabelVoter, RandomVoter 6 | 7 | 8 | class BaselineModelTest(unittest.TestCase): 9 | def test_random_vote(self): 10 | L = np.array([[0, 1, 0], [-1, 3, 2], [2, -1, -1], [0, 1, 1]]) 11 | rand_voter = RandomVoter() 12 | Y_p = rand_voter.predict_proba(L) 13 | self.assertLessEqual(Y_p.max(), 1.0) 14 | self.assertGreaterEqual(Y_p.min(), 0.0) 15 | np.testing.assert_array_almost_equal( 16 | np.sum(Y_p, axis=1), np.ones(np.shape(L)[0]) 17 | ) 18 | 19 | def test_majority_class_vote(self): 20 | L = np.array([[0, 1, 0], [1, 1, 0], [1, 1, 0], [-1, -1, 1]]) 21 | mc_voter = MajorityClassVoter() 22 | mc_voter.fit(balance=np.array([0.8, 0.2])) 23 | Y_p = mc_voter.predict_proba(L) 24 | 25 | Y_p_true = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) 26 | np.testing.assert_array_almost_equal(Y_p, Y_p_true) 27 | 28 | def test_majority_label_vote(self): 29 | L = np.array([[0, 1, 0], [0, 1, 0], [1, 0, 0], [-1, -1, 1]]) 30 | ml_voter = MajorityLabelVoter() 31 | Y_p = ml_voter.predict_proba(L) 32 | 33 | Y_p_true = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]]) 34 | np.testing.assert_array_almost_equal(Y_p, Y_p_true) 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /test/labeling/model/test_logger.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snorkel.labeling.model.logger import Logger 4 | 5 | 6 | class LoggerTest(unittest.TestCase): 7 | def test_basic(self): 8 | metrics_dict = {"train/loss": 0.01} 9 | logger = Logger(log_freq=1) 10 | logger.log(metrics_dict) 11 | 12 | metrics_dict = {"train/message": "well done!"} 13 | logger = Logger(log_freq=1) 14 | logger.log(metrics_dict) 15 | 16 | def test_bad_metrics_dict(self): 17 | bad_metrics_dict = {"task1/slice1/train/loss": 0.05} 18 | logger = Logger(log_freq=1) 19 | self.assertRaises(Exception, logger.log, bad_metrics_dict) 20 | 21 | def test_valid_metrics_dict(self): 22 | mtl_metrics_dict = {"task1/valid/loss": 0.05} 23 | logger = Logger(log_freq=1) 24 | logger.log(mtl_metrics_dict) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /test/labeling/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/preprocess/__init__.py -------------------------------------------------------------------------------- /test/labeling/preprocess/test_nlp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | 4 | from snorkel.preprocess.nlp import SpacyPreprocessor 5 | 6 | 7 | class TestSpacyPreprocessor(unittest.TestCase): 8 | def test_spacy_preprocessor(self) -> None: 9 | x = SimpleNamespace(text="Jane plays soccer.") 10 | preprocessor = SpacyPreprocessor("text", "doc") 11 | x_preprocessed = preprocessor(x) 12 | assert x_preprocessed is not None 13 | self.assertEqual(len(x_preprocessed.doc), 4) 14 | token = x_preprocessed.doc[0] 15 | self.assertEqual(token.text, "Jane") 16 | self.assertEqual(token.pos_, "PROPN") 17 | -------------------------------------------------------------------------------- /test/labeling/test_convergence.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | import torch 8 | 9 | from snorkel.labeling import LabelingFunction, PandasLFApplier, labeling_function 10 | from snorkel.labeling.model import LabelModel 11 | from snorkel.preprocess import preprocessor 12 | from snorkel.types import DataPoint 13 | 14 | 15 | def create_data(n: int) -> pd.DataFrame: 16 | """Create random pairs x1, x2 in [-1., 1.] with label x1 > x2 + 0.25.""" 17 | X = np.random.random((n, 2)) * 2 - 1 18 | Y = (X[:, 0] > X[:, 1] + 0.25).astype(int) 19 | 20 | df = pd.DataFrame( 21 | {"x0": np.random.randint(0, 1000, n), "x1": X[:, 0], "x2": X[:, 1], "y": Y} 22 | ) 23 | return df 24 | 25 | 26 | def get_positive_labeling_function(divisor: int) -> LabelingFunction: 27 | """Get LabelingFunction that abstains unless x0 is divisible by divisor.""" 28 | 29 | def f(x): 30 | return 1 if x.x0 % divisor == 0 and x.x1 > x.x2 + 0.25 else -1 31 | 32 | return LabelingFunction(f"lf_pos_{divisor}", f) 33 | 34 | 35 | def get_negative_labeling_function(divisor: int) -> LabelingFunction: 36 | """Get LabelingFunction that abstains unless x0 is divisible by divisor.""" 37 | 38 | def f(x): 39 | return 0 if x.x0 % divisor == 0 and x.x1 <= x.x2 + 0.25 else -1 40 | 41 | return LabelingFunction(f"lf_neg_{divisor}", f) 42 | 43 | 44 | @preprocessor() 45 | def copy_features(x: DataPoint) -> DataPoint: 46 | """Compute x2 + 0.25 for direct comparison to x1.""" 47 | x.x3 = x.x2 + 0.25 48 | return x 49 | 50 | 51 | @labeling_function(pre=[copy_features], resources=dict(divisor=3)) 52 | def f(x: DataPoint, divisor: int) -> int: 53 | # Abstain unless x0 is divisible by divisor. 54 | return 0 if x.x0 % divisor == 1 and x.x1 > x.x3 else -1 55 | 56 | 57 | class LabelingConvergenceTest(unittest.TestCase): 58 | @classmethod 59 | def setUpClass(cls): 60 | # Ensure deterministic runs 61 | random.seed(123) 62 | np.random.seed(123) 63 | torch.manual_seed(123) 64 | 65 | # Create raw data 66 | cls.N_TRAIN = 1500 67 | 68 | cls.cardinality = 2 69 | cls.df_train = create_data(cls.N_TRAIN) 70 | 71 | @pytest.mark.complex 72 | def test_labeling_convergence(self) -> None: 73 | """Test convergence of end to end labeling pipeline.""" 74 | # Apply LFs 75 | labeling_functions = ( 76 | [f] 77 | + [get_positive_labeling_function(divisor) for divisor in range(2, 9)] 78 | + [get_negative_labeling_function(divisor) for divisor in range(2, 9)] 79 | ) 80 | applier = PandasLFApplier(labeling_functions) 81 | L_train = applier.apply(self.df_train, progress_bar=False) 82 | 83 | self.assertEqual(L_train.shape, (self.N_TRAIN, len(labeling_functions))) 84 | 85 | # Train LabelModel 86 | label_model = LabelModel(cardinality=self.cardinality, verbose=False) 87 | label_model.fit(L_train, n_epochs=100, lr=0.01, l2=0.0) 88 | Y_lm = label_model.predict_proba(L_train).argmax(axis=1) 89 | Y = self.df_train.y 90 | err = np.where(Y != Y_lm, 1, 0).sum() / self.N_TRAIN 91 | self.assertLess(err, 0.06) 92 | 93 | 94 | if __name__ == "__main__": 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /test/labeling/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from snorkel.labeling import filter_unlabeled_dataframe 7 | 8 | 9 | class TestAnalysis(unittest.TestCase): 10 | def test_filter_unlabeled_dataframe(self) -> None: 11 | X = pd.DataFrame(dict(A=["x", "y", "z"], B=[1, 2, 3])) 12 | y = np.array( 13 | [[0.25, 0.25, 0.25, 0.25], [1.0, 0.0, 0.0, 0.0], [0.2, 0.3, 0.5, 0.0]] 14 | ) 15 | L = np.array([[0, 1, -1], [-1, -1, -1], [1, 1, 0]]) 16 | X_filtered, y_filtered = filter_unlabeled_dataframe(X, y, L) 17 | np.array_equal(X_filtered.values, np.array([["x", 1], ["z", 3]])) 18 | np.testing.assert_array_almost_equal( 19 | y_filtered, np.array([[0.25, 0.25, 0.25, 0.25], [0.2, 0.3, 0.5, 0.0]]) 20 | ) 21 | -------------------------------------------------------------------------------- /test/map/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/map/__init__.py -------------------------------------------------------------------------------- /test/slicing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/__init__.py -------------------------------------------------------------------------------- /test/slicing/apply/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/apply/__init__.py -------------------------------------------------------------------------------- /test/slicing/apply/test_sf_applier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | from typing import List 4 | 5 | from snorkel.preprocess import preprocessor 6 | from snorkel.slicing import SFApplier, slicing_function 7 | from snorkel.types import DataPoint 8 | 9 | 10 | @preprocessor() 11 | def square(x: DataPoint) -> DataPoint: 12 | x.num_squared = x.num**2 13 | return x 14 | 15 | 16 | class SquareHitTracker: 17 | def __init__(self): 18 | self.n_hits = 0 19 | 20 | def __call__(self, x: float) -> float: 21 | self.n_hits += 1 22 | return x**2 23 | 24 | 25 | @slicing_function() 26 | def f(x: DataPoint) -> int: 27 | return x.num > 42 28 | 29 | 30 | @slicing_function(pre=[square]) 31 | def fp(x: DataPoint) -> int: 32 | return x.num_squared > 42 33 | 34 | 35 | @slicing_function(resources=dict(db=[3, 6, 9])) 36 | def g(x: DataPoint, db: List[int]) -> int: 37 | return x.num in db 38 | 39 | 40 | DATA = [3, 43, 12, 9, 3] 41 | S_EXPECTED = {"f": [0, 1, 0, 0, 0], "g": [1, 0, 0, 1, 1]} 42 | S_PREPROCESS_EXPECTED = {"f": [0, 1, 0, 0, 0], "fp": [0, 1, 1, 1, 0]} 43 | 44 | 45 | class TestSFApplier(unittest.TestCase): 46 | def test_sf_applier(self) -> None: 47 | data_points = [SimpleNamespace(num=num) for num in DATA] 48 | applier = SFApplier([f, g]) 49 | S = applier.apply(data_points, progress_bar=False) 50 | self.assertEqual(S["f"].tolist(), S_EXPECTED["f"]) 51 | self.assertEqual(S["g"].tolist(), S_EXPECTED["g"]) 52 | S = applier.apply(data_points, progress_bar=True) 53 | self.assertEqual(S["f"].tolist(), S_EXPECTED["f"]) 54 | self.assertEqual(S["g"].tolist(), S_EXPECTED["g"]) 55 | 56 | def test_sf_applier_preprocessor(self) -> None: 57 | data_points = [SimpleNamespace(num=num) for num in DATA] 58 | applier = SFApplier([f, fp]) 59 | S = applier.apply(data_points, progress_bar=False) 60 | self.assertEqual(S["f"].tolist(), S_PREPROCESS_EXPECTED["f"]) 61 | self.assertEqual(S["fp"].tolist(), S_PREPROCESS_EXPECTED["fp"]) 62 | -------------------------------------------------------------------------------- /test/slicing/sf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/sf/__init__.py -------------------------------------------------------------------------------- /test/slicing/sf/test_core.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | 4 | from snorkel.slicing import SlicingFunction, slicing_function 5 | 6 | 7 | class TestSlicingFunction(unittest.TestCase): 8 | def _run_sf(self, sf: SlicingFunction) -> None: 9 | x_43 = SimpleNamespace(num=43) 10 | x_19 = SimpleNamespace(num=19) 11 | self.assertEqual(sf(x_43), True) 12 | self.assertEqual(sf(x_19), False) 13 | 14 | def _run_sf_raise(self, sf: SlicingFunction) -> None: 15 | x_none = SimpleNamespace(num=None) 16 | with self.assertRaises(TypeError): 17 | sf(x_none) 18 | 19 | def test_slicing_function_decorator(self) -> None: 20 | @slicing_function() 21 | def sf(x) -> int: 22 | return x.num > 42 23 | 24 | self.assertIsInstance(sf, SlicingFunction) 25 | self.assertEqual(sf.name, "sf") 26 | self._run_sf(sf) 27 | self._run_sf_raise(sf) 28 | 29 | def test_slicing_function_decorator_no_parens(self) -> None: 30 | with self.assertRaisesRegex(ValueError, "missing parentheses"): 31 | 32 | @slicing_function 33 | def sf(x) -> int: 34 | return 0 if x.num > 42 else -1 35 | -------------------------------------------------------------------------------- /test/slicing/sf/test_nlp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from types import SimpleNamespace 3 | 4 | from snorkel.preprocess import preprocessor 5 | from snorkel.slicing.sf.nlp import NLPSlicingFunction, nlp_slicing_function 6 | from snorkel.types import DataPoint 7 | 8 | 9 | @preprocessor() 10 | def combine_text(x: DataPoint) -> DataPoint: 11 | x.text = f"{x.title} {x.article}" 12 | return x 13 | 14 | 15 | def has_person_mention(x: DataPoint) -> int: 16 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 17 | return 0 if len(person_ents) > 0 else -1 18 | 19 | 20 | class TestNLPSlicingFunction(unittest.TestCase): 21 | def _run_sf(self, sf: NLPSlicingFunction) -> None: 22 | x = SimpleNamespace( 23 | num=8, title="Great film!", article="The movie is really great!" 24 | ) 25 | self.assertEqual(sf(x), -1) 26 | x = SimpleNamespace(num=8, title="Nice movie!", article="Jane Doe acted well.") 27 | self.assertEqual(sf(x), 0) 28 | 29 | def test_nlp_slicing_function(self) -> None: 30 | sf = NLPSlicingFunction(name="my_sf", f=has_person_mention, pre=[combine_text]) 31 | self._run_sf(sf) 32 | 33 | def test_nlp_slicing_function_decorator(self) -> None: 34 | @nlp_slicing_function(pre=[combine_text]) 35 | def has_person_mention(x: DataPoint) -> int: 36 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"] 37 | return 0 if len(person_ents) > 0 else -1 38 | 39 | self.assertIsInstance(has_person_mention, NLPSlicingFunction) 40 | self.assertEqual(has_person_mention.name, "has_person_mention") 41 | self._run_sf(has_person_mention) 42 | -------------------------------------------------------------------------------- /test/slicing/test_monitor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pandas as pd 4 | 5 | from snorkel.slicing import slicing_function 6 | from snorkel.slicing.monitor import slice_dataframe 7 | 8 | DATA = [5, 10, 19, 22, 25] 9 | 10 | 11 | @slicing_function() 12 | def sf(x): 13 | return x.num < 20 14 | 15 | 16 | class PandasSlicerTest(unittest.TestCase): 17 | @classmethod 18 | def setUpClass(cls): 19 | cls.df = pd.DataFrame(dict(num=DATA)) 20 | 21 | def test_slice(self): 22 | self.assertEqual(len(self.df), 5) 23 | 24 | # Should return a subset 25 | sliced_df = slice_dataframe(self.df, sf) 26 | self.assertEqual(len(sliced_df), 3) 27 | -------------------------------------------------------------------------------- /test/slicing/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | 7 | from snorkel.classification import DictDataLoader, DictDataset, Operation, Task 8 | from snorkel.slicing import ( 9 | PandasSFApplier, 10 | add_slice_labels, 11 | convert_to_slice_tasks, 12 | slicing_function, 13 | ) 14 | 15 | 16 | @slicing_function() 17 | def f(x): 18 | return x.val < 0.25 19 | 20 | 21 | class UtilsTest(unittest.TestCase): 22 | def test_add_slice_labels(self): 23 | # Create dummy data 24 | # Given slicing function f(), we expect the first two entries to be active 25 | x = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5]) 26 | y = torch.Tensor([0, 1, 1, 0, 1]).long() 27 | dataset = DictDataset( 28 | name="TestData", split="train", X_dict={"data": x}, Y_dict={"TestTask": y} 29 | ) 30 | 31 | # Ensure that we start with 1 labelset 32 | self.assertEqual(len(dataset.Y_dict), 1) 33 | 34 | # Apply SFs with PandasSFApplier 35 | df = pd.DataFrame({"val": x, "y": y}) 36 | slicing_functions = [f] 37 | applier = PandasSFApplier(slicing_functions) 38 | S = applier.apply(df, progress_bar=False) 39 | 40 | dataloader = DictDataLoader(dataset) 41 | 42 | dummy_task = create_dummy_task(task_name="TestTask") 43 | add_slice_labels(dataloader, dummy_task, S) 44 | 45 | # Ensure that all the fields are present 46 | labelsets = dataloader.dataset.Y_dict 47 | self.assertIn("TestTask", labelsets) 48 | self.assertIn("TestTask_slice:base_ind", labelsets) 49 | self.assertIn("TestTask_slice:base_pred", labelsets) 50 | self.assertIn("TestTask_slice:f_ind", labelsets) 51 | self.assertIn("TestTask_slice:f_pred", labelsets) 52 | self.assertEqual(len(labelsets), 5) 53 | 54 | # Ensure "ind" contains mask 55 | self.assertEqual( 56 | labelsets["TestTask_slice:f_ind"].numpy().tolist(), [1, 1, 0, 0, 0] 57 | ) 58 | self.assertEqual( 59 | labelsets["TestTask_slice:base_ind"].numpy().tolist(), [1, 1, 1, 1, 1] 60 | ) 61 | 62 | # Ensure "pred" contains masked elements 63 | self.assertEqual( 64 | labelsets["TestTask_slice:f_pred"].numpy().tolist(), [0, 1, -1, -1, -1] 65 | ) 66 | self.assertEqual( 67 | labelsets["TestTask_slice:base_pred"].numpy().tolist(), [0, 1, 1, 0, 1] 68 | ) 69 | self.assertEqual(labelsets["TestTask"].numpy().tolist(), [0, 1, 1, 0, 1]) 70 | 71 | def test_convert_to_slice_tasks(self): 72 | task_name = "TestTask" 73 | task = create_dummy_task(task_name) 74 | 75 | slice_names = ["slice_a", "slice_b", "slice_c"] 76 | slice_tasks = convert_to_slice_tasks(task, slice_names) 77 | 78 | slice_task_names = [t.name for t in slice_tasks] 79 | # Check for original base task 80 | self.assertIn(task_name, slice_task_names) 81 | 82 | # Check for 2 tasks (pred + ind) per slice, accounting for base slice 83 | for slice_name in slice_names + ["base"]: 84 | self.assertIn(f"{task_name}_slice:{slice_name}_pred", slice_task_names) 85 | self.assertIn(f"{task_name}_slice:{slice_name}_ind", slice_task_names) 86 | 87 | self.assertEqual(len(slice_tasks), 2 * (len(slice_names) + 1) + 1) 88 | 89 | # Test that modules share the same body flow operations 90 | # NOTE: Use "is" comparison to check object equality 91 | body_flow = task.op_sequence[:-1] 92 | ind_and_pred_tasks = [ 93 | t for t in slice_tasks if "_ind" in t.name or "_pred" in t.name 94 | ] 95 | for op in body_flow: 96 | for slice_task in ind_and_pred_tasks: 97 | self.assertTrue( 98 | slice_task.module_pool[op.module_name] 99 | is task.module_pool[op.module_name] 100 | ) 101 | 102 | # Test that pred tasks share the same predictor head 103 | pred_tasks = [t for t in slice_tasks if "_pred" in t.name] 104 | predictor_head_name = pred_tasks[0].op_sequence[-1].module_name 105 | shared_predictor_head = pred_tasks[0].module_pool[predictor_head_name] 106 | for pred_task in pred_tasks[1:]: 107 | self.assertTrue( 108 | pred_task.module_pool[predictor_head_name] is shared_predictor_head 109 | ) 110 | 111 | 112 | def create_dummy_task(task_name): 113 | # Create dummy task 114 | module_pool = nn.ModuleDict( 115 | {"linear1": nn.Linear(2, 10), "linear2": nn.Linear(10, 2)} 116 | ) 117 | 118 | op_sequence = [ 119 | Operation(name="encoder", module_name="linear1", inputs=["_input_"]), 120 | Operation(name="prediction_head", module_name="linear2", inputs=["encoder"]), 121 | ] 122 | 123 | task = Task(name=task_name, module_pool=module_pool, op_sequence=op_sequence) 124 | return task 125 | -------------------------------------------------------------------------------- /test/synthetic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/synthetic/__init__.py -------------------------------------------------------------------------------- /test/synthetic/test_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | from snorkel.labeling import LFAnalysis 7 | from snorkel.synthetic.synthetic_data import generate_simple_label_matrix 8 | 9 | 10 | class TestGenerateSimpleLabelMatrix(unittest.TestCase): 11 | """Testing the generate_simple_label_matrix function.""" 12 | 13 | def setUp(self) -> None: 14 | """Set constants for the tests.""" 15 | self.m = 10 # Number of LFs 16 | self.n = 1000 # Number of data points 17 | 18 | def _test_generate_L(self, k: int, decimal: Optional[int] = 2) -> None: 19 | """Test generated label matrix L for consistency with P, Y. 20 | 21 | This tests for consistency between the true conditional LF probabilities, P, 22 | and the empirical ones computed from L and Y, where P, L, and Y are generated 23 | by the generate_simple_label_matrix function. 24 | 25 | Parameters 26 | ---------- 27 | k 28 | Cardinality 29 | decimal 30 | Number of decimals to check element-wise error, err < 1.5 * 10**(-decimal) 31 | """ 32 | np.random.seed(123) 33 | P, Y, L = generate_simple_label_matrix(self.n, self.m, k) 34 | P_emp = LFAnalysis(L).lf_empirical_probs(Y, k=k) 35 | np.testing.assert_array_almost_equal(P, P_emp, decimal=decimal) 36 | 37 | def test_generate_L(self) -> None: 38 | """Test the generated dataset for consistency.""" 39 | self._test_generate_L(2, decimal=1) 40 | 41 | def test_generate_L_multiclass(self) -> None: 42 | """Test the generated dataset for consistency with cardinality=3.""" 43 | self._test_generate_L(3, decimal=1) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/utils/__init__.py -------------------------------------------------------------------------------- /test/utils/test_config_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snorkel.types import Config 4 | from snorkel.utils.config_utils import merge_config 5 | 6 | 7 | class FooConfig(Config): 8 | a: float = 0.5 9 | 10 | 11 | class BarConfig(Config): 12 | a: int = 1 13 | foo_config: FooConfig = FooConfig() # type: ignore 14 | 15 | 16 | class UtilsTest(unittest.TestCase): 17 | def test_merge_config(self): 18 | config_updates = {"a": 2, "foo_config": {"a": 0.75}} 19 | bar_config = merge_config(BarConfig(), config_updates) 20 | self.assertEqual(bar_config.a, 2) 21 | self.assertEqual(bar_config.foo_config.a, 0.75) 22 | -------------------------------------------------------------------------------- /test/utils/test_core.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from snorkel.utils import ( 6 | filter_labels, 7 | preds_to_probs, 8 | probs_to_preds, 9 | to_int_label_array, 10 | ) 11 | 12 | PROBS = np.array([[0.1, 0.9], [0.7, 0.3]]) 13 | PREDS = np.array([1, 0]) 14 | PREDS_ROUND = np.array([[0, 1], [1, 0]]) 15 | 16 | 17 | class UtilsTest(unittest.TestCase): 18 | def test_to_int_label_array(self): 19 | X = np.array([[1], [0], [2.0]]) 20 | Y_expected = np.array([1, 0, 2]) 21 | Y = to_int_label_array(X, flatten_vector=True) 22 | np.testing.assert_array_equal(Y, Y_expected) 23 | 24 | Y = to_int_label_array(np.array([[1]]), flatten_vector=True) 25 | Y_expected = np.array([1]) 26 | np.testing.assert_array_equal(Y, Y_expected) 27 | 28 | Y = to_int_label_array(X, flatten_vector=False) 29 | Y_expected = np.array([[1], [0], [2]]) 30 | np.testing.assert_array_equal(Y, Y_expected) 31 | 32 | X = np.array([[1], [0], [2.1]]) 33 | with self.assertRaisesRegex(ValueError, "non-integer value"): 34 | to_int_label_array(X) 35 | 36 | X = np.array([[1, 0], [0, 1]]) 37 | with self.assertRaisesRegex(ValueError, "1d np.array"): 38 | to_int_label_array(X, flatten_vector=True) 39 | 40 | def test_preds_to_probs(self): 41 | np.testing.assert_array_equal(preds_to_probs(PREDS, 2), PREDS_ROUND) 42 | 43 | def test_probs_to_preds(self): 44 | np.testing.assert_array_equal(probs_to_preds(PROBS), PREDS) 45 | 46 | # abtains with ties 47 | probs = np.array([[0.33, 0.33, 0.33]]) 48 | preds = probs_to_preds(probs, tie_break_policy="abstain") 49 | true_preds = np.array([-1]) 50 | np.testing.assert_array_equal(preds, true_preds) 51 | 52 | # true random with ties 53 | probs = np.array([[0.33, 0.33, 0.33]]) 54 | random_preds = [] 55 | for seed in range(10): 56 | preds = probs_to_preds(probs, tie_break_policy="true-random") 57 | random_preds.append(preds[0]) 58 | 59 | # check predicted labels within range 60 | self.assertLessEqual(max(random_preds), 2) 61 | self.assertGreaterEqual(min(random_preds), 0) 62 | 63 | # deterministic random with ties 64 | probs = np.array( 65 | [[0.33, 0.33, 0.33], [0.0, 0.5, 0.5], [0.33, 0.33, 0.33], [0.5, 0.5, 0]] 66 | ) 67 | random_preds = [] 68 | for _ in range(10): 69 | preds = probs_to_preds(probs, tie_break_policy="random") 70 | random_preds.append(preds) 71 | 72 | # check labels are same across seeds 73 | for i in range(len(random_preds) - 1): 74 | np.testing.assert_array_equal(random_preds[i], random_preds[i + 1]) 75 | 76 | # check predicted labels within range (only one instance since should all be same) 77 | self.assertLessEqual(max(random_preds[0]), 2) 78 | self.assertGreaterEqual(min(random_preds[0]), 0) 79 | 80 | # check invalid policy 81 | with self.assertRaisesRegex(ValueError, "policy not recognized"): 82 | preds = probs_to_preds(probs, tie_break_policy="negative") 83 | 84 | # check invalid input 85 | with self.assertRaisesRegex(ValueError, "probs must have probabilities"): 86 | preds = probs_to_preds(np.array([[0.33], [0.33]])) 87 | 88 | def test_filter_labels(self): 89 | golds = np.array([-1, 0, 0, 1, 1]) 90 | preds = np.array([0, 0, 1, 1, -1]) 91 | filtered = filter_labels( 92 | label_dict={"golds": golds, "preds": preds}, 93 | filter_dict={"golds": [-1], "preds": [-1]}, 94 | ) 95 | np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1])) 96 | np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1])) 97 | 98 | def test_filter_labels_probs(self): 99 | golds = np.array([-1, 0, 0, 1, 1]) 100 | preds = np.array([0, 0, 1, 1, -1]) 101 | probs = np.array([[0.8, 0.2], [0.8, 0.2], [0.2, 0.8], [0.2, 0.8], [0.5, 0.5]]) 102 | filtered = filter_labels( 103 | label_dict={"golds": golds, "preds": preds, "probs": probs}, 104 | filter_dict={"golds": [-1], "preds": [-1]}, 105 | ) 106 | np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1])) 107 | np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1])) 108 | 109 | 110 | if __name__ == "__main__": 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /test/utils/test_data_operators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from snorkel.utils.data_operators import check_unique_names 4 | 5 | 6 | class DataOperatorsTest(unittest.TestCase): 7 | def test_check_unique_names(self): 8 | check_unique_names(["alice", "bob", "chuck"]) 9 | with self.assertRaisesRegex(ValueError, "3 operators with name c"): 10 | check_unique_names(["a", "a", "b", "c", "c", "c"]) 11 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | skip_missing_interpreters = true 3 | envlist = 4 | py11, 5 | type, 6 | check, 7 | doctest, 8 | isolated_build = true 9 | 10 | [testenv] 11 | description = run the test driver with {basepython} 12 | # Note: in order to allow dependency library install reuse 13 | # on CI, we allow overriding the default envdir 14 | # (specified as `{toxworkdir}/{envname}`) by setting the 15 | # environment variable `TOX_INSTALL_DIR`. We avoid 16 | # collision with the already-used `TOX_ENV_DIR`. 17 | envdir = {env:TOX_INSTALL_DIR:{toxworkdir}/{envname}} 18 | # Note: we try to keep the deps the same for all tests 19 | # running on CI so that we skip reinstalling dependency 20 | # libraries for all testenvs 21 | deps = 22 | -rrequirements.txt 23 | -rrequirements-pyspark.txt 24 | commands_pre = python -m spacy download en_core_web_sm 25 | commands = python -m pytest {posargs:-m 'not spark and not complex'} 26 | 27 | [testenv:spark] 28 | description = run the test driver for spark tests with {basepython} 29 | passenv = JAVA_HOME 30 | commands = python -m pytest -m spark {posargs} 31 | 32 | [testenv:complex] 33 | description = run the test driver for integration tests with {basepython} 34 | commands = python -m pytest -m 'complex and not spark' {posargs} 35 | 36 | [testenv:doctest] 37 | description = run doctest 38 | skipsdist = true 39 | commands = python -m pytest --doctest-plus snorkel 40 | 41 | [testenv:check] 42 | description = check the code and doc style 43 | basepython = python3 44 | allowlist_externals = 45 | {toxinidir}/scripts/check_requirements.py 46 | {toxinidir}/scripts/sync_api_docs.py 47 | commands_pre = 48 | commands = 49 | isort -rc -c . 50 | black --check . 51 | flake8 . 52 | pydocstyle snorkel 53 | {toxinidir}/scripts/check_requirements.py 54 | {toxinidir}/scripts/sync_api_docs.py --check 55 | 56 | [testenv:type] 57 | description = run static type checking 58 | basepython = python3 59 | commands_pre = 60 | commands = mypy -p snorkel --disallow-untyped-defs --disallow-incomplete-defs --no-implicit-optional 61 | 62 | [testenv:coverage] 63 | description = run coverage checks 64 | basepython = python3 65 | # Note: make sure this matches testenv since this is used 66 | # on CI as the default unit test runner 67 | commands = python -m pytest -m 'not spark and not complex' --cov=snorkel 68 | 69 | [testenv:fix] 70 | description = run code stylers 71 | basepython = python3 72 | usedevelop = True 73 | commands_pre = 74 | commands = 75 | isort -rc . 76 | black . 77 | 78 | [testenv:doc] 79 | description = build docs 80 | basepython = python3 81 | skipsdist = True 82 | commands_pre = python -m pip install -U -r docs/requirements-doc.txt 83 | commands = 84 | rm -rf docs/_build 85 | rm -rf docs/packages/_autosummary 86 | make -C docs/ html 87 | {toxinidir}/scripts/sync_api_docs.py 88 | --------------------------------------------------------------------------------