├── .circleci
└── config.yml
├── .codecov.yaml
├── .coveragerc
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── pull_request_template.md
└── workflows
│ └── stale.yml
├── .gitignore
├── .readthedocs.yml
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── RELEASING.md
├── docs
├── Makefile
├── _static
│ └── octopus.png
├── conf.py
├── index.rst
├── packages.json
├── packages
│ ├── analysis.rst
│ ├── augmentation.rst
│ ├── classification.rst
│ ├── labeling.rst
│ ├── map.rst
│ ├── preprocess.rst
│ ├── slicing.rst
│ └── utils.rst
└── requirements-doc.txt
├── figs
├── ONR.jpg
├── WS_pipeline2.pdf
├── darpa.JPG
├── dp_neurips_2016.png
├── logo_01.png
├── mobilize_logo.png
├── moore_logo.png
├── nih_logo.png
├── user_logos.png
└── vldb2018_logo.png
├── pyproject.toml
├── requirements-pyspark.txt
├── requirements.txt
├── scripts
├── check_requirements.py
└── sync_api_docs.py
├── setup.cfg
├── setup.py
├── snorkel
├── __init__.py
├── analysis
│ ├── __init__.py
│ ├── error_analysis.py
│ ├── metrics.py
│ └── scorer.py
├── augmentation
│ ├── __init__.py
│ ├── apply
│ │ ├── __init__.py
│ │ ├── core.py
│ │ └── pandas.py
│ ├── policy
│ │ ├── __init__.py
│ │ ├── core.py
│ │ └── sampling.py
│ └── tf.py
├── classification
│ ├── __init__.py
│ ├── data.py
│ ├── loss.py
│ ├── multitask_classifier.py
│ ├── task.py
│ ├── training
│ │ ├── __init__.py
│ │ ├── loggers
│ │ │ ├── __init__.py
│ │ │ ├── checkpointer.py
│ │ │ ├── log_manager.py
│ │ │ ├── log_writer.py
│ │ │ └── tensorboard_writer.py
│ │ ├── schedulers
│ │ │ ├── __init__.py
│ │ │ ├── scheduler.py
│ │ │ ├── sequential_scheduler.py
│ │ │ └── shuffled_scheduler.py
│ │ └── trainer.py
│ └── utils.py
├── contrib
│ ├── README.md
│ └── __init__.py
├── labeling
│ ├── __init__.py
│ ├── analysis.py
│ ├── apply
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── dask.py
│ │ ├── pandas.py
│ │ └── spark.py
│ ├── lf
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── nlp.py
│ │ └── nlp_spark.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── base_labeler.py
│ │ ├── baselines.py
│ │ ├── graph_utils.py
│ │ ├── label_model.py
│ │ └── logger.py
│ └── utils.py
├── map
│ ├── __init__.py
│ ├── core.py
│ └── spark.py
├── preprocess
│ ├── __init__.py
│ ├── core.py
│ ├── nlp.py
│ └── spark.py
├── slicing
│ ├── __init__.py
│ ├── apply
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── dask.py
│ │ └── spark.py
│ ├── modules
│ │ ├── __init__.py
│ │ └── slice_combiner.py
│ ├── monitor.py
│ ├── sf
│ │ ├── __init__.py
│ │ ├── core.py
│ │ └── nlp.py
│ ├── sliceaware_classifier.py
│ └── utils.py
├── synthetic
│ ├── __init__.py
│ └── synthetic_data.py
├── types
│ ├── __init__.py
│ ├── classifier.py
│ ├── data.py
│ └── hashing.py
├── utils
│ ├── __init__.py
│ ├── config_utils.py
│ ├── core.py
│ ├── data_operators.py
│ ├── lr_schedulers.py
│ └── optimizers.py
└── version.py
├── test
├── __init__.py
├── analysis
│ ├── test_error_analysis.py
│ ├── test_metrics.py
│ └── test_scorer.py
├── augmentation
│ ├── __init__.py
│ ├── apply
│ │ ├── __init__.py
│ │ └── test_tf_applier.py
│ └── policy
│ │ ├── __init__.py
│ │ ├── test_core.py
│ │ └── test_sampling.py
├── classification
│ ├── __init__.py
│ ├── test_classifier_convergence.py
│ ├── test_data.py
│ ├── test_loss.py
│ ├── test_multitask_classifier.py
│ ├── test_task.py
│ ├── test_utils.py
│ └── training
│ │ ├── loggers
│ │ ├── __init__.py
│ │ ├── test_checkpointer.py
│ │ ├── test_log_manager.py
│ │ ├── test_log_writer.py
│ │ └── test_tensorboard_writer.py
│ │ ├── schedulers
│ │ └── test_schedulers.py
│ │ └── test_trainer.py
├── labeling
│ ├── __init__.py
│ ├── apply
│ │ ├── __init__.py
│ │ ├── lf_applier_spark_test_script.py
│ │ ├── test_lf_applier.py
│ │ └── test_spark.py
│ ├── lf
│ │ ├── test_core.py
│ │ ├── test_nlp.py
│ │ └── test_nlp_spark.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── test_baseline.py
│ │ ├── test_label_model.py
│ │ └── test_logger.py
│ ├── preprocess
│ │ ├── __init__.py
│ │ └── test_nlp.py
│ ├── test_analysis.py
│ ├── test_convergence.py
│ └── test_utils.py
├── map
│ ├── __init__.py
│ ├── test_core.py
│ └── test_spark.py
├── slicing
│ ├── __init__.py
│ ├── apply
│ │ ├── __init__.py
│ │ └── test_sf_applier.py
│ ├── sf
│ │ ├── __init__.py
│ │ ├── test_core.py
│ │ └── test_nlp.py
│ ├── test_monitor.py
│ ├── test_slice_combiner.py
│ ├── test_sliceaware_classifier.py
│ └── test_utils.py
├── synthetic
│ ├── __init__.py
│ └── test_synthetic_data.py
└── utils
│ ├── __init__.py
│ ├── test_config_utils.py
│ ├── test_core.py
│ └── test_data_operators.py
└── tox.ini
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Use the latest 2.1 version of CircleCI pipeline process engine.
4 | version: 2.1
5 |
6 | orbs:
7 | python: circleci/python@1.2
8 |
9 | commands:
10 | setup_dependencies:
11 | description: "Install depenencies"
12 | parameters:
13 | after-deps:
14 | description: "Install dependenceis"
15 | type: steps
16 | default: []
17 | steps:
18 | - run:
19 | name: "Install open JDK"
20 | command: sudo add-apt-repository -y ppa:openjdk-r/ppa
21 | - run:
22 | name: "Install qq"
23 | command: sudo apt-get -qq update
24 | - run:
25 | name: "No install recommends for JDK"
26 | command: sudo apt-get install -y openjdk-8-jdk --no-install-recommends
27 | - run:
28 | name: "Run Java Alternatives install for JDK"
29 | command: sudo update-java-alternatives -s java-1.8.0-openjdk-amd64
30 | - run:
31 | name: "Run pip install setup tools and wheel"
32 | command: pip install -U pip setuptools wheel
33 | - run:
34 | name: "Install Tox"
35 | command: pip install -U tox==4.11.4
36 | - run:
37 | name: "Install Code Cov"
38 | command: pip install -U codecov
39 | - steps: << parameters.after-deps >>
40 |
41 | # We want to make sure we run this only on main branch, release, or when we make tags
42 | run_complex: &run_complex
43 | filters:
44 | branches:
45 | only:
46 | - main
47 | - /release-v.*/
48 | tags:
49 | only: /.*/
50 |
51 | jobs:
52 | Python311-Unit-Tests:
53 | docker:
54 | - image: cimg/python:3.11
55 | environment:
56 | TOXENV: coverage,doctest,type,check
57 | TOX_INSTALL_DIR: .env
58 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64
59 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python
60 |
61 | steps:
62 | - checkout
63 | - setup_dependencies
64 | - run:
65 | name: "Run Tox"
66 | command: tox
67 |
68 | Python311-Integration-Tests:
69 | docker:
70 | - image: cimg/python:3.11
71 | environment:
72 | TOXENV: complex,type,check
73 | TOX_INSTALL_DIR: .env
74 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64
75 |
76 | steps:
77 | - checkout
78 | - run:
79 | name: Setup python3
80 | command: |
81 | pyenv global 3.11.3 > /dev/null && activated=0 || activated=1
82 | if [[ $activated -ne 0 ]]; then
83 | for i in {1..6}; do
84 | pyenv install 3.11.3 && break || sleep $((2 ** $i))
85 | done
86 | pyenv global 3.11.3
87 | fi
88 | - setup_dependencies
89 | - run:
90 | name: "Run Tox"
91 | no_output_timeout: 60m
92 | command: |
93 | export PYTHONUNBUFFERED=1
94 | tox
95 |
96 | workflows:
97 | version: 2
98 |
99 | Integration-Tests:
100 | jobs:
101 | - Python311-Integration-Tests:
102 | <<: *run_complex
103 | Unit-Tests:
104 | jobs:
105 | - Python311-Unit-Tests
106 |
--------------------------------------------------------------------------------
/.codecov.yaml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | target: 95%
6 | patch:
7 | default:
8 | threshold: 2%
9 |
10 | comment:
11 | layout: "header, diff, flags, files"
12 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 | source = snorkel
4 |
5 | [report]
6 | exclude_lines =
7 | pragma: no cover
8 | raise NotImplementedError
9 | if __name__ == .__main__.:
10 | def __repr__
11 | ignore_errors = True
12 | omit =
13 | test/*
14 | *spark.py
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Let us know about a bug you found
4 | ---
5 |
6 | ## Issue description
7 |
8 | A clear and concise description of what the bug is.
9 |
10 | ## Code example/repro steps
11 |
12 | Please try to provide a minimal example to repro the bug.
13 | Error messages and stack traces are also helpful.
14 |
15 | ## Expected behavior
16 | A clear and concise description of what you expected to happen.
17 |
18 | ## Screenshots
19 | If applicable, add screenshots to help explain your problem.
20 | No screenshots of code!
21 |
22 | ## System info
23 |
24 | * How you installed Snorkel (conda, pip, source):
25 | * Build command you used (if compiling from source):
26 | * OS:
27 | * Python version:
28 | * Snorkel version:
29 | * Versions of any other relevant libraries:
30 |
31 | ## Additional context
32 | Add any other context about the problem here.
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Let us know about something new you want
4 | ---
5 |
6 | ## Is your feature request related to a problem? Please describe.
7 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
8 |
9 | ## Describe the solution you'd like
10 | A clear and concise description of what you want to happen.
11 |
12 | ## Describe alternatives you've considered
13 | A clear and concise description of any alternative solutions or features you've considered.
14 |
15 | ## Additional context
16 | Add any other context or screenshots about the feature request here.
17 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Description of proposed changes
2 |
3 | ## Related issue(s)
4 |
5 | Fixes # (issue)
6 |
7 | ## Test plan
8 |
9 | ## Checklist
10 |
11 | Need help on these? Just ask!
12 |
13 | * [ ] I have read the **CONTRIBUTING** document.
14 | * [ ] I have updated the documentation accordingly.
15 | * [ ] I have added tests to cover my changes.
16 | * [ ] I have run `tox -e complex` and/or `tox -e spark` if appropriate.
17 | * [ ] All new and existing tests passed.
18 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | name: Mark/close stale issues and pull requests
2 |
3 | on:
4 | schedule:
5 | - cron: "0 12 * * *"
6 |
7 | jobs:
8 | stale:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/stale@v1
14 | with:
15 | repo-token: ${{ secrets.GITHUB_TOKEN }}
16 | stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
17 | stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
18 | stale-issue-label: 'no-issue-activity'
19 | stale-pr-label: 'no-pr-activity'
20 | exempt-issue-label: 'no-stale'
21 | days-before-stale: 90
22 | days-before-close: 7
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | .pypirc
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 | docs/packages/_autosummary
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # MacOS
110 | .DS_Store
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | # Editors
131 | .vscode/
132 | .code-workspace*
133 |
134 | # Dask
135 | dask-worker-space/
136 |
137 | # nohup
138 | nohup.out
139 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | sphinx:
10 | configuration: docs/conf.py
11 |
12 | # Optionally set the version of Python and requirements required to build your docs
13 | python:
14 | version: 3.11
15 | install:
16 | - requirements: docs/requirements-doc.txt
17 | - method: pip
18 | path: .
19 | system_packages: true
20 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements.txt
4 |
--------------------------------------------------------------------------------
/RELEASING.md:
--------------------------------------------------------------------------------
1 | # Snorkel Release Guide
2 |
3 | ## Before You Start
4 |
5 | Make sure you have [PyPI](https://pypi.org) account with maintainer access to the Snorkel project.
6 | Create a .pypirc in your home directory.
7 | It should look like this:
8 |
9 | ```
10 | [distutils]
11 | index-servers =
12 | pypi
13 | pypitest
14 |
15 | [pypi]
16 | username=YOUR_USERNAME
17 | password=YOUR_PASSWORD
18 | ```
19 |
20 | Then run `chmod 600 ./.pypirc` so only you can read/write.
21 |
22 |
23 | ## Release Steps
24 |
25 | 1. Make sure you're in the top-level `snorkel` directory.
26 | 1. Make certain your branch is in sync with head:
27 |
28 | $ git pull origin main
29 |
30 | 1. Add a new changelog entry for the release.
31 |
32 | ## [0.9.0]
33 | ### [Breaking Changes]
34 | ### [Added]
35 | ### [Changed]
36 | ### [Deprecated]
37 | ### [Removed]
38 | Make sure `CHANGELOG.md` is up to date for the release: compare against PRs
39 | merged since the last release.
40 |
41 | 1. Update version to, e.g. 0.9.0 (remove the `+dev` label) in `snorkel/version.py`.
42 |
43 |
44 | 1. Commit these changes and create a PR:
45 |
46 | git checkout -b release-v0.9.0
47 | git add . -u
48 | git commit -m "[RELEASE]: v0.9.0"
49 | git push --set-upstream origin release-v0.9.0
50 |
51 | 1. Once the PR is approved, merge it and pull main locally.
52 |
53 | 1. Tag the release:
54 |
55 | git tag -a v0.9.0 -m "v0.9.0 release"
56 | git push origin v0.9.0
57 |
58 | 1. Build source & wheel distributions:
59 |
60 | rm -rf dist build # clean old builds & distributions
61 | python3 setup.py sdist # create a source distribution
62 | python3 setup.py bdist_wheel # create a universal wheel
63 |
64 | 1. Check that everything looks correct by installing the wheel locally and checking the version:
65 |
66 | python3 -m venv test_snorkel # create a virtualenv for testing
67 | source test_snorkel/bin/activate # activate virtualenv
68 | python3 -m pip install dist/snorkel-0.9.1-py3-none-any.whl
69 | python3 -c "import snorkel; print(snorkel.__version__)"
70 |
71 | 1. Publish to PyPI
72 |
73 | pip install twine # if not installed
74 | twine upload dist/* -r pypi
75 |
76 | 1. A PR is auto-submitted (this will take a few hours) on [`conda-forge/snorkel-feedstock`](https://github.com/conda-forge/snorkel-feedstock) to update the version.
77 | * A maintainer needs to accept and merge those changes.
78 |
79 | 1. Create a new release on Github.
80 | * Input the recently-created Tag Version: `v0.9.0`
81 | * Copy the release notes in `CHANGELOG.md` to the GitHub tag.
82 | * Attach the resulting binaries in (`dist/snorkel-x.x.x.*`) to the release.
83 | * Publish the release.
84 |
85 |
86 | 1. Update version to, e.g. 0.9.1+dev in `snorkel/version.py`.
87 |
88 | 1. Add a new changelog entry for the unreleased version in `CHANGELOG.md`:
89 |
90 | ## [Unreleased]
91 | ### [Breaking Changes]
92 | ### [Added]
93 | ### [Changed]
94 | ### [Deprecated]
95 | ### [Removed]
96 |
97 | 1. Commit these changes and create a PR:
98 |
99 | git checkout -b bump-v0.9.1+dev
100 | git add . -u
101 | git commit -m "[BUMP]: v0.9.1+dev"
102 | git push --set-upstream origin bump-v0.9.1+dev
103 |
104 |
105 | 1. Add the new tag to [the Snorkel project on ReadTheDocs](https://readthedocs.org/projects/snorkel),
106 | * Trigger a build for main to pull new tags.
107 | * Go to the "Versions" tab, and "Activate" the new tag.
108 | * Go to Admin/Advanced to set this tag as the new default version.
109 | * In "Overview", make sure a build is triggered:
110 | * For the tag `v0.9.1`
111 | * For `latest`
112 |
113 |
114 | ## Credit
115 | * [AllenNLP](https://github.com/allenai/allennlp/blob/master/setup.py)
116 | * [Altair](https://github.com/altair-viz/altair/blob/master/RELEASING.md)
117 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/octopus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/docs/_static/octopus.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Snorkel documentation master file, created by
2 | sphinx-quickstart on Fri Jul 12 17:34:20 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Snorkel's API Documentation
7 | ===========================
8 |
9 | If you're looking for technical details on Snorkel's API,
10 | you're in the right place.
11 |
12 | For more narrative walkthroughs of Snorkel fundamentals or
13 | example use cases, check out our `homepage `_
14 | and our `tutorials repo `_.
15 |
16 | .. toctree::
17 | :maxdepth: 1
18 | :caption: Package Reference
19 |
20 | packages/analysis
21 | packages/augmentation
22 | packages/classification
23 | packages/labeling
24 | packages/map
25 | packages/preprocess
26 | packages/slicing
27 | packages/utils
28 |
--------------------------------------------------------------------------------
/docs/packages.json:
--------------------------------------------------------------------------------
1 | {
2 | "packages": [
3 | "analysis",
4 | "augmentation",
5 | "classification",
6 | "labeling",
7 | "map",
8 | "preprocess",
9 | "slicing",
10 | "utils"
11 | ],
12 | "extra_members": {
13 | "labeling": [
14 | "apply.dask.DaskLFApplier",
15 | "apply.dask.PandasParallelLFApplier",
16 | "apply.spark.SparkLFApplier",
17 | "lf.nlp.NLPLabelingFunction",
18 | "lf.nlp.nlp_labeling_function",
19 | "lf.nlp_spark.SparkNLPLabelingFunction",
20 | "lf.nlp_spark.spark_nlp_labeling_function",
21 | "model.baselines.MajorityClassVoter",
22 | "model.baselines.MajorityLabelVoter",
23 | "model.baselines.RandomVoter",
24 | "model.label_model.LabelModel"
25 | ],
26 | "map": [
27 | "spark.make_spark_mapper"
28 | ],
29 | "preprocess": [
30 | "nlp.SpacyPreprocessor",
31 | "spark.make_spark_preprocessor"
32 | ],
33 | "slicing": [
34 | "apply.dask.DaskSFApplier",
35 | "apply.dask.PandasParallelSFApplier",
36 | "apply.spark.SparkSFApplier",
37 | "sf.nlp.NLPSlicingFunction",
38 | "sf.nlp.nlp_slicing_function"
39 | ]
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/docs/packages/analysis.rst:
--------------------------------------------------------------------------------
1 | Snorkel Analysis Package
2 | ------------------------
3 |
4 | Generic model analysis utilities shared across Snorkel.
5 |
6 | .. currentmodule:: snorkel.analysis
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/analysis/
10 | :nosignatures:
11 |
12 | Scorer
13 | get_label_buckets
14 | get_label_instances
15 | metric_score
16 |
--------------------------------------------------------------------------------
/docs/packages/augmentation.rst:
--------------------------------------------------------------------------------
1 | Snorkel Augmentation Package
2 | ----------------------------
3 |
4 | Programmatic data set augmentation: TF creation and data generation utilities.
5 |
6 | .. currentmodule:: snorkel.augmentation
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/augmentation/
10 | :nosignatures:
11 |
12 | ApplyAllPolicy
13 | ApplyEachPolicy
14 | ApplyOnePolicy
15 | MeanFieldPolicy
16 | PandasTFApplier
17 | RandomPolicy
18 | TFApplier
19 | TransformationFunction
20 | transformation_function
21 |
--------------------------------------------------------------------------------
/docs/packages/classification.rst:
--------------------------------------------------------------------------------
1 | Snorkel Classification Package
2 | ------------------------------
3 |
4 | PyTorch-based multi-task learning framework for discriminative modeling.
5 |
6 | .. currentmodule:: snorkel.classification
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/classification/
10 | :nosignatures:
11 |
12 | Checkpointer
13 | CheckpointerConfig
14 | DictDataLoader
15 | DictDataset
16 | LogManager
17 | LogManagerConfig
18 | LogWriter
19 | LogWriterConfig
20 | MultitaskClassifier
21 | Operation
22 | Task
23 | TensorBoardWriter
24 | Trainer
25 | cross_entropy_with_probs
26 |
--------------------------------------------------------------------------------
/docs/packages/labeling.rst:
--------------------------------------------------------------------------------
1 | Snorkel Labeling Package
2 | ------------------------
3 |
4 | Programmatic data set labeling: LF creation, models, and analysis utilities.
5 |
6 | .. currentmodule:: snorkel.labeling
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/labeling/
10 | :nosignatures:
11 |
12 | apply.dask.DaskLFApplier
13 | LFAnalysis
14 | LFApplier
15 | model.label_model.LabelModel
16 | LabelingFunction
17 | model.baselines.MajorityClassVoter
18 | model.baselines.MajorityLabelVoter
19 | lf.nlp.NLPLabelingFunction
20 | PandasLFApplier
21 | apply.dask.PandasParallelLFApplier
22 | model.baselines.RandomVoter
23 | apply.spark.SparkLFApplier
24 | lf.nlp_spark.SparkNLPLabelingFunction
25 | filter_unlabeled_dataframe
26 | labeling_function
27 | lf.nlp.nlp_labeling_function
28 | lf.nlp_spark.spark_nlp_labeling_function
29 |
--------------------------------------------------------------------------------
/docs/packages/map.rst:
--------------------------------------------------------------------------------
1 | Snorkel Map Package
2 | -------------------
3 |
4 | Generic utilities for data point to data point operations.
5 |
6 | .. currentmodule:: snorkel.map
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/map/
10 | :nosignatures:
11 |
12 | BaseMapper
13 | LambdaMapper
14 | Mapper
15 | lambda_mapper
16 | spark.make_spark_mapper
17 |
--------------------------------------------------------------------------------
/docs/packages/preprocess.rst:
--------------------------------------------------------------------------------
1 | Snorkel Preprocess Package
2 | --------------------------
3 |
4 | Preprocessors for LFs, TFs, and SFs.
5 |
6 | .. currentmodule:: snorkel.preprocess
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/preprocess/
10 | :nosignatures:
11 |
12 | BasePreprocessor
13 | LambdaPreprocessor
14 | Preprocessor
15 | nlp.SpacyPreprocessor
16 | spark.make_spark_preprocessor
17 | preprocessor
18 |
--------------------------------------------------------------------------------
/docs/packages/slicing.rst:
--------------------------------------------------------------------------------
1 | Snorkel Slicing Package
2 | -----------------------
3 |
4 | Programmatic data set slicing: SF creation, monitoring utilities, and representation learning for slices.
5 |
6 | .. currentmodule:: snorkel.slicing
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/slicing/
10 | :nosignatures:
11 |
12 | apply.dask.DaskSFApplier
13 | sf.nlp.NLPSlicingFunction
14 | apply.dask.PandasParallelSFApplier
15 | PandasSFApplier
16 | SFApplier
17 | SliceAwareClassifier
18 | SliceCombinerModule
19 | SlicingFunction
20 | apply.spark.SparkSFApplier
21 | add_slice_labels
22 | convert_to_slice_tasks
23 | sf.nlp.nlp_slicing_function
24 | slice_dataframe
25 | slicing_function
26 |
--------------------------------------------------------------------------------
/docs/packages/utils.rst:
--------------------------------------------------------------------------------
1 | Snorkel Utils Package
2 | ---------------------
3 |
4 | General machine learning utilities shared across Snorkel.
5 |
6 | .. currentmodule:: snorkel.utils
7 |
8 | .. autosummary::
9 | :toctree: _autosummary/utils/
10 | :nosignatures:
11 |
12 | filter_labels
13 | preds_to_probs
14 | probs_to_preds
15 | to_int_label_array
16 |
--------------------------------------------------------------------------------
/docs/requirements-doc.txt:
--------------------------------------------------------------------------------
1 | sphinx==2.4.5
2 | sphinx_autodoc_typehints==1.7.0
3 | sphinx_rtd_theme==0.4.3
4 | https://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp36-cp36m-linux_x86_64.whl
5 |
--------------------------------------------------------------------------------
/figs/ONR.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/ONR.jpg
--------------------------------------------------------------------------------
/figs/WS_pipeline2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/WS_pipeline2.pdf
--------------------------------------------------------------------------------
/figs/darpa.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/darpa.JPG
--------------------------------------------------------------------------------
/figs/dp_neurips_2016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/dp_neurips_2016.png
--------------------------------------------------------------------------------
/figs/logo_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/logo_01.png
--------------------------------------------------------------------------------
/figs/mobilize_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/mobilize_logo.png
--------------------------------------------------------------------------------
/figs/moore_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/moore_logo.png
--------------------------------------------------------------------------------
/figs/nih_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/nih_logo.png
--------------------------------------------------------------------------------
/figs/user_logos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/user_logos.png
--------------------------------------------------------------------------------
/figs/vldb2018_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/figs/vldb2018_logo.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 40.6.2",
4 | "wheel >= 0.30.0",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [tool.black]
9 | line-length = 88
10 | target-version = ['py311']
11 | exclude = '''
12 | /(
13 | \.eggs
14 | | \.git
15 | | \.mypy_cache
16 | | \.tox
17 | | \.env
18 | | \.venv
19 | | _build
20 | | build
21 | | dist
22 | )/
23 | '''
--------------------------------------------------------------------------------
/requirements-pyspark.txt:
--------------------------------------------------------------------------------
1 | # Note: we don't include PySpark in the normal required installs.
2 | # Installing a new version may overwrite your existing system install.
3 | pyspark==3.4.1
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Library dependencies for Python code. You need to install these with
2 | # `pip install -r requirements.txt` or
3 | # `conda install --file requirements.txt`
4 | # to ensure that you can use all Snorkel code.
5 | # NOTE: all essential packages must be placed under a section named
6 | # '#### ESSENTIAL ...' so that the script `./scripts/check_requirements.py`
7 | # can find them.
8 |
9 | #### ESSENTIAL LIBRARIES
10 |
11 | # General scientific computing
12 | numpy>=1.24.0
13 | scipy>=1.2.0
14 |
15 | # Data storage and function application
16 | pandas>=1.0.0
17 | tqdm>=4.33.0
18 |
19 | # Internal models
20 | scikit-learn>=0.20.2
21 | torch>=1.2.0
22 | munkres>=1.0.6
23 |
24 | # LF dependency learning
25 | networkx>=2.2
26 |
27 | # Model introspection tools
28 | protobuf>=3.19.6
29 | tensorboard>=2.13.0
30 |
31 | #### EXTRA/TEST LIBRARIES
32 |
33 | # spaCy (NLP)
34 | spacy>=2.1.0
35 | blis>=0.3.0
36 |
37 | # Dask (parallelism)
38 | dask[dataframe]>=2020.12.0
39 | distributed>=2023.7.0
40 |
41 | # Dill (serialization)
42 | dill>=0.3.0
43 |
44 | #### DEV TOOLS
45 |
46 | black>=22.8
47 | flake8>=3.7.0
48 | importlib_metadata<5 # necessary for flake8
49 | isort>=4.3.0
50 | mypy>=0.760
51 | pydocstyle>=4.0.0
52 | pytest>=6.0.0
53 | pytest-cov>=2.7.0
54 | pytest-doctestplus>=0.3.0
55 | tox>=3.13.0
56 |
--------------------------------------------------------------------------------
/scripts/sync_api_docs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """Check and update API docs under docs/packages.
4 |
5 | This script checks and updates the package documentation pages, making sure
6 | that the packages in docs/packages.json are documented and up to date.
7 | Rather than calling this directly, use `tox -e check` or `tox -e fix`.
8 | """
9 |
10 | import json
11 | import os
12 | import sys
13 | from importlib import import_module
14 | from typing import Any, List
15 |
16 | PACKAGE_INFO_PATH = "docs/packages.json"
17 | PACKAGE_PAGE_PATH = "docs/packages"
18 |
19 |
20 | PACKAGE_DOC_TEMPLATE = """{title}
21 | {underscore}
22 |
23 | {docstring}
24 |
25 | .. currentmodule:: snorkel.{package_name}
26 |
27 | .. autosummary::
28 | :toctree: _autosummary/{package_name}/
29 | :nosignatures:
30 |
31 | {members}
32 | """
33 |
34 |
35 | def get_title_and_underscore(package_name: str) -> str:
36 | title = f"Snorkel {package_name.capitalize()} Package"
37 | underscore = "-" * len(title)
38 | return title, underscore
39 |
40 |
41 | def get_package_members(package: Any) -> List[str]:
42 | members = []
43 | for name in dir(package):
44 | if name.startswith("_"):
45 | continue
46 | obj = getattr(package, name)
47 | if isinstance(obj, type) or callable(obj):
48 | members.append(name)
49 | return members
50 |
51 |
52 | def main(check: bool) -> None:
53 | with open(PACKAGE_INFO_PATH, "r") as f:
54 | packages_info = json.load(f)
55 | package_names = sorted(packages_info["packages"])
56 | if check:
57 | f_basenames = sorted(
58 | [
59 | os.path.splitext(f_name)[0]
60 | for f_name in os.listdir(PACKAGE_PAGE_PATH)
61 | if f_name.endswith(".rst")
62 | ]
63 | )
64 | if f_basenames != package_names:
65 | raise ValueError(
66 | "Expected package files do not match actual!\n"
67 | f"Expected: {package_names}\n"
68 | f"Actual: {f_basenames}"
69 | )
70 | else:
71 | os.makedirs(PACKAGE_PAGE_PATH, exist_ok=True)
72 | for package_name in package_names:
73 | package = import_module(f"snorkel.{package_name}")
74 | docstring = package.__doc__
75 | title, underscore = get_title_and_underscore(package_name)
76 | all_members = get_package_members(package)
77 | all_members.extend(packages_info["extra_members"].get(package_name, []))
78 | contents = PACKAGE_DOC_TEMPLATE.format(
79 | title=title,
80 | underscore=underscore,
81 | docstring=docstring,
82 | package_name=package_name,
83 | members="\n ".join(sorted(all_members, key=lambda s: s.split(".")[-1])),
84 | )
85 | f_path = os.path.join(PACKAGE_PAGE_PATH, f"{package_name}.rst")
86 | if check:
87 | with open(f_path, "r") as f:
88 | contents_actual = f.read()
89 | if contents != contents_actual:
90 | raise ValueError(f"Contents for {package_name} differ!")
91 | else:
92 | with open(f_path, "w") as f:
93 | f.write(contents)
94 |
95 |
96 | if __name__ == "__main__":
97 | check = False if len(sys.argv) == 1 else (sys.argv[1] == "--check")
98 | sys.exit(main(check))
99 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | testpaths = test
3 | markers =
4 | spark
5 | complex
6 | norecursedirs = .tox
7 | doctest_optionflags =
8 | NORMALIZE_WHITESPACE
9 | ELLIPSIS
10 | FLOAT_CMP
11 |
12 | [flake8]
13 | extend-ignore =
14 | E203,
15 | # Throws errors for '#%%' delimiter in VSCode jupyter notebook syntax
16 | E265,
17 | E501,
18 | E731,
19 | E741,
20 | exclude =
21 | .eggs,
22 | .git,
23 | .mypy_cache,
24 | .tox,
25 | .env,
26 | .venv,
27 | _build,
28 | build,
29 | dist
30 |
31 | [isort]
32 | multi_line_output=3
33 | include_trailing_comma=True
34 | force_grid_wrap=0
35 | combine_as_imports=True
36 | line_length=88
37 | known_first_party=
38 | snorkel,
39 | known_third_party=
40 | numpy,
41 | pandas,
42 | pyspark,
43 | scipy,
44 | setuptools,
45 | tqdm,
46 | default_section=THIRDPARTY
47 | skip=.env,.venv,.tox
48 |
49 | [pydocstyle]
50 | convention = numpy
51 | add-ignore =
52 | D100,
53 | D104,
54 | D105,
55 | D107,
56 | D202,
57 | D203,
58 | D204,
59 | D213,
60 | D413,
61 |
62 | [mypy]
63 |
64 | [mypy-dask]
65 | ignore_missing_imports = True
66 |
67 | [mypy-dask.distributed]
68 | ignore_missing_imports = True
69 |
70 | [mypy-networkx]
71 | ignore_missing_imports = True
72 |
73 | [mypy-numpy]
74 | ignore_missing_imports = True
75 |
76 | [mypy-numpy.lib]
77 | ignore_missing_imports = True
78 |
79 | [mypy-pandas]
80 | ignore_missing_imports = True
81 |
82 | [mypy-scipy]
83 | ignore_missing_imports = True
84 |
85 | [mypy-scipy.sparse]
86 | ignore_missing_imports = True
87 |
88 | [mypy-sklearn]
89 | ignore_missing_imports = True
90 |
91 | [mypy-sklearn.metrics]
92 | ignore_missing_imports = True
93 |
94 | [mypy-pyspark]
95 | ignore_missing_imports = True
96 |
97 | [mypy-pyspark.sql]
98 | ignore_missing_imports = True
99 |
100 | [mypy-spacy]
101 | ignore_missing_imports = True
102 |
103 | [mypy-tqdm]
104 | ignore_missing_imports = True
105 |
106 | [mypy-torch]
107 | ignore_missing_imports = True
108 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # version.py defines the VERSION and VERSION_SHORT variables.
6 | # We use exec here so we don't import snorkel.
7 | VERSION: Dict[str, str] = {}
8 | with open("snorkel/version.py", "r") as version_file:
9 | exec(version_file.read(), VERSION)
10 |
11 | # Use README.md as the long_description for the package
12 | with open("README.md", "r") as readme_file:
13 | long_description = readme_file.read()
14 |
15 | setup(
16 | name="snorkel",
17 | version=VERSION["VERSION"],
18 | url="https://github.com/snorkel-team/snorkel",
19 | description="A system for quickly generating training data with weak supervision",
20 | long_description_content_type="text/markdown",
21 | long_description=long_description,
22 | license="Apache License 2.0",
23 | classifiers=[
24 | "Intended Audience :: Science/Research",
25 | "Topic :: Scientific/Engineering :: Information Analysis",
26 | "License :: OSI Approved :: Apache Software License",
27 | "Programming Language :: Python :: 3",
28 | ],
29 | project_urls={
30 | "Homepage": "https://snorkel.org",
31 | "Source": "https://github.com/snorkel-team/snorkel/",
32 | "Bug Reports": "https://github.com/snorkel-team/snorkel/issues",
33 | "Citation": "https://doi.org/10.14778/3157794.3157797",
34 | },
35 | packages=find_packages(exclude=("test*",)),
36 | include_package_data=True,
37 | install_requires=[
38 | "munkres>=1.0.6",
39 | "numpy>=1.24.0",
40 | "scipy>=1.2.0",
41 | "pandas>=1.0.0",
42 | "tqdm>=4.33.0",
43 | "scikit-learn>=0.20.2",
44 | "torch>=1.2.0",
45 | "tensorboard>=2.13.0",
46 | "protobuf>=3.19.6",
47 | "networkx>=2.2",
48 | ],
49 | python_requires=">=3.11",
50 | keywords="machine-learning ai weak-supervision",
51 | )
52 |
--------------------------------------------------------------------------------
/snorkel/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import VERSION as __version__ # noqa: F401
2 |
--------------------------------------------------------------------------------
/snorkel/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | """Generic model analysis utilities shared across Snorkel."""
2 |
3 | from .error_analysis import get_label_buckets, get_label_instances # noqa: F401
4 | from .metrics import metric_score # noqa: F401
5 | from .scorer import Scorer # noqa: F401
6 |
--------------------------------------------------------------------------------
/snorkel/analysis/error_analysis.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 | from typing import DefaultDict, Dict, List, Tuple
4 |
5 | import numpy as np
6 |
7 | from snorkel.utils import to_int_label_array
8 |
9 |
10 | def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]:
11 | """Return data point indices bucketed by label combinations.
12 |
13 | Parameters
14 | ----------
15 | *y
16 | A list of np.ndarray of (int) labels
17 |
18 | Returns
19 | -------
20 | Dict[Tuple[int, ...], np.ndarray]
21 | A mapping of each label bucket to a NumPy array of its corresponding indices
22 |
23 | Example
24 | -------
25 | A common use case is calling ``buckets = label_buckets(Y_gold, Y_pred)`` where
26 | ``Y_gold`` is a set of gold (i.e. ground truth) labels and ``Y_pred`` is a
27 | corresponding set of predicted labels.
28 |
29 | >>> Y_gold = np.array([1, 1, 1, 0])
30 | >>> Y_pred = np.array([1, 1, -1, -1])
31 | >>> buckets = get_label_buckets(Y_gold, Y_pred)
32 |
33 | The returned ``buckets[(i, j)]`` is a NumPy array of data point indices with
34 | true label i and predicted label j.
35 |
36 | More generally, the returned indices within each bucket refer to the order of the
37 | labels that were passed in as function arguments.
38 |
39 | >>> buckets[(1, 1)] # true positives
40 | array([0, 1])
41 | >>> (1, 0) in buckets # false positives
42 | False
43 | >>> (0, 1) in buckets # false negatives
44 | False
45 | >>> (0, 0) in buckets # true negatives
46 | False
47 | >>> buckets[(1, -1)] # abstained positives
48 | array([2])
49 | >>> buckets[(0, -1)] # abstained negatives
50 | array([3])
51 | """
52 | buckets: DefaultDict[Tuple[int, int], List[int]] = defaultdict(list)
53 | y_flat = list(map(lambda x: to_int_label_array(x, flatten_vector=True), y))
54 | if len(set(map(len, y_flat))) != 1:
55 | raise ValueError("Arrays must all have the same number of elements")
56 | for i, labels in enumerate(zip(*y_flat)):
57 | buckets[labels].append(i)
58 | return {k: np.array(v) for k, v in buckets.items()}
59 |
60 |
61 | def get_label_instances(
62 | bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray
63 | ) -> np.ndarray:
64 | """Return instances in x with the specified combination of labels.
65 |
66 | Parameters
67 | ----------
68 | bucket
69 | A tuple of label values corresponding to which instances from x are returned
70 | x
71 | NumPy array of data instances to be returned
72 | *y
73 | A list of np.ndarray of (int) labels
74 |
75 | Returns
76 | -------
77 | np.ndarray
78 | NumPy array of instances from x with the specified combination of labels
79 |
80 | Example
81 | -------
82 | A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)``
83 | where ``x`` is a NumPy array of data instances that the labels correspond to,
84 | ``Y_gold`` is a list of gold (i.e. ground truth) labels, and
85 | ``Y_pred`` is a corresponding list of predicted labels.
86 |
87 | >>> import pandas as pd
88 | >>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]})
89 | >>> Y_gold = np.array([1, 1, 1])
90 | >>> Y_pred = np.array([1, 0, 0])
91 | >>> bucket = (1, 0)
92 |
93 | The returned NumPy array of data instances from ``x`` will correspond to
94 | the rows where the first list had a 1 and the second list had a 0.
95 | >>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)
96 | array([['a second string', '2'],
97 | ['a third string', '3']], dtype=object)
98 |
99 | More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...``
100 | the returned data instances from ``x`` will correspond to the rows where
101 | y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y``
102 | must all be the same length.
103 | """
104 | if len(y) != len(bucket):
105 | raise ValueError("Number of lists must match the amount of labels in bucket")
106 | if x.shape[0] != len(y[0]):
107 | # Note: the check for all y having the same number of elements occurs in get_label_buckets
108 | raise ValueError(
109 | "Number of rows in x does not match number of elements in at least one label list"
110 | )
111 | buckets = get_label_buckets(*y)
112 | try:
113 | indices = buckets[bucket]
114 | except KeyError:
115 | logging.warning("Bucket" + str(bucket) + " does not exist.")
116 | return np.array([])
117 | instances = x[indices]
118 | return instances
119 |
--------------------------------------------------------------------------------
/snorkel/analysis/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, NamedTuple, Optional
2 |
3 | import numpy as np
4 | import sklearn.metrics as skmetrics
5 |
6 | from snorkel.utils import filter_labels, to_int_label_array
7 |
8 |
9 | class Metric(NamedTuple):
10 | """Specification for a metric and the subset of [golds, preds, probs] it expects."""
11 |
12 | func: Callable[..., float]
13 | inputs: List[str] = ["golds", "preds"]
14 |
15 |
16 | def metric_score(
17 | golds: Optional[np.ndarray] = None,
18 | preds: Optional[np.ndarray] = None,
19 | probs: Optional[np.ndarray] = None,
20 | metric: str = "accuracy",
21 | filter_dict: Optional[Dict[str, List[int]]] = None,
22 | **kwargs: Any,
23 | ) -> float:
24 | """Evaluate a standard metric on a set of predictions/probabilities.
25 |
26 | Parameters
27 | ----------
28 | golds
29 | An array of gold (int) labels
30 | preds
31 | An array of (int) predictions
32 | probs
33 | An [n_datapoints, n_classes] array of probabilistic (float) predictions
34 | metric
35 | The name of the metric to calculate
36 | filter_dict
37 | A mapping from label set name to the labels that should be filtered out for
38 | that label set
39 |
40 | Returns
41 | -------
42 | float
43 | The value of the requested metric
44 |
45 | Raises
46 | ------
47 | ValueError
48 | The requested metric is not currently supported
49 | ValueError
50 | The user attempted to calculate roc_auc score for a non-binary problem
51 | """
52 | if metric not in METRICS:
53 | msg = f"The metric you provided ({metric}) is not currently implemented."
54 | raise ValueError(msg)
55 |
56 | # Print helpful error messages if golds or preds has invalid shape or type
57 | golds = to_int_label_array(golds) if golds is not None else None
58 | preds = to_int_label_array(preds) if preds is not None else None
59 |
60 | # Optionally filter out examples (e.g., abstain predictions or unknown labels)
61 | label_dict: Dict[str, Optional[np.ndarray]] = {
62 | "golds": golds,
63 | "preds": preds,
64 | "probs": probs,
65 | }
66 | if filter_dict:
67 | if set(filter_dict.keys()).difference(set(label_dict.keys())):
68 | raise ValueError(
69 | "filter_dict must only include keys in ['golds', 'preds', 'probs']"
70 | )
71 | # label_dict is overwritten from type Dict[str, Optional[np.ndarray]]
72 | # to Dict[str, np.ndarray]
73 | label_dict = filter_labels(label_dict, filter_dict) # type: ignore
74 |
75 | # Confirm that required label sets are available
76 | func, label_names = METRICS[metric]
77 | for label_name in label_names:
78 | if label_dict[label_name] is None:
79 | raise ValueError(f"Metric {metric} requires access to {label_name}.")
80 |
81 | label_sets = [label_dict[label_name] for label_name in label_names]
82 | return func(*label_sets, **kwargs)
83 |
84 |
85 | def _coverage_score(preds: np.ndarray) -> float:
86 | return np.sum(preds != -1) / len(preds)
87 |
88 |
89 | def _roc_auc_score(golds: np.ndarray, probs: np.ndarray) -> float:
90 | if not probs.shape[1] == 2:
91 | raise ValueError(
92 | "Metric roc_auc is currently only defined for binary problems."
93 | )
94 | return skmetrics.roc_auc_score(golds, probs[:, 1])
95 |
96 |
97 | def _f1_score(golds: np.ndarray, preds: np.ndarray) -> float:
98 | if golds.max() <= 1:
99 | return skmetrics.f1_score(golds, preds)
100 | else:
101 | raise ValueError(
102 | "f1 not supported for multiclass. Try f1_micro or f1_macro instead."
103 | )
104 |
105 |
106 | def _f1_micro_score(golds: np.ndarray, preds: np.ndarray) -> float:
107 | return skmetrics.f1_score(golds, preds, average="micro")
108 |
109 |
110 | def _f1_macro_score(golds: np.ndarray, preds: np.ndarray) -> float:
111 | return skmetrics.f1_score(golds, preds, average="macro")
112 |
113 |
114 | # See https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
115 | # for details on the definitions and available kwargs for all metrics from scikit-learn
116 | METRICS = {
117 | "accuracy": Metric(skmetrics.accuracy_score),
118 | "coverage": Metric(_coverage_score, ["preds"]),
119 | "precision": Metric(skmetrics.precision_score),
120 | "recall": Metric(skmetrics.recall_score),
121 | "f1": Metric(_f1_score, ["golds", "preds"]),
122 | "f1_micro": Metric(_f1_micro_score, ["golds", "preds"]),
123 | "f1_macro": Metric(_f1_macro_score, ["golds", "preds"]),
124 | "fbeta": Metric(skmetrics.fbeta_score),
125 | "matthews_corrcoef": Metric(skmetrics.matthews_corrcoef),
126 | "roc_auc": Metric(_roc_auc_score, ["golds", "probs"]),
127 | }
128 |
--------------------------------------------------------------------------------
/snorkel/augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | """Programmatic data set augmentation: TF creation and data generation utilities."""
2 |
3 | from .apply.core import TFApplier # noqa: F401
4 | from .apply.pandas import PandasTFApplier # noqa: F401
5 | from .policy.core import ApplyAllPolicy, ApplyEachPolicy, ApplyOnePolicy # noqa: F401
6 | from .policy.sampling import MeanFieldPolicy, RandomPolicy # noqa: F401
7 | from .tf import TransformationFunction, transformation_function # noqa: F401
8 |
--------------------------------------------------------------------------------
/snorkel/augmentation/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/augmentation/apply/__init__.py
--------------------------------------------------------------------------------
/snorkel/augmentation/apply/core.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List
2 |
3 | from tqdm import tqdm
4 |
5 | from snorkel.augmentation.policy.core import Policy
6 | from snorkel.augmentation.tf import BaseTransformationFunction
7 | from snorkel.types import DataPoint, DataPoints
8 | from snorkel.utils.data_operators import check_unique_names
9 |
10 |
11 | class BaseTFApplier:
12 | """Base class for TF applier objects.
13 |
14 | Base class for TF applier objects, which execute a set of TF
15 | on a collection of data points. Subclasses should operate on
16 | a single data point collection format (e.g. ``DataFrame``).
17 | Subclasses must implement the ``apply`` method.
18 |
19 | Parameters
20 | ----------
21 | tfs
22 | TFs that this applier executes on examples
23 | policy
24 | Augmentation policy used to generate sequences of TFs
25 |
26 | Raises
27 | ------
28 | ValueError
29 | If names of TFs are not unique
30 | """
31 |
32 | def __init__(self, tfs: List[BaseTransformationFunction], policy: Policy) -> None:
33 | self._tfs = tfs
34 | self._tf_names = [tf.name for tf in tfs]
35 | check_unique_names(self._tf_names)
36 | self._policy = policy
37 |
38 | def _apply_policy_to_data_point(self, x: DataPoint) -> DataPoints:
39 | x_transformed = []
40 | for seq in self._policy.generate_for_example():
41 | x_t = x
42 | # Handle empty sequence for `keep_original`
43 | transform_applied = len(seq) == 0
44 | # Apply TFs
45 | for tf_idx in seq:
46 | tf = self._tfs[tf_idx]
47 | x_t_or_none = tf(x_t)
48 | # Update if transformation was applied
49 | if x_t_or_none is not None:
50 | transform_applied = True
51 | x_t = x_t_or_none
52 | # Add example if original or transformations applied
53 | if transform_applied:
54 | x_transformed.append(x_t)
55 | return x_transformed
56 |
57 | def __repr__(self) -> str:
58 | policy_name = type(self._policy).__name__
59 | return f"{type(self).__name__}, Policy: {policy_name}, TFs: {self._tf_names}"
60 |
61 |
62 | class TFApplier(BaseTFApplier):
63 | """TF applier for a list of data points.
64 |
65 | Augments a list of data points (e.g. ``SimpleNamespace``). Primarily
66 | useful for testing.
67 | """
68 |
69 | def apply_generator(
70 | self, data_points: DataPoints, batch_size: int
71 | ) -> Iterator[List[DataPoint]]:
72 | """Augment a list of data points using TFs and policy in batches.
73 |
74 | This method acts as a generator, yielding augmented data points for
75 | a given input batch of data points. This can be useful in a training
76 | loop when it is too memory-intensive to pregenerate all transformed
77 | examples.
78 |
79 | Parameters
80 | ----------
81 | data_points
82 | List containing data points to be transformed
83 | batch_size
84 | Batch size for generator. Yields augmented data points
85 | for the next ``batch_size`` input data points.
86 |
87 | Yields
88 | ------
89 | List[DataPoint]
90 | List of data points in augmented data set for batches of inputs
91 | """
92 | for i in range(0, len(data_points), batch_size):
93 | batch_transformed: List[DataPoint] = []
94 | for x in data_points[i : i + batch_size]:
95 | batch_transformed.extend(self._apply_policy_to_data_point(x))
96 | yield batch_transformed
97 |
98 | def apply(
99 | self, data_points: DataPoints, progress_bar: bool = True
100 | ) -> List[DataPoint]:
101 | """Augment a list of data points using TFs and policy.
102 |
103 | Parameters
104 | ----------
105 | data_points
106 | List containing data points to be transformed
107 | progress_bar
108 | Display a progress bar?
109 |
110 | Returns
111 | -------
112 | List[DataPoint]
113 | List of data points in augmented data set
114 | """
115 | x_transformed: List[DataPoint] = []
116 | for x in tqdm(data_points, disable=(not progress_bar)):
117 | x_transformed.extend(self._apply_policy_to_data_point(x))
118 | return x_transformed
119 |
--------------------------------------------------------------------------------
/snorkel/augmentation/apply/pandas.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | from .core import BaseTFApplier
7 |
8 |
9 | class PandasTFApplier(BaseTFApplier):
10 | """TF applier for a Pandas DataFrame.
11 |
12 | Data points are stored as Series in a DataFrame. The TFs
13 | run on data points obtained via a ``pandas.DataFrame.iterrows``
14 | call, which is single-process and can be slow for large DataFrames.
15 | For large datasets, consider ``DaskTFApplier`` or ``SparkTFApplier``.
16 | """
17 |
18 | def apply_generator(self, df: pd.DataFrame, batch_size: int) -> pd.DataFrame:
19 | """Augment a Pandas DataFrame of data points using TFs and policy in batches.
20 |
21 | This method acts as a generator, yielding augmented data points for
22 | a given input batch of data points. This can be useful in a training
23 | loop when it is too memory-intensive to pregenerate all transformed
24 | examples.
25 |
26 | Parameters
27 | ----------
28 | df
29 | Pandas DataFrame containing data points to be transformed
30 | batch_size
31 | Batch size for generator. Yields augmented data points
32 | for the next ``batch_size`` input data points.
33 |
34 | Returns
35 | -------
36 | pd.DataFrame
37 | Pandas DataFrame of data points in augmented data set
38 | """
39 | batch_transformed: List[pd.Series] = []
40 | for i, (_, x) in enumerate(df.iterrows()):
41 | batch_transformed.extend(self._apply_policy_to_data_point(x))
42 | if (i + 1) % batch_size == 0:
43 | yield pd.concat(batch_transformed, axis=1).T.infer_objects()
44 | batch_transformed = []
45 | yield pd.concat(batch_transformed, axis=1).T.infer_objects()
46 |
47 | def apply(self, df: pd.DataFrame, progress_bar: bool = True) -> pd.DataFrame:
48 | """Augment a Pandas DataFrame of data points using TFs and policy.
49 |
50 | Parameters
51 | ----------
52 | df
53 | Pandas DataFrame containing data points to be transformed
54 | progress_bar
55 | Display a progress bar?
56 |
57 | Returns
58 | -------
59 | pd.DataFrame
60 | Pandas DataFrame of data points in augmented data set
61 | """
62 | x_transformed: List[pd.Series] = []
63 | for _, x in tqdm(df.iterrows(), total=len(df), disable=(not progress_bar)):
64 | x_transformed.extend(self._apply_policy_to_data_point(x))
65 | return pd.concat(x_transformed, axis=1).T.infer_objects()
66 |
--------------------------------------------------------------------------------
/snorkel/augmentation/policy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/augmentation/policy/__init__.py
--------------------------------------------------------------------------------
/snorkel/augmentation/policy/sampling.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Sequence
2 |
3 | import numpy as np
4 |
5 | from .core import Policy
6 |
7 |
8 | class MeanFieldPolicy(Policy):
9 | """Sample sequences of TFs according to a distribution.
10 |
11 | Samples sequences of indices of a specified length from a
12 | user-provided distribution. A distribution over TFs can be
13 | learned by a TANDA mean-field model, for example.
14 | See https://hazyresearch.github.io/snorkel/blog/tanda.html
15 |
16 | Parameters
17 | ----------
18 | n_tfs
19 | Total number of TFs
20 | sequence_length
21 | Number of TFs to run on each data point
22 | p
23 | Probability distribution from which to sample TF indices.
24 | Must have length ``n_tfs`` and be a valid distribution.
25 | n_per_original
26 | Number of transformed data points per original
27 | keep_original
28 | Keep untransformed data point in augmented data set? Note that
29 | even if in-place modifications are made to the original data
30 | point by the TFs being applied, the original data point will
31 | remain unchanged.
32 |
33 | Attributes
34 | ----------
35 | n
36 | Total number of TFs
37 | n_per_original
38 | See above
39 | keep_original
40 | See above
41 | sequence_length
42 | See above
43 | """
44 |
45 | def __init__(
46 | self,
47 | n_tfs: int,
48 | sequence_length: int = 1,
49 | p: Optional[Sequence[float]] = None,
50 | n_per_original: int = 1,
51 | keep_original: bool = True,
52 | ) -> None:
53 | self.sequence_length = sequence_length
54 | self._p = p
55 | super().__init__(
56 | n_tfs, n_per_original=n_per_original, keep_original=keep_original
57 | )
58 |
59 | def generate(self) -> List[int]:
60 | """Generate a sequence of TF indices by sampling from distribution.
61 |
62 | Returns
63 | -------
64 | List[int]
65 | Indices of TFs to run on data point in order.
66 | """
67 | return np.random.choice(self.n, size=self.sequence_length, p=self._p).tolist()
68 |
69 |
70 | class RandomPolicy(MeanFieldPolicy):
71 | """Naive random augmentation policy.
72 |
73 | Samples sequences of TF indices a specified length at random
74 | from the total number of TFs. Sampling uniformly at random is
75 | a common baseline approach to data augmentation.
76 |
77 | Parameters
78 | ----------
79 | n_tfs
80 | Total number of TFs
81 | sequence_length
82 | Number of TFs to run on each data point
83 | n_per_original
84 | Number of transformed data points per original
85 | keep_original
86 | Keep untransformed data point in augmented data set? Note that
87 | even if in-place modifications are made to the original data
88 | point by the TFs being applied, the original data point will
89 | remain unchanged.
90 |
91 | Attributes
92 | ----------
93 | n
94 | Total number of TFs
95 | n_per_original
96 | See above
97 | keep_original
98 | See above
99 | sequence_length
100 | See above
101 | """
102 |
103 | def __init__(
104 | self,
105 | n_tfs: int,
106 | sequence_length: int = 1,
107 | n_per_original: int = 1,
108 | keep_original: bool = True,
109 | ) -> None:
110 | super().__init__(
111 | n_tfs,
112 | sequence_length=sequence_length,
113 | p=None,
114 | n_per_original=n_per_original,
115 | keep_original=keep_original,
116 | )
117 |
--------------------------------------------------------------------------------
/snorkel/augmentation/tf.py:
--------------------------------------------------------------------------------
1 | from snorkel.map import BaseMapper, LambdaMapper, Mapper, lambda_mapper
2 |
3 | """Base classes for transformation functions.
4 |
5 | A transformation function (TF) represents an atomic transformation
6 | to a data point in a data augmentation pipeline. Common examples in
7 | image processing include small image rotations or crops. Snorkel
8 | models data augmentation as a sequence of TFs generated by a policy.
9 | """
10 |
11 | # Used for type checking only
12 | # Note: subclassing as below trips up mypy
13 | BaseTransformationFunction = BaseMapper
14 |
15 |
16 | class TransformationFunction(Mapper):
17 | """Base class for TFs.
18 |
19 | See ``snorkel.map.core.Mapper`` for details.
20 | """
21 |
22 | pass
23 |
24 |
25 | class LambdaTransformationFunction(LambdaMapper):
26 | """Convenience class for definining TFs from functions.
27 |
28 | See ``snorkel.map.core.LambdaMapper`` for details.
29 | """
30 |
31 | pass
32 |
33 |
34 | class transformation_function(lambda_mapper):
35 | """Decorate functions to create TFs.
36 |
37 | See ``snorkel.map.core.lambda_mapper`` for details.
38 |
39 | Example
40 | -------
41 | >>> @transformation_function()
42 | ... def square(x):
43 | ... x.num = x.num ** 2
44 | ... return x
45 | >>> from types import SimpleNamespace
46 | >>> square(SimpleNamespace(num=2))
47 | namespace(num=4)
48 | """
49 |
50 | pass
51 |
--------------------------------------------------------------------------------
/snorkel/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """PyTorch-based multi-task learning framework for discriminative modeling."""
2 |
3 | from .data import DictDataLoader, DictDataset # noqa: F401
4 | from .loss import cross_entropy_with_probs # noqa: F401
5 | from .multitask_classifier import MultitaskClassifier # noqa: F401
6 | from .task import Operation, Task # noqa: F401
7 | from .training.loggers import ( # noqa: F401
8 | Checkpointer,
9 | CheckpointerConfig,
10 | LogManager,
11 | LogManagerConfig,
12 | LogWriter,
13 | LogWriterConfig,
14 | TensorBoardWriter,
15 | )
16 | from .training.trainer import Trainer # noqa: F401
17 |
--------------------------------------------------------------------------------
/snorkel/classification/loss.py:
--------------------------------------------------------------------------------
1 | from typing import List, Mapping, Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | Outputs = Mapping[str, List[torch.Tensor]]
7 |
8 |
9 | def cross_entropy_with_probs(
10 | input: torch.Tensor,
11 | target: torch.Tensor,
12 | weight: Optional[torch.Tensor] = None,
13 | reduction: str = "mean",
14 | ) -> torch.Tensor:
15 | """Calculate cross-entropy loss when targets are probabilities (floats), not ints.
16 |
17 | PyTorch's F.cross_entropy() method requires integer labels; it does accept
18 | probabilistic labels. We can, however, simulate such functionality with a for loop,
19 | calculating the loss contributed by each class and accumulating the results.
20 | Libraries such as keras do not require this workaround, as methods like
21 | "categorical_crossentropy" accept float labels natively.
22 |
23 | Note that the method signature is intentionally very similar to F.cross_entropy()
24 | so that it can be used as a drop-in replacement when target labels are changed from
25 | from a 1D tensor of ints to a 2D tensor of probabilities.
26 |
27 | Parameters
28 | ----------
29 | input
30 | A [num_points, num_classes] tensor of logits
31 | target
32 | A [num_points, num_classes] tensor of probabilistic target labels
33 | weight
34 | An optional [num_classes] array of weights to multiply the loss by per class
35 | reduction
36 | One of "none", "mean", "sum", indicating whether to return one loss per data
37 | point, the mean loss, or the sum of losses
38 |
39 | Returns
40 | -------
41 | torch.Tensor
42 | The calculated loss
43 |
44 | Raises
45 | ------
46 | ValueError
47 | If an invalid reduction keyword is submitted
48 | """
49 | num_points, num_classes = input.shape
50 | # Note that t.new_zeros, t.new_full put tensor on same device as t
51 | cum_losses = input.new_zeros(num_points)
52 | for y in range(num_classes):
53 | target_temp = input.new_full((num_points,), y, dtype=torch.long)
54 | y_loss = F.cross_entropy(input, target_temp, reduction="none")
55 | if weight is not None:
56 | y_loss = y_loss * weight[y]
57 | cum_losses += target[:, y].float() * y_loss
58 |
59 | if reduction == "none":
60 | return cum_losses
61 | elif reduction == "mean":
62 | return cum_losses.mean()
63 | elif reduction == "sum":
64 | return cum_losses.sum()
65 | else:
66 | raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']")
67 |
--------------------------------------------------------------------------------
/snorkel/classification/task.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 | from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | from snorkel.analysis import Scorer
10 |
11 | Outputs = Mapping[str, List[torch.FloatTensor]]
12 |
13 |
14 | class Operation:
15 | """A single operation (forward pass of a module) to execute in a Task.
16 |
17 | See ``Task`` for more detail on the usage and semantics of an Operation.
18 |
19 | Parameters
20 | ----------
21 | name
22 | The name of this operation (defaults to module_name since for most workflows,
23 | each module is only used once per forward pass)
24 | module_name
25 | The name of the module in the module pool that this operation uses
26 | inputs
27 | The inputs that the specified module expects, given as a list of names of
28 | previous operations (or optionally a tuple of the operation name and a key
29 | if the output of that module is a dict instead of a Tensor).
30 | Note that the original input to the model can be referred to as "_input_".
31 |
32 | Example
33 | -------
34 | >>> op1 = Operation(module_name="linear1", inputs=[("_input_", "features")])
35 | >>> op2 = Operation(module_name="linear2", inputs=["linear1"])
36 | >>> op_sequence = [op1, op2]
37 |
38 | Attributes
39 | ----------
40 | name
41 | See above
42 | module_name
43 | See above
44 | inputs
45 | See above
46 | """
47 |
48 | def __init__(
49 | self,
50 | module_name: str,
51 | inputs: Sequence[Union[str, Tuple[str, str]]],
52 | name: Optional[str] = None,
53 | ) -> None:
54 | self.name = name or module_name
55 | self.module_name = module_name
56 | self.inputs = inputs
57 |
58 | def __repr__(self) -> str:
59 | return (
60 | f"Operation(name={self.name}, "
61 | f"module_name={self.module_name}, "
62 | f"inputs={self.inputs})"
63 | )
64 |
65 |
66 | class Task:
67 | r"""A single task (a collection of modules and specified path through them).
68 |
69 | Parameters
70 | ----------
71 | name
72 | The name of the task
73 | module_pool
74 | A ModuleDict mapping module names to the modules themselves
75 | op_sequence
76 | A list of ``Operation``\s to execute in order, defining the flow of information
77 | through the network for this task
78 | scorer
79 | A ``Scorer`` with the desired metrics to calculate for this task
80 | loss_func
81 | A function that converts final logits into loss values.
82 | Defaults to F.cross_entropy() if none is provided.
83 | To use probalistic labels for training, use the Snorkel-defined method
84 | cross_entropy_with_probs() instead.
85 | output_func
86 | A function that converts final logits into 'outputs' (e.g. probabilities)
87 | Defaults to F.softmax(..., dim=1).
88 |
89 | Attributes
90 | ----------
91 | name
92 | See above
93 | module_pool
94 | See above
95 | op_sequence
96 | See above
97 | scorer
98 | See above
99 | loss_func
100 | See above
101 | output_func
102 | See above
103 | """
104 |
105 | def __init__(
106 | self,
107 | name: str,
108 | module_pool: nn.ModuleDict,
109 | op_sequence: Sequence[Operation],
110 | scorer: Scorer = Scorer(metrics=["accuracy"]),
111 | loss_func: Optional[Callable[..., torch.Tensor]] = None,
112 | output_func: Optional[Callable[..., torch.Tensor]] = None,
113 | ) -> None:
114 | self.name = name
115 | self.module_pool = module_pool
116 | self.op_sequence = op_sequence
117 | self.loss_func = loss_func or F.cross_entropy
118 | self.output_func = output_func or partial(F.softmax, dim=1)
119 | self.scorer = scorer
120 |
121 | logging.info(f"Created task: {self.name}")
122 |
123 | def __repr__(self) -> str:
124 | cls_name = type(self).__name__
125 | return f"{cls_name}(name={self.name})"
126 |
--------------------------------------------------------------------------------
/snorkel/classification/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/classification/training/__init__.py
--------------------------------------------------------------------------------
/snorkel/classification/training/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpointer import Checkpointer, CheckpointerConfig # noqa: F401
2 | from .log_manager import LogManager, LogManagerConfig # noqa: F401
3 | from .log_writer import LogWriter, LogWriterConfig # noqa: F401
4 | from .tensorboard_writer import TensorBoardWriter # noqa: F401
5 |
--------------------------------------------------------------------------------
/snorkel/classification/training/loggers/log_manager.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Optional
3 |
4 | from snorkel.classification.multitask_classifier import MultitaskClassifier
5 | from snorkel.types import Config
6 |
7 | from .checkpointer import Checkpointer
8 | from .log_writer import LogWriter
9 |
10 |
11 | class LogManagerConfig(Config):
12 | """Manager for checkpointing model.
13 |
14 | Parameters
15 | ----------
16 | counter_unit
17 | The unit to use when assessing when it's time to log.
18 | Options are ["epochs", "batches", "points"]
19 | evaluation_freq:
20 | Evaluate performance on the validation set every this many counter_units
21 | """
22 |
23 | counter_unit: str = "epochs"
24 | evaluation_freq: float = 1.0
25 |
26 |
27 | class LogManager:
28 | """A class to manage logging during training progress.
29 |
30 | Parameters
31 | ----------
32 | n_batches_per_epoch
33 | Total number batches per epoch
34 | log_writer
35 | ``LogWriter`` for current run logs
36 | checkpointer
37 | ``Checkpointer`` for current model
38 | kwargs
39 | Settings to update in LogManagerConfig
40 | """
41 |
42 | def __init__(
43 | self,
44 | n_batches_per_epoch: int,
45 | log_writer: Optional[LogWriter] = None,
46 | checkpointer: Optional[Checkpointer] = None,
47 | **kwargs: Any,
48 | ) -> None:
49 | self.config = LogManagerConfig(**kwargs) # type: ignore
50 | self.n_batches_per_epoch = n_batches_per_epoch
51 |
52 | self.log_writer = log_writer
53 | self.checkpointer = checkpointer
54 |
55 | # Set up counter unit
56 | self.counter_unit = self.config.counter_unit
57 | if self.counter_unit not in ["points", "batches", "epochs"]:
58 | raise ValueError(f"Unrecognized counter_unit: {self.counter_unit}")
59 |
60 | # Set up evaluation frequency
61 | self.evaluation_freq = self.config.evaluation_freq
62 | logging.info(f"Evaluating every {self.evaluation_freq} {self.counter_unit}.")
63 |
64 | # Set up number of X passed since last evaluation/checkpointing and total
65 | self.point_count = 0
66 | self.point_total = 0
67 |
68 | self.batch_count = 0
69 | self.batch_total = 0
70 |
71 | self.epoch_count = 0.0
72 | self.epoch_total = 0.0
73 |
74 | self.unit_count = 0.0
75 | self.unit_total = 0.0
76 |
77 | # Set up count that triggers the evaluation since last checkpointing
78 | self.trigger_count = 0
79 |
80 | def update(self, batch_size: int) -> None:
81 | """Update the count and total number."""
82 |
83 | # Update number of points
84 | self.point_count += batch_size
85 | self.point_total += batch_size
86 |
87 | # Update number of batches
88 | self.batch_count += 1
89 | self.batch_total += 1
90 |
91 | # Update number of epochs
92 | self.epoch_count = self.batch_count / self.n_batches_per_epoch
93 | self.epoch_total = self.batch_total / self.n_batches_per_epoch
94 |
95 | # Update number of units
96 | if self.counter_unit == "points":
97 | self.unit_count = self.point_count
98 | self.unit_total = self.point_total
99 | if self.counter_unit == "batches":
100 | self.unit_count = self.batch_count
101 | self.unit_total = self.batch_total
102 | elif self.counter_unit == "epochs":
103 | self.unit_count = self.epoch_count
104 | self.unit_total = self.epoch_total
105 |
106 | def trigger_evaluation(self) -> bool:
107 | """Check if current counts trigger evaluation."""
108 | satisfied = self.unit_count >= self.evaluation_freq
109 | if satisfied:
110 | self.trigger_count += 1
111 | self.reset()
112 | return satisfied
113 |
114 | def trigger_checkpointing(self) -> bool:
115 | """Check if current counts trigger checkpointing."""
116 | if self.checkpointer is None:
117 | return False
118 | satisfied = self.trigger_count >= self.checkpointer.checkpoint_factor
119 | if satisfied:
120 | self.trigger_count = 0
121 | return satisfied
122 |
123 | def reset(self) -> None:
124 | """Reset counters."""
125 | self.point_count = 0
126 | self.batch_count = 0
127 | self.epoch_count = 0
128 | self.unit_count = 0
129 |
130 | def cleanup(self, model: MultitaskClassifier) -> MultitaskClassifier:
131 | """Close the log writer and checkpointer if needed. Reload best model."""
132 | if self.log_writer is not None:
133 | self.log_writer.cleanup()
134 | if self.checkpointer is not None:
135 | self.checkpointer.clear()
136 | model = self.checkpointer.load_best_model(model)
137 | return model
138 |
--------------------------------------------------------------------------------
/snorkel/classification/training/loggers/log_writer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from collections import defaultdict
5 | from datetime import datetime
6 | from typing import Any, DefaultDict, List, Mapping, Optional
7 |
8 | from snorkel.types import Config
9 |
10 |
11 | class LogWriterConfig(Config):
12 | """Manager for checkpointing model.
13 |
14 | Parameters
15 | ----------
16 | log_dir
17 | The root directory where logs should be saved
18 | run_name
19 | The name of this particular run (defaults to date-time combination if None)
20 | """
21 |
22 | log_dir: str = "logs"
23 | run_name: Optional[str] = None
24 |
25 |
26 | class LogWriter:
27 | """A class for writing logs.
28 |
29 | Parameters
30 | ----------
31 | kwargs
32 | Settings to merge into LogWriterConfig
33 |
34 | Attributes
35 | ----------
36 | config
37 | Merged configuration
38 | run_name
39 | Name of run if provided, otherwise date-time combination
40 | log_dir
41 | The root directory where logs should be saved
42 | run_log
43 | Dictionary of scalar values to log, keyed by value name
44 | """
45 |
46 | def __init__(self, **kwargs: Any) -> None:
47 | self.config = LogWriterConfig(**kwargs) # type: ignore
48 |
49 | self.run_name = self.config.run_name
50 | if self.run_name is None:
51 | date = datetime.now().strftime("%Y_%m_%d")
52 | time = datetime.now().strftime("%H_%M_%S")
53 | self.run_name = f"{date}/{time}/"
54 |
55 | self.log_dir = os.path.join(self.config.log_dir, self.run_name)
56 | if not os.path.exists(self.log_dir):
57 | os.makedirs(self.log_dir)
58 |
59 | self.run_log: DefaultDict[str, List[List[float]]] = defaultdict(list)
60 |
61 | def add_scalar(self, name: str, value: float, step: float) -> None:
62 | """Log a scalar variable.
63 |
64 | Parameters
65 | ----------
66 | name
67 | Name of the scalar collection
68 | value
69 | Value of scalar
70 | step
71 | Step axis value
72 | """
73 | # Note: storing as list for JSON roundtripping
74 | self.run_log[name].append([step, value])
75 |
76 | def write_config(
77 | self, config: Config, config_filename: str = "config.json"
78 | ) -> None:
79 | """Dump the config to file.
80 |
81 | Parameters
82 | ----------
83 | config
84 | JSON-compatible config to write to file
85 | config_filename
86 | Name of file in logging directory to write to
87 | """
88 | self.write_json(config._asdict(), config_filename)
89 |
90 | def write_log(self, log_filename: str) -> None:
91 | """Dump the scalar value log to file.
92 |
93 | Parameters
94 | ----------
95 | log_filename
96 | Name of file in logging directory to write to
97 | """
98 | self.write_json(self.run_log, log_filename)
99 |
100 | def write_text(self, text: str, filename: str) -> None:
101 | """Dump user-provided text to filename (e.g., the launch command).
102 |
103 | Parameters
104 | ----------
105 | text
106 | Text to write
107 | filename
108 | Name of file in logging directory to write to
109 | """
110 | text_path = os.path.join(self.log_dir, filename)
111 | with open(text_path, "w") as f:
112 | f.write(text)
113 |
114 | def write_json(self, dict_to_write: Mapping[str, Any], filename: str) -> None:
115 | """Dump a JSON-compatbile object to root log directory.
116 |
117 | Parameters
118 | ----------
119 | dict_to_write
120 | JSON-compatbile object to log
121 | filename
122 | Name of file in logging directory to write to
123 | """
124 | if not filename.endswith(".json"): # pragma: no cover
125 | logging.warning(
126 | f"Using write_json() method with a filename without a .json extension: {filename}"
127 | )
128 | log_path = os.path.join(self.log_dir, filename)
129 | with open(log_path, "w") as f:
130 | json.dump(dict_to_write, f)
131 |
132 | def cleanup(self) -> None:
133 | """Perform final operations and close writer if necessary."""
134 | self.write_log("log.json")
135 |
--------------------------------------------------------------------------------
/snorkel/classification/training/loggers/tensorboard_writer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 | from snorkel.types import Config
6 |
7 | from .log_writer import LogWriter
8 |
9 |
10 | class TensorBoardWriter(LogWriter):
11 | """A class for logging to Tensorboard during training process.
12 |
13 | See ``LogWriter`` for more attributes.
14 |
15 | Parameters
16 | ----------
17 | kwargs
18 | Passed to ``LogWriter`` initializer
19 |
20 | Attributes
21 | ----------
22 | writer
23 | ``SummaryWriter`` for logging and visualization
24 | """
25 |
26 | def __init__(self, **kwargs: Any) -> None:
27 | super().__init__(**kwargs)
28 | self.writer = SummaryWriter(self.log_dir)
29 |
30 | def add_scalar(self, name: str, value: float, step: float) -> None:
31 | """Log a scalar variable to TensorBoard.
32 |
33 | Parameters
34 | ----------
35 | name
36 | Name of the scalar collection
37 | value
38 | Value of scalar
39 | step
40 | Step axis value
41 | """
42 | self.writer.add_scalar(name, value, step)
43 |
44 | def write_config(
45 | self, config: Config, config_filename: str = "config.json"
46 | ) -> None:
47 | """Dump the config to file and add it to TensorBoard.
48 |
49 | Parameters
50 | ----------
51 | config
52 | JSON-compatible config to write to TensorBoard
53 | config_filename
54 | File to write config to
55 | """
56 | super().write_config(config, config_filename)
57 | self.writer.add_text(tag="config", text_string=str(config))
58 |
59 | def cleanup(self) -> None:
60 | """Close the ``SummaryWriter``."""
61 | self.writer.close()
62 |
--------------------------------------------------------------------------------
/snorkel/classification/training/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .sequential_scheduler import SequentialScheduler
2 | from .shuffled_scheduler import ShuffledScheduler
3 |
4 | batch_schedulers = {"sequential": SequentialScheduler, "shuffled": ShuffledScheduler}
5 |
--------------------------------------------------------------------------------
/snorkel/classification/training/schedulers/scheduler.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, Iterator, Sequence, Tuple
3 |
4 | from torch import Tensor
5 |
6 | from snorkel.classification.data import DictDataLoader # noqa: F401
7 |
8 | BatchIterator = Iterator[
9 | Tuple[Tuple[Dict[str, Any], Dict[str, Tensor]], "DictDataLoader"]
10 | ]
11 |
12 |
13 | class Scheduler(ABC):
14 | """Return batches from all dataloaders according to a specified strategy."""
15 |
16 | def __init__(self) -> None:
17 | pass
18 |
19 | @abstractmethod
20 | def get_batches(self, dataloaders: Sequence["DictDataLoader"]) -> BatchIterator:
21 | """Return batches from dataloaders according to a specified strategy.
22 |
23 | Parameters
24 | ----------
25 | dataloaders
26 | A sequence of dataloaders to get batches from
27 |
28 | Yields
29 | ------
30 | (batch, dataloader)
31 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader
32 | that that batch came from. That dataloader will not be accessed by the
33 | model; it is passed primarily so that the model can pull the necessary
34 | metadata to know what to do with the batch it has been given.
35 | """
36 |
--------------------------------------------------------------------------------
/snorkel/classification/training/schedulers/sequential_scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from snorkel.classification.data import DictDataLoader
4 |
5 | from .scheduler import BatchIterator, Scheduler
6 |
7 |
8 | class SequentialScheduler(Scheduler):
9 | """Return batches from all dataloaders in sequential order."""
10 |
11 | def __init__(self) -> None:
12 | super().__init__()
13 |
14 | def get_batches(self, dataloaders: Sequence[DictDataLoader]) -> BatchIterator:
15 | """Return batches from dataloaders sequentially in the order they were given.
16 |
17 | Parameters
18 | ----------
19 | dataloaders
20 | A sequence of dataloaders to get batches from
21 |
22 | Yields
23 | ------
24 | (batch, dataloader)
25 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader
26 | that that batch came from. That dataloader will not be accessed by the
27 | model; it is passed primarily so that the model can pull the necessary
28 | metadata to know what to do with the batch it has been given.
29 | """
30 | for dataloader in dataloaders:
31 | for batch in dataloader:
32 | yield batch, dataloader
33 |
--------------------------------------------------------------------------------
/snorkel/classification/training/schedulers/shuffled_scheduler.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Sequence
3 |
4 | from snorkel.classification.data import DictDataLoader
5 |
6 | from .scheduler import BatchIterator, Scheduler
7 |
8 |
9 | class ShuffledScheduler(Scheduler):
10 | """Return batches from all dataloaders in shuffled order for each epoch."""
11 |
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | def get_batches(self, dataloaders: Sequence[DictDataLoader]) -> BatchIterator:
16 | """Return batches in shuffled order from dataloaders.
17 |
18 | Note that this shuffles the batch order, but it does not shuffle the datasets
19 | themselves; shuffling the datasets is specified in the DataLoaders directly.
20 |
21 | Parameters
22 | ----------
23 | dataloaders
24 | A sequence of dataloaders to get batches from
25 |
26 | Yields
27 | ------
28 | (batch, dataloader)
29 | batch is a tuple of (X_dict, Y_dict) and dataloader is the dataloader
30 | that that batch came from. That dataloader will not be accessed by the
31 | model; it is passed primarily so that the model can pull the necessary
32 | metadata to know what to do with the batch it has been given.
33 | """
34 | batch_counts = [len(dl) for dl in dataloaders]
35 | dataloader_iters = [iter(dl) for dl in dataloaders]
36 |
37 | dataloader_indices = []
38 | for idx, count in enumerate(batch_counts):
39 | dataloader_indices.extend([idx] * count)
40 |
41 | random.shuffle(dataloader_indices)
42 |
43 | for index in dataloader_indices:
44 | yield next(dataloader_iters[index]), dataloaders[index]
45 |
--------------------------------------------------------------------------------
/snorkel/classification/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 |
7 | TensorCollection = Union[torch.Tensor, dict, list, tuple]
8 |
9 |
10 | def list_to_tensor(item_list: List[torch.Tensor]) -> torch.Tensor:
11 | """Convert a list of torch.Tensor into a single torch.Tensor."""
12 |
13 | # Convert single value tensor
14 | if all(item_list[i].dim() == 0 for i in range(len(item_list))):
15 | item_tensor = torch.stack(item_list, dim=0)
16 | # Convert 2 or more-D tensor with the same shape
17 | elif all(
18 | (item_list[i].size() == item_list[0].size()) and (len(item_list[i].size()) != 1)
19 | for i in range(len(item_list))
20 | ):
21 | item_tensor = torch.stack(item_list, dim=0)
22 | # Convert reshape to 1-D tensor and then convert
23 | else:
24 | item_tensor, _ = pad_batch([item.view(-1) for item in item_list])
25 |
26 | return item_tensor
27 |
28 |
29 | def pad_batch(
30 | batch: List[torch.Tensor],
31 | max_len: int = 0,
32 | pad_value: int = 0,
33 | left_padded: bool = False,
34 | ) -> Tuple[torch.Tensor, torch.Tensor]:
35 | """Convert the batch into a padded tensor and mask tensor.
36 |
37 | Parameters
38 | ----------
39 | batch
40 | The data for padding
41 | max_len
42 | Max length of sequence of padding
43 | pad_value
44 | The value to use for padding
45 | left_padded
46 | If True, pad on the left, otherwise on the right
47 |
48 | Returns
49 | -------
50 | Tuple[torch.Tensor, torch.Tensor]
51 | The padded matrix and correspoing mask matrix.
52 | """
53 |
54 | batch_size = len(batch)
55 | max_seq_len = int(np.max([len(item) for item in batch])) # type: ignore
56 |
57 | if max_len > 0 and max_len < max_seq_len:
58 | max_seq_len = max_len
59 |
60 | padded_batch = batch[0].new_full((batch_size, max_seq_len), pad_value)
61 |
62 | for i, item in enumerate(batch):
63 | length = min(len(item), max_seq_len) # type: ignore
64 | if left_padded:
65 | padded_batch[i, -length:] = item[-length:]
66 | else:
67 | padded_batch[i, :length] = item[:length]
68 |
69 | mask_batch = torch.eq(padded_batch.clone().detach(), pad_value).type_as(
70 | padded_batch
71 | )
72 |
73 | return padded_batch, mask_batch
74 |
75 |
76 | def move_to_device(
77 | obj: TensorCollection, device: int = -1
78 | ) -> TensorCollection: # pragma: no cover
79 | """Recursively move torch.Tensors to a given CUDA device.
80 |
81 | Given a structure (possibly) containing Tensors on the CPU, move all the Tensors
82 | to the specified GPU (or do nothing, if they should beon the CPU).
83 |
84 | Originally from:
85 | https://github.com/HazyResearch/metal/blob/mmtl_clean/metal/utils.py
86 |
87 | Parameters
88 | ----------
89 | obj
90 | Tensor or collection of Tensors to move
91 | device
92 | Device to move Tensors to
93 | device = -1 -> "cpu"
94 | device = 0 -> "cuda:0"
95 | """
96 |
97 | if device < 0 or not torch.cuda.is_available():
98 | return obj
99 | elif isinstance(obj, torch.Tensor):
100 | return obj.cuda(device) # type: ignore
101 | elif isinstance(obj, dict):
102 | return {key: move_to_device(value, device) for key, value in obj.items()}
103 | elif isinstance(obj, list):
104 | return [move_to_device(item, device) for item in obj]
105 | elif isinstance(obj, tuple):
106 | return tuple([move_to_device(item, device) for item in obj])
107 | else:
108 | return obj
109 |
110 |
111 | def collect_flow_outputs_by_suffix(
112 | output_dict: Dict[str, torch.Tensor], suffix: str
113 | ) -> List[torch.Tensor]:
114 | """Return output_dict outputs specified by suffix, ordered by sorted flow_name."""
115 | return [
116 | output_dict[flow_name]
117 | for flow_name in sorted(output_dict.keys())
118 | if flow_name.endswith(suffix)
119 | ]
120 |
121 |
122 | def metrics_dict_to_dataframe(metrics_dict: Dict[str, float]) -> pd.DataFrame:
123 | """Format a metrics_dict (with keys 'label/dataset/split/metric') format as a pandas DataFrame."""
124 |
125 | metrics = []
126 |
127 | for full_metric, score in metrics_dict.items():
128 | label_name, dataset_name, split, metric = tuple(full_metric.split("/"))
129 | metrics.append((label_name, dataset_name, split, metric, score))
130 |
131 | return pd.DataFrame(
132 | metrics, columns=["label", "dataset", "split", "metric", "score"]
133 | )
134 |
--------------------------------------------------------------------------------
/snorkel/contrib/README.md:
--------------------------------------------------------------------------------
1 | # Snorkel contrib
2 |
3 | tl;dr Have something fun that you think others could benefit from but isn't
4 | ready for prime time yet? Put it here!
5 |
6 | Any code in this directory is not officially supported, and may change or be
7 | removed at any time without notice.
8 |
9 | The contrib directory contains project directories, each of which has designated
10 | owners. It is meant to contain features and contributions whose interfaces may
11 | change, or which require some testing to see whether they can find broader acceptance.
12 | You may be asked to refactor code in contrib to use some feature inside core or
13 | in another contrib project rather than reimplementing the feature.
14 |
15 | When adding a project, please stick to the following directory structure:
16 | Create a project directory in `contrib/`, and mirror the portions of the
17 | Snorkel tree that your project requires underneath `contrib/my_project/`.
18 |
19 | For example, let's say you create foo ops for labeling in two files:
20 | `foo_ops.py` and `foo_ops_test.py`. If you were to merge those files
21 | directly into Snorkel, they would live in `snorkel/labeling/foo_ops.py` and
22 | `test/labeling/foo_ops_test.py`. In `contrib/`, they are part
23 | of project `foo`, and their full paths are `contrib/foo/snorkel/labeling/foo_ops.py`
24 | and `contrib/foo/test/labeling/foo_ops_test.py`.
25 |
26 |
27 | *Adapted from [TensorFlow contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib).*
28 |
--------------------------------------------------------------------------------
/snorkel/contrib/__init__.py:
--------------------------------------------------------------------------------
1 | """Contributed modules for Snorkel."""
2 |
--------------------------------------------------------------------------------
/snorkel/labeling/__init__.py:
--------------------------------------------------------------------------------
1 | """Programmatic data set labeling: LF creation, models, and analysis utilities."""
2 |
3 | from .analysis import LFAnalysis # noqa: F401
4 | from .apply.core import LFApplier # noqa: F401
5 | from .apply.pandas import PandasLFApplier # noqa: F401
6 | from .lf.core import LabelingFunction, labeling_function # noqa: F401
7 | from .utils import filter_unlabeled_dataframe # noqa: F401
8 |
--------------------------------------------------------------------------------
/snorkel/labeling/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/labeling/apply/__init__.py
--------------------------------------------------------------------------------
/snorkel/labeling/apply/dask.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Union
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from dask import dataframe as dd
7 | from dask.distributed import Client
8 |
9 | from .core import BaseLFApplier, _FunctionCaller
10 | from .pandas import apply_lfs_to_data_point, rows_to_triplets
11 |
12 | Scheduler = Union[str, Client]
13 |
14 |
15 | class DaskLFApplier(BaseLFApplier):
16 | """LF applier for a Dask DataFrame.
17 |
18 | Dask DataFrames consist of partitions, each being a Pandas DataFrame.
19 | This allows for efficient parallel computation over DataFrame rows.
20 | For more information, see https://docs.dask.org/en/stable/dataframe.html
21 | """
22 |
23 | def apply(
24 | self,
25 | df: dd.DataFrame,
26 | scheduler: Scheduler = "processes",
27 | fault_tolerant: bool = False,
28 | ) -> np.ndarray:
29 | """Label Dask DataFrame of data points with LFs.
30 |
31 | Parameters
32 | ----------
33 | df
34 | Dask DataFrame containing data points to be labeled by LFs
35 | scheduler
36 | A Dask scheduling configuration: either a string option or
37 | a ``Client``. For more information, see
38 | https://docs.dask.org/en/stable/scheduling.html#
39 | fault_tolerant
40 | Output ``-1`` if LF execution fails?
41 |
42 | Returns
43 | -------
44 | np.ndarray
45 | Matrix of labels emitted by LFs
46 | """
47 | f_caller = _FunctionCaller(fault_tolerant)
48 | apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
49 | map_fn = df.map_partitions(lambda p_df: p_df.apply(apply_fn, axis=1))
50 | labels = map_fn.compute(scheduler=scheduler)
51 | labels_with_index = rows_to_triplets(labels)
52 | return self._numpy_from_row_data(labels_with_index)
53 |
54 |
55 | class PandasParallelLFApplier(DaskLFApplier):
56 | """Parallel LF applier for a Pandas DataFrame.
57 |
58 | Creates a Dask DataFrame from a Pandas DataFrame, then uses
59 | ``DaskLFApplier`` to label data in parallel. See ``DaskLFApplier``.
60 | """
61 |
62 | def apply( # type: ignore
63 | self,
64 | df: pd.DataFrame,
65 | n_parallel: int = 2,
66 | scheduler: Scheduler = "processes",
67 | fault_tolerant: bool = False,
68 | ) -> np.ndarray:
69 | """Label Pandas DataFrame of data points with LFs in parallel using Dask.
70 |
71 | Parameters
72 | ----------
73 | df
74 | Pandas DataFrame containing data points to be labeled by LFs
75 | n_parallel
76 | Parallelism level for LF application. Corresponds to ``npartitions``
77 | in constructed Dask DataFrame. For ``scheduler="processes"``, number
78 | of processes launched. Recommended to be no more than the number
79 | of cores on the running machine.
80 | scheduler
81 | A Dask scheduling configuration: either a string option or
82 | a ``Client``. For more information, see
83 | https://docs.dask.org/en/stable/scheduling.html#
84 | fault_tolerant
85 | Output ``-1`` if LF execution fails?
86 |
87 | Returns
88 | -------
89 | np.ndarray
90 | Matrix of labels emitted by LFs
91 | """
92 | if n_parallel < 2:
93 | raise ValueError(
94 | "n_parallel should be >= 2. "
95 | "For single process Pandas, use PandasLFApplier."
96 | )
97 | df = dd.from_pandas(df, npartitions=n_parallel)
98 | return super().apply(df, scheduler=scheduler, fault_tolerant=fault_tolerant)
99 |
--------------------------------------------------------------------------------
/snorkel/labeling/apply/pandas.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import List, Tuple, Union
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from tqdm import tqdm
7 |
8 | from snorkel.labeling.lf import LabelingFunction
9 | from snorkel.types import DataPoint
10 |
11 | from .core import ApplierMetadata, BaseLFApplier, RowData, _FunctionCaller
12 |
13 | PandasRowData = List[Tuple[int, int]]
14 |
15 |
16 | def apply_lfs_to_data_point(
17 | x: DataPoint, lfs: List[LabelingFunction], f_caller: _FunctionCaller
18 | ) -> PandasRowData:
19 | """Label a single data point with a set of LFs.
20 |
21 | Parameters
22 | ----------
23 | x
24 | Data point to label
25 | lfs
26 | Set of LFs to label ``x`` with
27 | f_caller
28 | A ``_FunctionCaller`` to record failed LF executions
29 |
30 | Returns
31 | -------
32 | RowData
33 | A list of (LF index, label) tuples
34 | """
35 | labels = []
36 | for j, lf in enumerate(lfs):
37 | y = f_caller(lf, x)
38 | if y >= 0:
39 | labels.append((j, y))
40 | return labels
41 |
42 |
43 | def rows_to_triplets(labels: List[PandasRowData]) -> List[RowData]:
44 | """Convert list of list sparse matrix representation to list of triplets."""
45 | return [
46 | [(index, j, y) for j, y in row_labels]
47 | for index, row_labels in enumerate(labels)
48 | ]
49 |
50 |
51 | class PandasLFApplier(BaseLFApplier):
52 | """LF applier for a Pandas DataFrame.
53 |
54 | Data points are stored as ``Series`` in a DataFrame. The LFs
55 | are executed via a ``pandas.DataFrame.apply`` call, which
56 | is single-process and can be slow for large DataFrames.
57 | For large datasets, consider ``DaskLFApplier`` or ``SparkLFApplier``.
58 |
59 | Parameters
60 | ----------
61 | lfs
62 | LFs that this applier executes on examples
63 |
64 | Example
65 | -------
66 | >>> from snorkel.labeling import labeling_function
67 | >>> @labeling_function()
68 | ... def is_big_num(x):
69 | ... return 1 if x.num > 42 else 0
70 | >>> applier = PandasLFApplier([is_big_num])
71 | >>> applier.apply(pd.DataFrame(dict(num=[10, 100], text=["hello", "hi"])))
72 | array([[0], [1]])
73 | """
74 |
75 | def apply(
76 | self,
77 | df: pd.DataFrame,
78 | progress_bar: bool = True,
79 | fault_tolerant: bool = False,
80 | return_meta: bool = False,
81 | ) -> Union[np.ndarray, Tuple[np.ndarray, ApplierMetadata]]:
82 | """Label Pandas DataFrame of data points with LFs.
83 |
84 | Parameters
85 | ----------
86 | df
87 | Pandas DataFrame containing data points to be labeled by LFs
88 | progress_bar
89 | Display a progress bar?
90 | fault_tolerant
91 | Output ``-1`` if LF execution fails?
92 | return_meta
93 | Return metadata from apply call?
94 |
95 | Returns
96 | -------
97 | np.ndarray
98 | Matrix of labels emitted by LFs
99 | ApplierMetadata
100 | Metadata, such as fault counts, for the apply call
101 | """
102 | f_caller = _FunctionCaller(fault_tolerant)
103 | apply_fn = partial(apply_lfs_to_data_point, lfs=self._lfs, f_caller=f_caller)
104 | call_fn = df.apply
105 | if progress_bar:
106 | tqdm.pandas()
107 | call_fn = df.progress_apply
108 | labels = call_fn(apply_fn, axis=1)
109 | labels_with_index = rows_to_triplets(labels)
110 | L = self._numpy_from_row_data(labels_with_index)
111 | if return_meta:
112 | return L, ApplierMetadata(f_caller.fault_counts)
113 | return L
114 |
--------------------------------------------------------------------------------
/snorkel/labeling/apply/spark.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | from pyspark import RDD
5 |
6 | from snorkel.types import DataPoint
7 |
8 | from .core import BaseLFApplier, RowData, _FunctionCaller, apply_lfs_to_data_point
9 |
10 |
11 | class SparkLFApplier(BaseLFApplier):
12 | r"""LF applier for a Spark RDD.
13 |
14 | Data points are stored as ``Row``\s in an RDD, and a Spark
15 | ``map`` job is submitted to execute the LFs. A common
16 | way to obtain an RDD is via a PySpark DataFrame. For an
17 | example usage with AWS EMR instructions, see
18 | ``test/labeling/apply/lf_applier_spark_test_script.py``.
19 | """
20 |
21 | def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
22 | """Label PySpark RDD of data points with LFs.
23 |
24 | Parameters
25 | ----------
26 | data_points
27 | PySpark RDD containing data points to be labeled by LFs
28 | fault_tolerant
29 | Output ``-1`` if LF execution fails?
30 |
31 | Returns
32 | -------
33 | np.ndarray
34 | Matrix of labels emitted by LFs
35 | """
36 | f_caller = _FunctionCaller(fault_tolerant)
37 |
38 | def map_fn(args: Tuple[DataPoint, int]) -> RowData:
39 | return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)
40 |
41 | labels = data_points.zipWithIndex().map(map_fn).collect()
42 | return self._numpy_from_row_data(labels)
43 |
--------------------------------------------------------------------------------
/snorkel/labeling/lf/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import LabelingFunction, labeling_function # noqa: F401
2 |
--------------------------------------------------------------------------------
/snorkel/labeling/lf/core.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Mapping, Optional
2 |
3 | from snorkel.preprocess import BasePreprocessor
4 | from snorkel.types import DataPoint
5 |
6 |
7 | class LabelingFunction:
8 | """Base class for labeling functions.
9 |
10 | A labeling function (LF) is a function that takes a data point
11 | as input and produces an integer label, corresponding to a
12 | class. A labeling function can also abstain from voting by
13 | outputting ``-1``. For examples, see the Snorkel tutorials.
14 |
15 | This class wraps a Python function outputting a label. Extra
16 | functionality, such as running preprocessors and storing
17 | resources, is provided. Simple LFs can be defined via a
18 | decorator. See ``labeling_function``.
19 |
20 | Parameters
21 | ----------
22 | name
23 | Name of the LF
24 | f
25 | Function that implements the core LF logic
26 | resources
27 | Labeling resources passed in to ``f`` via ``kwargs``
28 | pre
29 | Preprocessors to run on data points before LF execution
30 |
31 | Raises
32 | ------
33 | ValueError
34 | Calling incorrectly defined preprocessors
35 |
36 | Attributes
37 | ----------
38 | name
39 | See above
40 | """
41 |
42 | def __init__(
43 | self,
44 | name: str,
45 | f: Callable[..., int],
46 | resources: Optional[Mapping[str, Any]] = None,
47 | pre: Optional[List[BasePreprocessor]] = None,
48 | ) -> None:
49 | self.name = name
50 | self._f = f
51 | self._resources = resources or {}
52 | self._pre = pre or []
53 |
54 | def _preprocess_data_point(self, x: DataPoint) -> DataPoint:
55 | for preprocessor in self._pre:
56 | x = preprocessor(x)
57 | if x is None:
58 | raise ValueError("Preprocessor should not return None")
59 | return x
60 |
61 | def __call__(self, x: DataPoint) -> int:
62 | """Label data point.
63 |
64 | Runs all preprocessors, then passes preprocessed data point to LF.
65 |
66 | Parameters
67 | ----------
68 | x
69 | Data point to label
70 |
71 | Returns
72 | -------
73 | int
74 | Label for data point
75 | """
76 | x = self._preprocess_data_point(x)
77 | return self._f(x, **self._resources)
78 |
79 | def __repr__(self) -> str:
80 | preprocessor_str = f", Preprocessors: {self._pre}"
81 | return f"{type(self).__name__} {self.name}{preprocessor_str}"
82 |
83 |
84 | class labeling_function:
85 | """Decorator to define a LabelingFunction object from a function.
86 |
87 | Parameters
88 | ----------
89 | name
90 | Name of the LF
91 | resources
92 | Labeling resources passed in to ``f`` via ``kwargs``
93 | pre
94 | Preprocessors to run on data points before LF execution
95 |
96 | Examples
97 | --------
98 | >>> @labeling_function()
99 | ... def f(x):
100 | ... return 0 if x.a > 42 else -1
101 | >>> f
102 | LabelingFunction f, Preprocessors: []
103 | >>> from types import SimpleNamespace
104 | >>> x = SimpleNamespace(a=90, b=12)
105 | >>> f(x)
106 | 0
107 |
108 | >>> @labeling_function(name="my_lf")
109 | ... def g(x):
110 | ... return 0 if x.a > 42 else -1
111 | >>> g
112 | LabelingFunction my_lf, Preprocessors: []
113 | """
114 |
115 | def __init__(
116 | self,
117 | name: Optional[str] = None,
118 | resources: Optional[Mapping[str, Any]] = None,
119 | pre: Optional[List[BasePreprocessor]] = None,
120 | ) -> None:
121 | if callable(name):
122 | raise ValueError("Looks like this decorator is missing parentheses!")
123 | self.name = name
124 | self.resources = resources
125 | self.pre = pre
126 |
127 | def __call__(self, f: Callable[..., int]) -> LabelingFunction:
128 | """Wrap a function to create a ``LabelingFunction``.
129 |
130 | Parameters
131 | ----------
132 | f
133 | Function that implements the core LF logic
134 |
135 | Returns
136 | -------
137 | LabelingFunction
138 | New ``LabelingFunction`` executing logic in wrapped function
139 | """
140 | name = self.name or f.__name__
141 | return LabelingFunction(name=name, f=f, resources=self.resources, pre=self.pre)
142 |
--------------------------------------------------------------------------------
/snorkel/labeling/lf/nlp_spark.py:
--------------------------------------------------------------------------------
1 | from snorkel.preprocess.nlp import SpacyPreprocessor
2 | from snorkel.preprocess.spark import make_spark_preprocessor
3 |
4 | from .nlp import (
5 | BaseNLPLabelingFunction,
6 | SpacyPreprocessorParameters,
7 | base_nlp_labeling_function,
8 | )
9 |
10 |
11 | class SparkNLPLabelingFunction(BaseNLPLabelingFunction):
12 | r"""Special labeling function type for SpaCy-based LFs running on Spark.
13 |
14 | This class is a Spark-compatible version of ``NLPLabelingFunction``.
15 | See ``NLPLabelingFunction`` for details.
16 |
17 | Parameters
18 | ----------
19 | name
20 | Name of the LF
21 | f
22 | Function that implements the core LF logic
23 | resources
24 | Labeling resources passed in to ``f`` via ``kwargs``
25 | pre
26 | Preprocessors to run before SpacyPreprocessor is executed
27 | text_field
28 | Name of data point text field to input
29 | doc_field
30 | Name of data point field to output parsed document to
31 | language
32 | SpaCy model to load
33 | See https://spacy.io/usage/models#usage
34 | disable
35 | List of pipeline components to disable
36 | See https://spacy.io/usage/processing-pipelines#disabling
37 | memoize
38 | Memoize preprocessor outputs?
39 | memoize_key
40 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
41 | gpu
42 | Prefer Spacy GPU processing?
43 |
44 | Raises
45 | ------
46 | ValueError
47 | Calling incorrectly defined preprocessors
48 |
49 | Attributes
50 | ----------
51 | name
52 | See above
53 | """
54 |
55 | @classmethod
56 | def _create_preprocessor(
57 | cls, parameters: SpacyPreprocessorParameters
58 | ) -> SpacyPreprocessor:
59 | preprocessor = SpacyPreprocessor(**parameters._asdict())
60 | make_spark_preprocessor(preprocessor)
61 | return preprocessor
62 |
63 |
64 | class spark_nlp_labeling_function(base_nlp_labeling_function):
65 | """Decorator to define a SparkNLPLabelingFunction object from a function.
66 |
67 | Parameters
68 | ----------
69 | name
70 | Name of the LF
71 | resources
72 | Labeling resources passed in to ``f`` via ``kwargs``
73 | pre
74 | Preprocessors to run before SpacyPreprocessor is executed
75 | text_field
76 | Name of data point text field to input
77 | doc_field
78 | Name of data point field to output parsed document to
79 | language
80 | SpaCy model to load
81 | See https://spacy.io/usage/models#usage
82 | disable
83 | List of pipeline components to disable
84 | See https://spacy.io/usage/processing-pipelines#disabling
85 | memoize
86 | Memoize preprocessor outputs?
87 | memoize_key
88 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
89 |
90 | Example
91 | -------
92 | >>> @spark_nlp_labeling_function()
93 | ... def has_person_mention(x):
94 | ... person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
95 | ... return 0 if len(person_ents) > 0 else -1
96 | >>> has_person_mention
97 | SparkNLPLabelingFunction has_person_mention, Preprocessors: [SpacyPreprocessor...]
98 |
99 | >>> from pyspark.sql import Row
100 | >>> x = Row(text="The movie was good.")
101 | >>> has_person_mention(x)
102 | -1
103 | """
104 |
105 | _lf_cls = SparkNLPLabelingFunction
106 |
--------------------------------------------------------------------------------
/snorkel/labeling/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .baselines import MajorityClassVoter, MajorityLabelVoter, RandomVoter # noqa: F401
2 | from .label_model import LabelModel # noqa: F401
3 |
--------------------------------------------------------------------------------
/snorkel/labeling/model/base_labeler.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Dict, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 |
8 | from snorkel.analysis import Scorer
9 | from snorkel.utils import probs_to_preds
10 |
11 |
12 | class BaseLabeler(ABC):
13 | """Abstract baseline label voter class."""
14 |
15 | def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
16 | self.cardinality = cardinality
17 |
18 | @abstractmethod
19 | def predict_proba(self, L: np.ndarray) -> np.ndarray:
20 | """Abstract method for predicting probabilistic labels given a label matrix.
21 |
22 | Parameters
23 | ----------
24 | L
25 | An [n,m] matrix with values in {-1,0,1,...,k-1}f
26 |
27 | Returns
28 | -------
29 | np.ndarray
30 | An [n,k] array of probabilistic labels
31 | """
32 | pass
33 |
34 | def predict(
35 | self,
36 | L: np.ndarray,
37 | return_probs: Optional[bool] = False,
38 | tie_break_policy: str = "abstain",
39 | ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
40 | """Return predicted labels, with ties broken according to policy.
41 |
42 | Policies to break ties include:
43 | "abstain": return an abstain vote (-1)
44 | "true-random": randomly choose among the tied options
45 | "random": randomly choose among tied option using deterministic hash
46 |
47 | NOTE: if tie_break_policy="true-random", repeated runs may have slightly different
48 | results due to difference in broken ties
49 |
50 |
51 | Parameters
52 | ----------
53 | L
54 | An [n,m] matrix with values in {-1,0,1,...,k-1}
55 | return_probs
56 | Whether to return probs along with preds
57 | tie_break_policy
58 | Policy to break ties when converting probabilistic labels to predictions
59 |
60 | Returns
61 | -------
62 | np.ndarray
63 | An [n,1] array of integer labels
64 |
65 | (np.ndarray, np.ndarray)
66 | An [n,1] array of integer labels and an [n,k] array of probabilistic labels
67 | """
68 | Y_probs = self.predict_proba(L)
69 | Y_p = probs_to_preds(Y_probs, tie_break_policy)
70 | if return_probs:
71 | return Y_p, Y_probs
72 | return Y_p
73 |
74 | def score(
75 | self,
76 | L: np.ndarray,
77 | Y: np.ndarray,
78 | metrics: Optional[List[str]] = ["accuracy"],
79 | tie_break_policy: str = "abstain",
80 | ) -> Dict[str, float]:
81 | """Calculate one or more scores from user-specified and/or user-defined metrics.
82 |
83 | Parameters
84 | ----------
85 | L
86 | An [n,m] matrix with values in {-1,0,1,...,k-1}
87 | Y
88 | Gold labels associated with data points in L
89 | metrics
90 | A list of metric names
91 | tie_break_policy
92 | Policy to break ties when converting probabilistic labels to predictions
93 |
94 |
95 | Returns
96 | -------
97 | Dict[str, float]
98 | A dictionary mapping metric names to metric scores
99 | """
100 | if tie_break_policy == "abstain": # pragma: no cover
101 | logging.warning(
102 | "Metrics calculated over data points with non-abstain labels only"
103 | )
104 |
105 | Y_pred, Y_prob = self.predict(
106 | L, return_probs=True, tie_break_policy=tie_break_policy
107 | )
108 |
109 | scorer = Scorer(metrics=metrics)
110 | results = scorer.score(Y, Y_pred, Y_prob)
111 | return results
112 |
113 | def save(self, destination: str) -> None:
114 | """Save label model.
115 |
116 | Parameters
117 | ----------
118 | destination
119 | Filename for saving model
120 |
121 | Example
122 | -------
123 | >>> label_model.save('./saved_label_model.pkl') # doctest: +SKIP
124 | """
125 | f = open(destination, "wb")
126 | pickle.dump(self.__dict__, f)
127 | f.close()
128 |
129 | def load(self, source: str) -> None:
130 | """Load existing label model.
131 |
132 | Parameters
133 | ----------
134 | source
135 | Filename to load model from
136 |
137 | Example
138 | -------
139 | Load parameters saved in ``saved_label_model``
140 |
141 | >>> label_model.load('./saved_label_model.pkl') # doctest: +SKIP
142 | """
143 | f = open(source, "rb")
144 | tmp_dict = pickle.load(f)
145 | f.close()
146 | self.__dict__.update(tmp_dict)
147 |
--------------------------------------------------------------------------------
/snorkel/labeling/model/baselines.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 |
5 | from snorkel.labeling.model.base_labeler import BaseLabeler
6 |
7 |
8 | class RandomVoter(BaseLabeler):
9 | """Random vote label model.
10 |
11 | Example
12 | -------
13 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]])
14 | >>> random_voter = RandomVoter()
15 | >>> predictions = random_voter.predict_proba(L)
16 | """
17 |
18 | def predict_proba(self, L: np.ndarray) -> np.ndarray:
19 | """
20 | Assign random votes to the data points.
21 |
22 | Parameters
23 | ----------
24 | L
25 | An [n, m] matrix of labels
26 |
27 | Returns
28 | -------
29 | np.ndarray
30 | A [n, k] array of probabilistic labels
31 |
32 | Example
33 | -------
34 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]])
35 | >>> random_voter = RandomVoter()
36 | >>> predictions = random_voter.predict_proba(L)
37 | """
38 | n = L.shape[0]
39 | Y_p = np.random.rand(n, self.cardinality)
40 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1)
41 | return Y_p
42 |
43 |
44 | class MajorityClassVoter(BaseLabeler):
45 | """Majority class label model."""
46 |
47 | def fit( # type: ignore
48 | self, balance: np.ndarray, *args: Any, **kwargs: Any
49 | ) -> None:
50 | """Train majority class model.
51 |
52 | Set class balance for majority class label model.
53 |
54 | Parameters
55 | ----------
56 | balance
57 | A [k] array of class probabilities
58 | """
59 | self.balance = balance
60 |
61 | def predict_proba(self, L: np.ndarray) -> np.ndarray:
62 | """Predict probabilities using majority class.
63 |
64 | Assign majority class vote to each datapoint.
65 | In case of multiple majority classes, assign equal probabilities among them.
66 |
67 |
68 | Parameters
69 | ----------
70 | L
71 | An [n, m] matrix of labels
72 |
73 | Returns
74 | -------
75 | np.ndarray
76 | A [n, k] array of probabilistic labels
77 |
78 | Example
79 | -------
80 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]])
81 | >>> maj_class_voter = MajorityClassVoter()
82 | >>> maj_class_voter.fit(balance=np.array([0.8, 0.2]))
83 | >>> maj_class_voter.predict_proba(L)
84 | array([[1., 0.],
85 | [1., 0.],
86 | [1., 0.]])
87 | """
88 | n = L.shape[0]
89 | Y_p = np.zeros((n, self.cardinality))
90 | max_classes = np.where(self.balance == max(self.balance))
91 | for c in max_classes:
92 | Y_p[:, c] = 1.0
93 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1)
94 | return Y_p
95 |
96 |
97 | class MajorityLabelVoter(BaseLabeler):
98 | """Majority vote label model."""
99 |
100 | def predict_proba(self, L: np.ndarray) -> np.ndarray:
101 | """Predict probabilities using majority vote.
102 |
103 | Assign vote by calculating majority vote across all labeling functions.
104 | In case of ties, non-integer probabilities are possible.
105 |
106 | Parameters
107 | ----------
108 | L
109 | An [n, m] matrix of labels
110 |
111 | Returns
112 | -------
113 | np.ndarray
114 | A [n, k] array of probabilistic labels
115 |
116 | Example
117 | -------
118 | >>> L = np.array([[0, 0, -1], [-1, 0, 1], [1, -1, 0]])
119 | >>> maj_voter = MajorityLabelVoter()
120 | >>> maj_voter.predict_proba(L)
121 | array([[1. , 0. ],
122 | [0.5, 0.5],
123 | [0.5, 0.5]])
124 | """
125 | n, m = L.shape
126 | Y_p = np.zeros((n, self.cardinality))
127 | for i in range(n):
128 | counts = np.zeros(self.cardinality)
129 | for j in range(m):
130 | if L[i, j] != -1:
131 | counts[L[i, j]] += 1
132 | Y_p[i, :] = np.where(counts == max(counts), 1, 0)
133 | Y_p /= Y_p.sum(axis=1).reshape(-1, 1)
134 | return Y_p
135 |
--------------------------------------------------------------------------------
/snorkel/labeling/model/graph_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, List, Tuple
2 |
3 | import networkx as nx
4 |
5 |
6 | def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Graph:
7 | """
8 | Given a set of int nodes i and edges (i,j), returns a clique tree.
9 |
10 | Clique tree is an object G for which:
11 | - G.nodes[i]['members'] contains the set of original nodes in the ith
12 | maximal clique
13 | - G[i][j]['members'] contains the set of original nodes in the seperator
14 | set between maximal cliques i and j
15 |
16 | Note: This method is currently only implemented for chordal graphs; TODO:
17 | add a step to triangulate non-chordal graphs.
18 |
19 | Parameters
20 | ----------
21 | nodes
22 | A list of nodes indices
23 | edges
24 | A list of tuples, where each tuple has indices for connected nodes
25 |
26 | Returns
27 | -------
28 | networkx.Graph
29 | An object G representing clique tree
30 | """
31 | # Form the original graph G1
32 | G1 = nx.Graph()
33 | G1.add_nodes_from(nodes)
34 | G1.add_edges_from(edges)
35 |
36 | # Check if graph is chordal
37 | # TODO: Add step to triangulate graph if not
38 | if not nx.is_chordal(G1):
39 | raise NotImplementedError("Graph triangulation not implemented.")
40 |
41 | # Create maximal clique graph G2
42 | # Each node is a maximal clique C_i
43 | # Let w = |C_i \cap C_j|; C_i, C_j have an edge with weight w if w > 0
44 | G2 = nx.Graph()
45 | for i, c in enumerate(nx.chordal_graph_cliques(G1)):
46 | G2.add_node(i, members=c)
47 | for i in G2.nodes:
48 | for j in G2.nodes:
49 | S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"])
50 | w = len(S)
51 | if w > 0:
52 | G2.add_edge(i, j, weight=w, members=S)
53 |
54 | # Return a minimum spanning tree of G2
55 | return nx.minimum_spanning_tree(G2)
56 |
--------------------------------------------------------------------------------
/snorkel/labeling/model/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 | from typing import DefaultDict, Dict, List, Optional
4 |
5 |
6 | class Logger:
7 | """Class for logging LabelModel.
8 |
9 | Parameters
10 | ----------
11 | log_freq
12 | Number of units at which to log model
13 |
14 | Attributes
15 | ----------
16 | log_freq
17 | Number of units at which to log model
18 | unit_count
19 | Running total of number of units passed without logging
20 | """
21 |
22 | def __init__(self, log_freq: int) -> None:
23 | self.log_freq = log_freq
24 | self.unit_count = -1
25 |
26 | def check(self) -> bool:
27 | """Check if the logging frequency has been met.
28 |
29 | Returns
30 | -------
31 | bool
32 | Whether to log or not based on logging frequency
33 | """
34 | self.unit_count += 1
35 | return self.unit_count % self.log_freq == 0
36 |
37 | def log(self, metrics_dict: Dict[str, float]) -> None:
38 | """Print all metrics in metrics_dict to screen.
39 |
40 | Parameters
41 | ----------
42 | metrics_dict
43 | Dictionary of metric names (keys) and values to log
44 |
45 | Raises
46 | ------
47 | Exception
48 | If metric names formatted incorrectly
49 | """
50 | score_strings: DefaultDict[str, List[str]] = defaultdict(list)
51 | for full_name, value in metrics_dict.items():
52 | task: Optional[str]
53 | if full_name.count("/") == 2:
54 | task, split, metric = full_name.split("/")
55 | elif full_name.count("/") == 1:
56 | task = None
57 | split, metric = full_name.split("/")
58 | else:
59 | msg = f"Metric should have form task/split/metric or split/metric, not: {full_name}"
60 | raise Exception(msg)
61 |
62 | if task:
63 | metric_name = f"{task}/{metric}"
64 | else:
65 | metric_name = metric
66 | if isinstance(value, float):
67 | score_strings[split].append(f"{metric_name}={value:0.3f}")
68 | else:
69 | score_strings[split].append(f"{metric_name}={value}")
70 |
71 | string = f"[{self.unit_count} epochs]:"
72 |
73 | if score_strings["train"]:
74 | train_scores = f"{', '.join(score_strings['train'])}"
75 | string += f" TRAIN:[{train_scores}]"
76 | if score_strings["valid"]:
77 | valid_scores = f"{', '.join(score_strings['valid'])}"
78 | string += f" VALID:[{valid_scores}]"
79 | logging.info(string)
80 |
--------------------------------------------------------------------------------
/snorkel/labeling/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 |
7 | def filter_unlabeled_dataframe(
8 | X: pd.DataFrame, y: np.ndarray, L: np.ndarray
9 | ) -> Tuple[pd.DataFrame, np.ndarray]:
10 | """Filter out examples not covered by any labeling function.
11 |
12 | Parameters
13 | ----------
14 | X
15 | Data points in a Pandas DataFrame.
16 | y
17 | Matrix of probabilities output by label model's predict_proba method.
18 | L
19 | Matrix of labels emitted by LFs.
20 |
21 | Returns
22 | -------
23 | pd.DataFrame
24 | Data points that were labeled by at least one LF in L.
25 | np.ndarray
26 | Probabilities matrix for data points labeled by at least one LF in L.
27 | """
28 | mask = (L != -1).any(axis=1)
29 | return X.iloc[mask], y[mask]
30 |
--------------------------------------------------------------------------------
/snorkel/map/__init__.py:
--------------------------------------------------------------------------------
1 | """Generic utilities for data point to data point operations."""
2 |
3 | from .core import BaseMapper, LambdaMapper, Mapper, lambda_mapper # noqa: F401
4 |
--------------------------------------------------------------------------------
/snorkel/map/spark.py:
--------------------------------------------------------------------------------
1 | from pyspark.sql import Row
2 |
3 | from snorkel.types import FieldMap
4 |
5 | from .core import Mapper
6 |
7 |
8 | def _update_fields(x: Row, mapped_fields: FieldMap) -> Row:
9 | # ``pyspark.sql.Row`` objects are not mutable, so need to
10 | # reconstruct
11 | all_fields = x.asDict()
12 | all_fields.update(mapped_fields)
13 | return Row(**all_fields)
14 |
15 |
16 | def make_spark_mapper(mapper: Mapper) -> Mapper:
17 | """Convert ``Mapper`` to be compatible with PySpark.
18 |
19 | Parameters
20 | ----------
21 | mapper
22 | Mapper to make compatible with PySpark
23 | """
24 | mapper._update_fields = _update_fields # type: ignore
25 | return mapper
26 |
--------------------------------------------------------------------------------
/snorkel/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | """Preprocessors for LFs, TFs, and SFs."""
2 |
3 | from .core import ( # noqa: F401
4 | BasePreprocessor,
5 | LambdaPreprocessor,
6 | Preprocessor,
7 | preprocessor,
8 | )
9 |
--------------------------------------------------------------------------------
/snorkel/preprocess/core.py:
--------------------------------------------------------------------------------
1 | from snorkel.map import BaseMapper, LambdaMapper, Mapper, lambda_mapper
2 |
3 | """Base classes for preprocessors.
4 |
5 | A preprocessor is a data point to data point mapping in a labeling
6 | pipeline. This allows Snorkel operations (e.g. LFs) to share common
7 | preprocessing steps that make it easier to express labeling logic.
8 | A simple example for text processing is concatenating the title and
9 | body of an article. For a more complex example, see
10 | ``snorkel.preprocess.nlp.SpacyPreprocessor``.
11 | """
12 |
13 | # Used for type checking only
14 | # Note: subclassing as below trips up mypy
15 | BasePreprocessor = BaseMapper
16 |
17 |
18 | class Preprocessor(Mapper):
19 | """Base class for preprocessors.
20 |
21 | See ``snorkel.map.core.Mapper`` for details.
22 | """
23 |
24 | pass
25 |
26 |
27 | class LambdaPreprocessor(LambdaMapper):
28 | """Convenience class for defining preprocessors from functions.
29 |
30 | See ``snorkel.map.core.LambdaMapper`` for details.
31 | """
32 |
33 | pass
34 |
35 |
36 | class preprocessor(lambda_mapper):
37 | """Decorate functions to create preprocessors.
38 |
39 | See ``snorkel.map.core.lambda_mapper`` for details.
40 |
41 | Example
42 | -------
43 | >>> @preprocessor()
44 | ... def combine_text_preprocessor(x):
45 | ... x.article = f"{x.title} {x.body}"
46 | ... return x
47 | >>> from snorkel.preprocess.nlp import SpacyPreprocessor
48 | >>> spacy_preprocessor = SpacyPreprocessor("article", "article_parsed")
49 |
50 | We can now add our preprocessors to an LF.
51 |
52 | >>> preprocessors = [combine_text_preprocessor, spacy_preprocessor]
53 | >>> from snorkel.labeling.lf import labeling_function
54 | >>> @labeling_function(pre=preprocessors)
55 | ... def article_mentions_person(x):
56 | ... for ent in x.article_parsed.ents:
57 | ... if ent.label_ == "PERSON":
58 | ... return ABSTAIN
59 | ... return NEGATIVE
60 | """
61 |
62 | pass
63 |
--------------------------------------------------------------------------------
/snorkel/preprocess/nlp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import spacy
4 |
5 | from snorkel.types import FieldMap, HashingFunction
6 |
7 | from .core import BasePreprocessor, Preprocessor
8 |
9 | EN_CORE_WEB_SM = "en_core_web_sm"
10 |
11 |
12 | class SpacyPreprocessor(Preprocessor):
13 | """Preprocessor that parses input text via a SpaCy model.
14 |
15 | A common approach to writing LFs over text is to first use
16 | a natural language parser to decompose the text into tokens,
17 | part-of-speech tags, etc. SpaCy (https://spacy.io/) is a
18 | popular tool for doing this. This preprocessor adds a
19 | SpaCy ``Doc`` object to the data point. A ``Doc`` object is
20 | a sequence of ``Token`` objects, which contain information
21 | on lemmatization, parts-of-speech, etc. ``Doc`` objects also
22 | contain fields like ``Doc.ents``, a list of named entities,
23 | and ``Doc.noun_chunks``, a list of noun phrases. For details
24 | of SpaCy ``Doc`` objects and a full attribute listing,
25 | see https://spacy.io/api/doc.
26 |
27 | Parameters
28 | ----------
29 | text_field
30 | Name of data point text field to input
31 | doc_field
32 | Name of data point field to output parsed document to
33 | language
34 | SpaCy model to load
35 | See https://spacy.io/usage/models#usage
36 | disable
37 | List of pipeline components to disable
38 | See https://spacy.io/usage/processing-pipelines#disabling
39 | pre
40 | Preprocessors to run before this preprocessor is executed
41 | memoize
42 | Memoize preprocessor outputs?
43 | memoize_key
44 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
45 | gpu
46 | Prefer Spacy GPU processing?
47 | """
48 |
49 | def __init__(
50 | self,
51 | text_field: str,
52 | doc_field: str,
53 | language: str = EN_CORE_WEB_SM,
54 | disable: Optional[List[str]] = None,
55 | pre: Optional[List[BasePreprocessor]] = None,
56 | memoize: bool = False,
57 | memoize_key: Optional[HashingFunction] = None,
58 | gpu: bool = False,
59 | ) -> None:
60 | name = type(self).__name__
61 | super().__init__(
62 | name,
63 | field_names=dict(text=text_field),
64 | mapped_field_names=dict(doc=doc_field),
65 | pre=pre,
66 | memoize=memoize,
67 | memoize_key=memoize_key,
68 | )
69 | self.gpu = gpu
70 | if self.gpu:
71 | spacy.prefer_gpu()
72 | self._nlp = spacy.load(language, disable=disable or [])
73 |
74 | def run(self, text: str) -> FieldMap: # type: ignore
75 | """Run the SpaCy model on input text.
76 |
77 | Parameters
78 | ----------
79 | text
80 | Text of document to parse
81 |
82 | Returns
83 | -------
84 | FieldMap
85 | Dictionary with a single key (``"doc"``), mapping to the
86 | parsed SpaCy ``Doc`` object
87 | """
88 | # Note: not trying to add the fields of `Doc` to top-level
89 | # as most are Cython property methods computed on the fly.
90 | return dict(doc=self._nlp(text))
91 |
--------------------------------------------------------------------------------
/snorkel/preprocess/spark.py:
--------------------------------------------------------------------------------
1 | from snorkel.map.spark import make_spark_mapper as make_spark_preprocessor # noqa: F401
2 |
--------------------------------------------------------------------------------
/snorkel/slicing/__init__.py:
--------------------------------------------------------------------------------
1 | """Programmatic data set slicing: SF creation, monitoring utilities, and representation learning for slices."""
2 |
3 | from .apply.core import PandasSFApplier, SFApplier # noqa: F401
4 | from .modules.slice_combiner import SliceCombinerModule # noqa: F401
5 | from .monitor import slice_dataframe # noqa: F401
6 | from .sf.core import SlicingFunction, slicing_function # noqa: F401
7 | from .sliceaware_classifier import SliceAwareClassifier # noqa: F401
8 | from .utils import add_slice_labels, convert_to_slice_tasks # noqa: F401
9 |
--------------------------------------------------------------------------------
/snorkel/slicing/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/slicing/apply/__init__.py
--------------------------------------------------------------------------------
/snorkel/slicing/apply/core.py:
--------------------------------------------------------------------------------
1 | from snorkel.labeling import LFApplier, PandasLFApplier
2 |
3 |
4 | class SFApplier(LFApplier):
5 | """SF applier for a list of data points.
6 |
7 | See ``snorkel.labeling.core.LFApplier`` for details.
8 | """
9 |
10 | _use_recarray = True
11 |
12 |
13 | class PandasSFApplier(PandasLFApplier):
14 | """SF applier for a Pandas DataFrame.
15 |
16 | See ``snorkel.labeling.core.PandasLFApplier`` for details.
17 | """
18 |
19 | _use_recarray = True
20 |
--------------------------------------------------------------------------------
/snorkel/slicing/apply/dask.py:
--------------------------------------------------------------------------------
1 | from snorkel.labeling.apply.dask import ( # pragma: no cover
2 | DaskLFApplier,
3 | PandasParallelLFApplier,
4 | )
5 |
6 |
7 | class DaskSFApplier(DaskLFApplier): # pragma: no cover
8 | """SF applier for a Dask DataFrame.
9 |
10 | See ``snorkel.labeling.apply.dask.DaskLFApplier`` for details.
11 | """
12 |
13 | _use_recarray = True
14 |
15 |
16 | class PandasParallelSFApplier(PandasParallelLFApplier): # pragma: no cover
17 | """Parallel SF applier for a Pandas DataFrame.
18 |
19 | See ``snorkel.labeling.apply.dask.PandasParallelLFApplier`` for details.
20 | """
21 |
22 | _use_recarray = True
23 |
--------------------------------------------------------------------------------
/snorkel/slicing/apply/spark.py:
--------------------------------------------------------------------------------
1 | from snorkel.labeling.apply.spark import SparkLFApplier as SparkSFApplier # noqa: F401
2 |
--------------------------------------------------------------------------------
/snorkel/slicing/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/slicing/modules/__init__.py
--------------------------------------------------------------------------------
/snorkel/slicing/monitor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from snorkel.slicing import PandasSFApplier
5 | from snorkel.slicing.sf import SlicingFunction
6 |
7 |
8 | def slice_dataframe(
9 | df: pd.DataFrame, slicing_function: SlicingFunction
10 | ) -> pd.DataFrame:
11 | """Return a dataframe with examples corresponding to specified ``SlicingFunction``.
12 |
13 | Parameters
14 | ----------
15 | df
16 | A pandas DataFrame that will be sliced
17 | slicing_function
18 | SlicingFunction which will operate over df to return a subset of examples;
19 | function returns a subset of data for which ``slicing_function`` output is True
20 |
21 | Returns
22 | -------
23 | pd.DataFrame
24 | A DataFrame including only examples belonging to slice_name
25 | """
26 |
27 | S = PandasSFApplier([slicing_function]).apply(df)
28 |
29 | # Index into the SF labels by name
30 | df_idx = np.where(S[slicing_function.name])[0] # type: ignore
31 | return df.iloc[df_idx]
32 |
--------------------------------------------------------------------------------
/snorkel/slicing/sf/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import SlicingFunction, slicing_function # noqa: F401
2 |
--------------------------------------------------------------------------------
/snorkel/slicing/sf/core.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Mapping, Optional
2 |
3 | from snorkel.labeling.lf import LabelingFunction
4 | from snorkel.preprocess import BasePreprocessor
5 |
6 |
7 | class SlicingFunction(LabelingFunction):
8 | """Base class for slicing functions.
9 |
10 | See ``snorkel.labeling.lf.LabelingFunction`` for details.
11 | """
12 |
13 | pass
14 |
15 |
16 | class slicing_function:
17 | """Decorator to define a SlicingFunction object from a function.
18 |
19 | Parameters
20 | ----------
21 | name
22 | Name of the SF
23 | resources
24 | Slicing resources passed in to ``f`` via ``kwargs``
25 | preprocessors
26 | Preprocessors to run on data points before SF execution
27 |
28 | Examples
29 | --------
30 | >>> @slicing_function()
31 | ... def f(x):
32 | ... return x.a > 42
33 | >>> f
34 | SlicingFunction f, Preprocessors: []
35 | >>> from types import SimpleNamespace
36 | >>> x = SimpleNamespace(a=90, b=12)
37 | >>> f(x)
38 | True
39 |
40 | >>> @slicing_function(name="my_sf")
41 | ... def g(x):
42 | ... return 0 if x.a > 42 else -1
43 | >>> g
44 | SlicingFunction my_sf, Preprocessors: []
45 | """
46 |
47 | def __init__(
48 | self,
49 | name: Optional[str] = None,
50 | resources: Optional[Mapping[str, Any]] = None,
51 | pre: Optional[List[BasePreprocessor]] = None,
52 | ) -> None:
53 | if callable(name):
54 | raise ValueError("Looks like this decorator is missing parentheses!")
55 | self.name = name
56 | self.resources = resources
57 | self.pre = pre
58 |
59 | def __call__(self, f: Callable[..., int]) -> SlicingFunction:
60 | """Wrap a function to create a ``SlicingFunction``.
61 |
62 | Parameters
63 | ----------
64 | f
65 | Function that implements the core LF logic
66 |
67 | Returns
68 | -------
69 | SlicingFunction
70 | New ``SlicingFunction`` executing logic in wrapped function
71 | """
72 | name = self.name or f.__name__
73 | return SlicingFunction(name=name, f=f, resources=self.resources, pre=self.pre)
74 |
--------------------------------------------------------------------------------
/snorkel/slicing/sf/nlp.py:
--------------------------------------------------------------------------------
1 | from snorkel.labeling.lf.nlp import (
2 | BaseNLPLabelingFunction,
3 | SpacyPreprocessorParameters,
4 | base_nlp_labeling_function,
5 | )
6 | from snorkel.preprocess.nlp import SpacyPreprocessor
7 |
8 |
9 | class NLPSlicingFunction(BaseNLPLabelingFunction):
10 | r"""Special labeling function type for spaCy-based LFs.
11 |
12 | This class is a special version of ``LabelingFunction``. It
13 | has a ``SpacyPreprocessor`` integrated which shares a cache
14 | with all other ``NLPLabelingFunction`` instances. This makes
15 | it easy to define LFs that have a text input field and have
16 | logic written over spaCy ``Doc`` objects. Examples passed
17 | into an ``NLPLabelingFunction`` will have a new field which
18 | can be accessed which contains a spaCy ``Doc``. By default,
19 | this field is called ``doc``. A ``Doc`` object is
20 | a sequence of ``Token`` objects, which contain information
21 | on lemmatization, parts-of-speech, etc. ``Doc`` objects also
22 | contain fields like ``Doc.ents``, a list of named entities,
23 | and ``Doc.noun_chunks``, a list of noun phrases. For details
24 | of spaCy ``Doc`` objects and a full attribute listing,
25 | see https://spacy.io/api/doc.
26 |
27 | Simple ``NLPLabelingFunction``\s can be defined via a
28 | decorator. See ``nlp_labeling_function``.
29 |
30 | Parameters
31 | ----------
32 | name
33 | Name of the LF
34 | f
35 | Function that implements the core LF logic
36 | resources
37 | Labeling resources passed in to ``f`` via ``kwargs``
38 | pre
39 | Preprocessors to run before SpacyPreprocessor is executed
40 | text_field
41 | Name of data point text field to input
42 | doc_field
43 | Name of data point field to output parsed document to
44 | language
45 | spaCy model to load
46 | See https://spacy.io/usage/models#usage
47 | disable
48 | List of pipeline components to disable
49 | See https://spacy.io/usage/processing-pipelines#disabling
50 | memoize
51 | Memoize preprocessor outputs?
52 | memoize_key
53 | Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
54 |
55 | Raises
56 | ------
57 | ValueError
58 | Calling incorrectly defined preprocessors
59 |
60 | Example
61 | -------
62 | >>> def f(x):
63 | ... person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
64 | ... return len(person_ents) > 0
65 | >>> has_person_mention = NLPSlicingFunction(name="has_person_mention", f=f)
66 | >>> has_person_mention
67 | NLPSlicingFunction has_person_mention, Preprocessors: [SpacyPreprocessor...]
68 |
69 | >>> from types import SimpleNamespace
70 | >>> x = SimpleNamespace(text="The movie was good.")
71 | >>> has_person_mention(x)
72 | False
73 |
74 | Attributes
75 | ----------
76 | name
77 | See above
78 | """
79 |
80 | @classmethod
81 | def _create_preprocessor(
82 | cls, parameters: SpacyPreprocessorParameters
83 | ) -> SpacyPreprocessor:
84 | return SpacyPreprocessor(**parameters._asdict())
85 |
86 |
87 | class nlp_slicing_function(base_nlp_labeling_function):
88 | """Decorator to define a NLPSlicingFunction child object from a function.
89 |
90 | TODO: Implement a common parent decorator for Snorkel operators
91 | """
92 |
93 | _lf_cls = NLPSlicingFunction
94 |
--------------------------------------------------------------------------------
/snorkel/synthetic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/snorkel/synthetic/__init__.py
--------------------------------------------------------------------------------
/snorkel/synthetic/synthetic_data.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | def generate_simple_label_matrix(
7 | n: int, m: int, cardinality: int, abstain_multiplier: float = 1.0
8 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
9 | """Generate a synthetic label matrix with true parameters and labels.
10 |
11 | This function generates a set of labeling function conditional probability tables,
12 | P(LF=l | Y=y), stored as a matrix P, and true labels Y, and then generates the
13 | resulting label matrix L.
14 |
15 | Parameters
16 | ----------
17 | n
18 | Number of data points
19 | m
20 | Number of labeling functions
21 | cardinality
22 | Cardinality of true labels (i.e. not including abstains)
23 | abstain_multiplier
24 | Factor to multiply the probability of abstaining by
25 |
26 | Returns
27 | -------
28 | Tuple[np.ndarray, np.ndarray, np.ndarray]
29 | A tuple containing the LF conditional probabilities P,
30 | the true labels Y, and the output label matrix L
31 | """
32 | # Generate the conditional probability tables for the LFs
33 | # The first axis is LF, second is LF output label, third is true class label
34 | # Note that we include abstains in the LF output space, and that we bias the
35 | # conditional probabilities towards being non-adversarial
36 | P = np.empty((m, cardinality + 1, cardinality))
37 | for i in range(m):
38 | p = np.random.rand(cardinality + 1, cardinality)
39 |
40 | # Bias the LFs to being non-adversarial
41 | p[1:, :] += (cardinality - 1) * np.eye(cardinality)
42 |
43 | # Optionally increase the abstain probability by some multiplier; note this is
44 | # to simulate the common setting where LFs label very sparsely
45 | p[0, :] *= abstain_multiplier
46 |
47 | # Normalize the conditional probabilities table
48 | P[i] = p @ np.diag(1 / p.sum(axis=0))
49 |
50 | # Generate the true datapoint labels
51 | # Note: Assuming balanced classes to start
52 | Y = np.random.choice(cardinality, n)
53 |
54 | # Generate the label matrix L
55 | L: np.ndarray = np.empty((n, m), dtype=int)
56 | for i in range(n):
57 | for j in range(m):
58 | L[i, j] = np.random.choice(cardinality + 1, p=P[j, :, Y[i]]) - 1
59 | return P, Y, L
60 |
--------------------------------------------------------------------------------
/snorkel/types/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import Config # noqa: F401
2 | from .data import DataPoint, DataPoints, Field, FieldMap # noqa: F401
3 | from .hashing import HashingFunction # noqa: F401
4 |
--------------------------------------------------------------------------------
/snorkel/types/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | Config = NamedTuple
4 |
--------------------------------------------------------------------------------
/snorkel/types/data.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping, Sequence
2 |
3 | DataPoint = Any
4 | DataPoints = Sequence[DataPoint]
5 |
6 | Field = Any
7 | FieldMap = Mapping[str, Field]
8 |
--------------------------------------------------------------------------------
/snorkel/types/hashing.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Hashable
2 | from typing import Any, Callable
3 |
4 | HashingFunction = Callable[[Any], Hashable]
5 |
--------------------------------------------------------------------------------
/snorkel/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """General machine learning utilities shared across Snorkel."""
2 |
3 | from .core import ( # noqa: F401
4 | filter_labels,
5 | preds_to_probs,
6 | probs_to_preds,
7 | to_int_label_array,
8 | )
9 |
--------------------------------------------------------------------------------
/snorkel/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from snorkel.types import Config
4 |
5 |
6 | def merge_config(config: Config, config_updates: Dict[str, Any]) -> Config:
7 | """Merge a (potentially nested) dict of kwargs into a config (NamedTuple).
8 |
9 | Parameters
10 | ----------
11 | config
12 | An instantiated Config to update
13 | config_updates
14 | A potentially nested dict of settings to update in the Config
15 |
16 | Returns
17 | -------
18 | Config
19 | The updated Config
20 |
21 | Example
22 | -------
23 | ```
24 | config_updates = {
25 | "n_epochs": 5,
26 | "optimizer_config": {
27 | "lr": 0.001,
28 | }
29 | }
30 | trainer_config = merge_config(TrainerConfig(), config_updates)
31 | ```
32 | """
33 | for key, value in config_updates.items():
34 | if isinstance(value, dict):
35 | config_updates[key] = merge_config(getattr(config, key), value)
36 | return config._replace(**config_updates)
37 |
--------------------------------------------------------------------------------
/snorkel/utils/data_operators.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from typing import List
3 |
4 |
5 | def check_unique_names(names: List[str]) -> None:
6 | """Check that operator names are unique."""
7 | k, ct = Counter(names).most_common(1)[0]
8 | if ct > 1:
9 | raise ValueError(f"Operator names not unique: {ct} operators with name {k}")
10 |
--------------------------------------------------------------------------------
/snorkel/utils/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | from snorkel.types import Config
2 |
3 |
4 | class ExponentialLRSchedulerConfig(Config):
5 | """Settings for Exponential decay learning rate scheduler."""
6 |
7 | gamma: float = 0.9
8 |
9 |
10 | class StepLRSchedulerConfig(Config):
11 | """Settings for Step decay learning rate scheduler."""
12 |
13 | gamma: float = 0.9
14 | step_size: int = 5
15 |
16 |
17 | class LRSchedulerConfig(Config):
18 | """Settings common to all LRSchedulers.
19 |
20 | Parameters
21 | ----------
22 | warmup_steps
23 | The number of warmup_units over which to perform learning rate warmup (a linear
24 | increase from 0 to the specified lr)
25 | warmup_unit
26 | The unit to use when counting warmup (one of ["batches", "epochs"])
27 | warmup_percentage
28 | The percentage of the training procedure to warm up over (ignored if
29 | warmup_steps is non-zero)
30 | min_lr
31 | The minimum learning rate to use during training (the learning rate specified
32 | by a learning rate scheduler will be rounded up to this if it is lower)
33 | exponential_config
34 | Extra settings for the ExponentialLRScheduler
35 | step_config
36 | Extra settings for the StepLRScheduler
37 | """
38 |
39 | warmup_steps: float = 0 # warm up steps
40 | warmup_unit: str = "batches" # [epochs, batches]
41 | warmup_percentage: float = 0.0 # warm up percentage
42 | min_lr: float = 0.0 # minimum learning rate
43 | exponential_config: ExponentialLRSchedulerConfig = (
44 | ExponentialLRSchedulerConfig() # type:ignore
45 | )
46 | step_config: StepLRSchedulerConfig = StepLRSchedulerConfig() # type:ignore
47 |
--------------------------------------------------------------------------------
/snorkel/utils/optimizers.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from snorkel.types import Config
4 |
5 |
6 | class SGDOptimizerConfig(Config):
7 | """Settings for SGD optimizer."""
8 |
9 | momentum: float = 0.9
10 |
11 |
12 | class AdamOptimizerConfig(Config):
13 | """Settings for Adam optimizer."""
14 |
15 | amsgrad: bool = False
16 | betas: Tuple[float, float] = (0.9, 0.999)
17 |
18 |
19 | class AdamaxOptimizerConfig(Config):
20 | """Settings for Adamax optimizer."""
21 |
22 | betas: Tuple[float, float] = (0.9, 0.999)
23 | eps: float = 1e-8
24 |
25 |
26 | class OptimizerConfig(Config):
27 | """Settings common to all optimizers."""
28 |
29 | sgd_config: SGDOptimizerConfig = SGDOptimizerConfig() # type:ignore
30 | adam_config: AdamOptimizerConfig = AdamOptimizerConfig() # type:ignore
31 | adamax_config: AdamaxOptimizerConfig = AdamaxOptimizerConfig() # type:ignore
32 |
--------------------------------------------------------------------------------
/snorkel/version.py:
--------------------------------------------------------------------------------
1 | _MAJOR = "0"
2 | _MINOR = "10"
3 | _REVISION = "1+dev"
4 |
5 | VERSION_SHORT = f"{_MAJOR}.{_MINOR}"
6 | VERSION = f"{_MAJOR}.{_MINOR}.{_REVISION}"
7 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/__init__.py
--------------------------------------------------------------------------------
/test/analysis/test_error_analysis.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from snorkel.analysis import get_label_buckets, get_label_instances
6 |
7 |
8 | class ErrorAnalysisTest(unittest.TestCase):
9 | def test_get_label_buckets(self) -> None:
10 | y1 = np.array([[2], [1], [3], [1], [1], [3]])
11 | y2 = np.array([1, 2, 3, 1, 2, 3])
12 | buckets = get_label_buckets(y1, y2)
13 | expected_buckets = {(2, 1): [0], (1, 2): [1, 4], (3, 3): [2, 5], (1, 1): [3]}
14 | expected_buckets = {k: np.array(v) for k, v in expected_buckets.items()}
15 | np.testing.assert_equal(buckets, expected_buckets)
16 |
17 | y1_1d = np.array([2, 1, 3, 1, 1, 3])
18 | buckets = get_label_buckets(y1_1d, y2)
19 | np.testing.assert_equal(buckets, expected_buckets)
20 |
21 | def test_get_label_buckets_multi(self) -> None:
22 | y1 = np.array([[2], [1], [3], [1], [1], [3]])
23 | y2 = np.array([1, 2, 3, 1, 2, 3])
24 | y3 = np.array([[3], [2], [1], [1], [2], [3]])
25 | buckets = get_label_buckets(y1, y2, y3)
26 | expected_buckets = {
27 | (2, 1, 3): [0],
28 | (1, 2, 2): [1, 4],
29 | (3, 3, 1): [2],
30 | (1, 1, 1): [3],
31 | (3, 3, 3): [5],
32 | }
33 | expected_buckets = {k: np.array(v) for k, v in expected_buckets.items()}
34 | np.testing.assert_equal(buckets, expected_buckets)
35 |
36 | def test_get_label_buckets_bad_shape(self) -> None:
37 | with self.assertRaisesRegex(ValueError, "same number of elements"):
38 | get_label_buckets(np.array([0, 1, 1]), np.array([1, 1]))
39 |
40 | def test_get_label_instances(self) -> None:
41 | x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
42 | y1 = np.array([1, 0, 0, 0])
43 | y2 = np.array([1, 1, 1, 0])
44 | instances = get_label_instances((0, 1), x, y1, y2)
45 | expected_instances = np.array([[3, 4], [5, 6]])
46 | np.testing.assert_equal(instances, expected_instances)
47 |
48 | x = np.array(["this", "is", "a", "test", "of", "multi"])
49 | y1 = np.array([[2], [1], [3], [1], [1], [3]])
50 | y2 = np.array([1, 2, 3, 1, 2, 3])
51 | y3 = np.array([[3], [2], [1], [1], [2], [3]])
52 | instances = get_label_instances((3, 3, 3), x, y1, y2, y3)
53 | expected_instances = np.array(["multi"])
54 | np.testing.assert_equal(instances, expected_instances)
55 |
56 | def test_get_label_instances_exceptions(self) -> None:
57 | x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
58 | y1 = np.array([1, 0, 0, 0])
59 | y2 = np.array([1, 1, 1, 0])
60 | instances = get_label_instances((2, 0), x, y1, y2)
61 | expected_instances = np.array([])
62 | np.testing.assert_equal(instances, expected_instances)
63 |
64 | with self.assertRaisesRegex(
65 | ValueError, "Number of lists must match the amount of labels in bucket"
66 | ):
67 | get_label_instances((1, 0), x, y1)
68 |
69 | x = np.array([[1, 2], [3, 4], [5, 6]])
70 | with self.assertRaisesRegex(
71 | ValueError,
72 | "Number of rows in x does not match number of elements in at least one label list",
73 | ):
74 | get_label_instances((1, 0), x, y1, y2)
75 |
76 |
77 | if __name__ == "__main__":
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/test/augmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/__init__.py
--------------------------------------------------------------------------------
/test/augmentation/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/apply/__init__.py
--------------------------------------------------------------------------------
/test/augmentation/policy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/augmentation/policy/__init__.py
--------------------------------------------------------------------------------
/test/augmentation/policy/test_core.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snorkel.augmentation import ApplyAllPolicy, ApplyEachPolicy
4 |
5 |
6 | class TestPolicy(unittest.TestCase):
7 | def test_apply_each_policy(self):
8 | policy = ApplyEachPolicy(3, keep_original=True)
9 | samples = policy.generate_for_example()
10 | self.assertEqual(samples, [[], [0], [1], [2]])
11 |
12 | policy = ApplyEachPolicy(3, keep_original=False)
13 | samples = policy.generate_for_example()
14 | self.assertEqual(samples, [[0], [1], [2]])
15 |
16 | def test_apply_all_policy(self):
17 | policy = ApplyAllPolicy(3, n_per_original=2, keep_original=False)
18 | samples = policy.generate_for_example()
19 | self.assertEqual(samples, [[0, 1, 2], [0, 1, 2]])
20 |
--------------------------------------------------------------------------------
/test/augmentation/policy/test_sampling.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snorkel.augmentation import MeanFieldPolicy, RandomPolicy
4 |
5 |
6 | class TestSamplingPolicy(unittest.TestCase):
7 | def test_random_policy(self):
8 | policy = RandomPolicy(2, sequence_length=2)
9 | n_samples = 100
10 | samples = [policy.generate() for _ in range(n_samples)]
11 | a_ct = samples.count([0, 0])
12 | b_ct = samples.count([0, 1])
13 | c_ct = samples.count([1, 0])
14 | d_ct = samples.count([1, 1])
15 | self.assertGreater(a_ct, 0)
16 | self.assertGreater(b_ct, 0)
17 | self.assertGreater(c_ct, 0)
18 | self.assertGreater(d_ct, 0)
19 | self.assertEqual(a_ct + b_ct + c_ct + d_ct, n_samples)
20 |
21 | def test_mean_field_policy(self):
22 | policy = MeanFieldPolicy(2, sequence_length=2, p=[1, 0])
23 | n_samples = 100
24 | samples = [policy.generate() for _ in range(n_samples)]
25 | self.assertEqual(samples.count([0, 0]), n_samples)
26 |
--------------------------------------------------------------------------------
/test/classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/classification/__init__.py
--------------------------------------------------------------------------------
/test/classification/test_classifier_convergence.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 | from typing import List
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pytest
8 | import torch
9 | import torch.nn as nn
10 |
11 | from snorkel.analysis import Scorer
12 | from snorkel.classification import (
13 | DictDataLoader,
14 | DictDataset,
15 | MultitaskClassifier,
16 | Operation,
17 | Task,
18 | Trainer,
19 | )
20 |
21 | N_TRAIN = 1000
22 | N_VALID = 300
23 |
24 |
25 | class ClassifierConvergenceTest(unittest.TestCase):
26 | @classmethod
27 | def setUpClass(cls):
28 | # Ensure deterministic runs
29 | random.seed(123)
30 | np.random.seed(123)
31 | torch.manual_seed(123)
32 |
33 | @pytest.mark.complex
34 | def test_convergence(self):
35 | """Test multitask classifier convergence with two tasks."""
36 |
37 | dataloaders = []
38 |
39 | for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
40 | df = create_data(N_TRAIN, offset)
41 | dataloader = create_dataloader(df, "train", task_name)
42 | dataloaders.append(dataloader)
43 |
44 | for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
45 | df = create_data(N_VALID, offset)
46 | dataloader = create_dataloader(df, "valid", task_name)
47 | dataloaders.append(dataloader)
48 |
49 | task1 = create_task("task1", module_suffixes=["A", "A"])
50 | task2 = create_task("task2", module_suffixes=["A", "B"])
51 | model = MultitaskClassifier(tasks=[task1, task2])
52 |
53 | # Train
54 | trainer = Trainer(lr=0.0024, n_epochs=10, progress_bar=False)
55 | trainer.fit(model, dataloaders)
56 | scores = model.score(dataloaders)
57 |
58 | # Confirm near perfect scores on both tasks
59 | for idx, task_name in enumerate(["task1", "task2"]):
60 | self.assertGreater(scores[f"{task_name}/TestData/valid/accuracy"], 0.95)
61 |
62 | # Calculate/check train/val loss
63 | train_dataset = dataloaders[idx].dataset
64 | train_loss_output = model.calculate_loss(
65 | train_dataset.X_dict, train_dataset.Y_dict
66 | )
67 | train_loss = train_loss_output[0][task_name].item()
68 | self.assertLess(train_loss, 0.05)
69 |
70 | val_dataset = dataloaders[2 + idx].dataset
71 | val_loss_output = model.calculate_loss(
72 | val_dataset.X_dict, val_dataset.Y_dict
73 | )
74 | val_loss = val_loss_output[0][task_name].item()
75 | self.assertLess(val_loss, 0.05)
76 |
77 |
78 | def create_data(n: int, offset=0) -> pd.DataFrame:
79 | """Create uniform X data from [-1, 1] on both axes.
80 |
81 | Create labels with linear decision boundaries related to the two coordinates of X.
82 | """
83 | X = (np.random.random((n, 2)) * 2 - 1).astype(np.float32)
84 | Y = (X[:, 0] < X[:, 1] + offset).astype(int)
85 |
86 | df = pd.DataFrame({"x1": X[:, 0], "x2": X[:, 1], "y": Y})
87 | return df
88 |
89 |
90 | def create_dataloader(df: pd.DataFrame, split: str, task_name: str) -> DictDataLoader:
91 | dataset = DictDataset(
92 | name="TestData",
93 | split=split,
94 | X_dict={
95 | "coordinates": torch.stack(
96 | (torch.tensor(df["x1"]), torch.tensor(df["x2"])), dim=1
97 | )
98 | },
99 | Y_dict={task_name: torch.tensor(df["y"], dtype=torch.long)},
100 | )
101 |
102 | dataloader = DictDataLoader(
103 | dataset=dataset, batch_size=4, shuffle=(dataset.split == "train")
104 | )
105 | return dataloader
106 |
107 |
108 | def create_task(task_name: str, module_suffixes: List[str]) -> Task:
109 | module1_name = f"linear1{module_suffixes[0]}"
110 | module2_name = f"linear2{module_suffixes[1]}"
111 |
112 | module_pool = nn.ModuleDict(
113 | {
114 | module1_name: nn.Sequential(nn.Linear(2, 20), nn.ReLU()),
115 | module2_name: nn.Linear(20, 2),
116 | }
117 | )
118 |
119 | op1 = Operation(module_name=module1_name, inputs=[("_input_", "coordinates")])
120 | op2 = Operation(module_name=module2_name, inputs=[op1.name])
121 |
122 | op_sequence = [op1, op2]
123 |
124 | task = Task(
125 | name=task_name,
126 | module_pool=module_pool,
127 | op_sequence=op_sequence,
128 | scorer=Scorer(metrics=["accuracy"]),
129 | )
130 |
131 | return task
132 |
133 |
134 | if __name__ == "__main__":
135 | unittest.main()
136 |
--------------------------------------------------------------------------------
/test/classification/test_task.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch.nn as nn
4 |
5 | from snorkel.classification import Operation, Task
6 |
7 | TASK_NAME = "TestTask"
8 |
9 |
10 | class TaskTest(unittest.TestCase):
11 | def test_task_creation(self):
12 | module_pool = nn.ModuleDict(
13 | {
14 | "linear1": nn.Sequential(nn.Linear(2, 10), nn.ReLU()),
15 | "linear2": nn.Linear(10, 1),
16 | }
17 | )
18 |
19 | op_sequence = [
20 | Operation(
21 | name="the_first_layer", module_name="linear1", inputs=["_input_"]
22 | ),
23 | Operation(
24 | name="the_second_layer",
25 | module_name="linear2",
26 | inputs=["the_first_layer"],
27 | ),
28 | ]
29 |
30 | task = Task(name=TASK_NAME, module_pool=module_pool, op_sequence=op_sequence)
31 |
32 | # Task has no functionality on its own
33 | # Here we only confirm that the object was initialized
34 | self.assertEqual(task.name, TASK_NAME)
35 |
36 |
37 | if __name__ == "__main__":
38 | unittest.main()
39 |
--------------------------------------------------------------------------------
/test/classification/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 |
5 | from snorkel.classification.utils import (
6 | collect_flow_outputs_by_suffix,
7 | list_to_tensor,
8 | pad_batch,
9 | )
10 |
11 |
12 | class UtilsTest(unittest.TestCase):
13 | def test_pad_batch(self):
14 | batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])]
15 | padded_batch, mask_batch = pad_batch(batch)
16 |
17 | self.assertTrue(
18 | torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]))
19 | )
20 | self.assertTrue(
21 | torch.equal(mask_batch, torch.Tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]]))
22 | )
23 |
24 | padded_batch, mask_batch = pad_batch(batch, max_len=2)
25 |
26 | self.assertTrue(
27 | torch.equal(padded_batch, torch.Tensor([[1, 2], [3, 0], [4, 5]]))
28 | )
29 | self.assertTrue(torch.equal(mask_batch, torch.Tensor([[0, 0], [0, 1], [0, 0]])))
30 |
31 | padded_batch, mask_batch = pad_batch(batch, pad_value=-1)
32 |
33 | self.assertTrue(
34 | torch.equal(
35 | padded_batch, torch.Tensor([[1, 2, -1], [3, -1, -1], [4, 5, 6]])
36 | )
37 | )
38 | self.assertTrue(
39 | torch.equal(mask_batch, torch.Tensor([[0, 0, 1], [0, 1, 1], [0, 0, 0]]))
40 | )
41 |
42 | padded_batch, mask_batch = pad_batch(batch, left_padded=True)
43 |
44 | self.assertTrue(
45 | torch.equal(padded_batch, torch.Tensor([[0, 1, 2], [0, 0, 3], [4, 5, 6]]))
46 | )
47 | self.assertTrue(
48 | torch.equal(mask_batch, torch.Tensor([[1, 0, 0], [1, 1, 0], [0, 0, 0]]))
49 | )
50 |
51 | padded_batch, mask_batch = pad_batch(batch, max_len=2, left_padded=True)
52 |
53 | self.assertTrue(
54 | torch.equal(padded_batch, torch.Tensor([[1, 2], [0, 3], [5, 6]]))
55 | )
56 | self.assertTrue(torch.equal(mask_batch, torch.Tensor([[0, 0], [1, 0], [0, 0]])))
57 |
58 | def test_list_to_tensor(self):
59 | # list of 1-D tensor with the different length
60 | batch = [torch.Tensor([1, 2]), torch.Tensor([3]), torch.Tensor([4, 5, 6])]
61 |
62 | padded_batch = list_to_tensor(batch)
63 |
64 | self.assertTrue(
65 | torch.equal(padded_batch, torch.Tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]))
66 | )
67 |
68 | # list of 1-D tensor with the same length
69 | batch = [
70 | torch.Tensor([1, 2, 3]),
71 | torch.Tensor([4, 5, 6]),
72 | torch.Tensor([7, 8, 9]),
73 | ]
74 |
75 | padded_batch = list_to_tensor(batch)
76 |
77 | self.assertTrue(
78 | torch.equal(padded_batch, torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
79 | )
80 |
81 | # list of 2-D tensor with the same size
82 | batch = [
83 | torch.Tensor([[1, 2, 3], [1, 2, 3]]),
84 | torch.Tensor([[4, 5, 6], [4, 5, 6]]),
85 | torch.Tensor([[7, 8, 9], [7, 8, 9]]),
86 | ]
87 |
88 | padded_batch = list_to_tensor(batch)
89 |
90 | self.assertTrue(
91 | torch.equal(
92 | padded_batch,
93 | torch.Tensor(
94 | [
95 | [[1, 2, 3], [1, 2, 3]],
96 | [[4, 5, 6], [4, 5, 6]],
97 | [[7, 8, 9], [7, 8, 9]],
98 | ]
99 | ),
100 | )
101 | )
102 |
103 | # list of tensor with the different size
104 | batch = [
105 | torch.Tensor([[1, 2], [2, 3]]),
106 | torch.Tensor([4, 5, 6]),
107 | torch.Tensor([7, 8, 9, 0]),
108 | ]
109 |
110 | padded_batch = list_to_tensor(batch)
111 |
112 | self.assertTrue(
113 | torch.equal(
114 | padded_batch, torch.Tensor([[1, 2, 2, 3], [4, 5, 6, 0], [7, 8, 9, 0]])
115 | )
116 | )
117 |
118 | def test_collect_flow_outputs_by_suffix(self):
119 | flow_dict = {
120 | "a_pred_head": torch.Tensor([1]),
121 | "b_pred_head": torch.Tensor([2]),
122 | "c_pred": torch.Tensor([3]),
123 | }
124 | outputs = collect_flow_outputs_by_suffix(flow_dict, "_head")
125 | self.assertIn(torch.Tensor([1]), outputs)
126 | self.assertIn(torch.Tensor([2]), outputs)
127 |
128 | if __name__ == "__main__":
129 | unittest.main()
130 |
--------------------------------------------------------------------------------
/test/classification/training/loggers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/classification/training/loggers/__init__.py
--------------------------------------------------------------------------------
/test/classification/training/loggers/test_checkpointer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 | import unittest
5 |
6 | from snorkel.classification.multitask_classifier import MultitaskClassifier
7 | from snorkel.classification.training.loggers import Checkpointer
8 |
9 | log_manager_config = {"counter_unit": "epochs", "evaluation_freq": 1}
10 |
11 |
12 | class TestLogManager(unittest.TestCase):
13 | def setUp(self) -> None:
14 | self.test_dir = tempfile.mkdtemp()
15 |
16 | def tearDown(self) -> None:
17 | shutil.rmtree(self.test_dir)
18 |
19 | def test_checkpointer(self) -> None:
20 | checkpointer = Checkpointer(
21 | **log_manager_config,
22 | checkpoint_dir=self.test_dir,
23 | checkpoint_runway=3,
24 | checkpoint_metric="task/dataset/valid/f1:max",
25 | )
26 | model = MultitaskClassifier([])
27 | checkpointer.checkpoint(2, model, {"task/dataset/valid/f1": 0.5})
28 | self.assertEqual(len(checkpointer.best_metric_dict), 0)
29 | checkpointer.checkpoint(
30 | 3, model, {"task/dataset/valid/f1": 0.8, "task/dataset/valid/f2": 0.5}
31 | )
32 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8)
33 | checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.9})
34 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.9)
35 |
36 | def test_checkpointer_min(self) -> None:
37 | checkpointer = Checkpointer(
38 | **log_manager_config,
39 | checkpoint_dir=self.test_dir,
40 | checkpoint_runway=3,
41 | checkpoint_metric="task/dataset/valid/f1:min",
42 | )
43 | model = MultitaskClassifier([])
44 | checkpointer.checkpoint(
45 | 3, model, {"task/dataset/valid/f1": 0.8, "task/dataset/valid/f2": 0.5}
46 | )
47 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8)
48 | checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.7})
49 | self.assertEqual(checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.7)
50 |
51 | def test_checkpointer_clear(self) -> None:
52 | checkpoint_dir = os.path.join(self.test_dir, "clear")
53 | checkpointer = Checkpointer(
54 | **log_manager_config,
55 | checkpoint_dir=checkpoint_dir,
56 | checkpoint_metric="task/dataset/valid/f1:max",
57 | checkpoint_clear=True,
58 | )
59 | model = MultitaskClassifier([])
60 | checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8})
61 | expected_files = ["checkpoint_1.pth", "best_model_task_dataset_valid_f1.pth"]
62 | self.assertEqual(set(os.listdir(checkpoint_dir)), set(expected_files))
63 | checkpointer.clear()
64 | expected_files = ["best_model_task_dataset_valid_f1.pth"]
65 | self.assertEqual(os.listdir(checkpoint_dir), expected_files)
66 |
67 | def test_checkpointer_load_best(self) -> None:
68 | checkpoint_dir = os.path.join(self.test_dir, "clear")
69 | checkpointer = Checkpointer(
70 | **log_manager_config,
71 | checkpoint_dir=checkpoint_dir,
72 | checkpoint_metric="task/dataset/valid/f1:max",
73 | )
74 | model = MultitaskClassifier([])
75 | checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8})
76 | load_model = checkpointer.load_best_model(model)
77 | self.assertEqual(model, load_model)
78 |
79 | def test_bad_checkpoint_runway(self) -> None:
80 | with self.assertRaisesRegex(ValueError, "checkpoint_runway"):
81 | Checkpointer(**log_manager_config, checkpoint_runway=-1)
82 |
83 | def test_no_zero_frequency(self) -> None:
84 | with self.assertRaisesRegex(ValueError, "checkpoint freq"):
85 | Checkpointer(
86 | **log_manager_config, checkpoint_dir=self.test_dir, checkpoint_factor=0
87 | )
88 |
89 | def test_bad_metric_name(self) -> None:
90 | with self.assertRaisesRegex(ValueError, "metric_name:mode"):
91 | Checkpointer(
92 | **log_manager_config,
93 | checkpoint_dir=self.test_dir,
94 | checkpoint_metric="task/dataset/split/f1-min",
95 | )
96 |
97 | with self.assertRaisesRegex(ValueError, "metric mode"):
98 | Checkpointer(
99 | **log_manager_config,
100 | checkpoint_dir=self.test_dir,
101 | checkpoint_metric="task/dataset/split/f1:mode",
102 | )
103 |
104 | with self.assertRaisesRegex(ValueError, "checkpoint_metric must be formatted"):
105 | Checkpointer(
106 | **log_manager_config,
107 | checkpoint_dir=self.test_dir,
108 | checkpoint_metric="accuracy:max",
109 | )
110 |
--------------------------------------------------------------------------------
/test/classification/training/loggers/test_log_writer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 | import tempfile
5 | import unittest
6 |
7 | from snorkel.classification.training.loggers import LogWriter
8 | from snorkel.types import Config
9 |
10 |
11 | class TempConfig(Config):
12 | a: int = 42
13 | b: str = "foo"
14 |
15 |
16 | class TestLogWriter(unittest.TestCase):
17 | def setUp(self):
18 | self.test_dir = tempfile.mkdtemp()
19 |
20 | def tearDown(self):
21 | shutil.rmtree(self.test_dir)
22 |
23 | def test_log_writer(self):
24 | run_name = "my_run"
25 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir)
26 | log_writer.add_scalar("my_value", value=0.5, step=2)
27 |
28 | log_filename = "my_log.json"
29 | log_writer.write_log(log_filename)
30 |
31 | log_path = os.path.join(self.test_dir, run_name, log_filename)
32 | with open(log_path, "r") as f:
33 | log = json.load(f)
34 |
35 | log_expected = dict(my_value=[[2, 0.5]])
36 | self.assertEqual(log, log_expected)
37 |
38 | def test_write_text(self) -> None:
39 | run_name = "my_run"
40 | filename = "my_text.txt"
41 | text = "my log text"
42 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir)
43 | log_writer.write_text(text, filename)
44 | log_path = os.path.join(self.test_dir, run_name, filename)
45 | with open(log_path, "r") as f:
46 | file_text = f.read()
47 | self.assertEqual(text, file_text)
48 |
49 | def test_write_config(self) -> None:
50 | run_name = "my_run"
51 | config = TempConfig(b="bar") # type: ignore
52 | log_writer = LogWriter(run_name=run_name, log_dir=self.test_dir)
53 | log_writer.write_config(config)
54 | log_path = os.path.join(self.test_dir, run_name, "config.json")
55 | with open(log_path, "r") as f:
56 | file_config = json.load(f)
57 | self.assertEqual(config._asdict(), file_config)
58 |
59 |
60 | if __name__ == "__main__":
61 | unittest.main()
62 |
--------------------------------------------------------------------------------
/test/classification/training/loggers/test_tensorboard_writer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 | import tempfile
5 | import unittest
6 |
7 | from snorkel.classification.training.loggers import TensorBoardWriter
8 | from snorkel.types import Config
9 |
10 |
11 | class TempConfig(Config):
12 | a: int = 42
13 | b: str = "foo"
14 |
15 |
16 | class TestTensorBoardWriter(unittest.TestCase):
17 | def setUp(self):
18 | self.test_dir = tempfile.mkdtemp()
19 |
20 | def tearDown(self):
21 | shutil.rmtree(self.test_dir)
22 |
23 | def test_tensorboard_writer(self):
24 | # Note: this just tests API calls. We rely on
25 | # tensorboard's unit tests for correctness.
26 | run_name = "my_run"
27 | config = TempConfig(b="bar")
28 | writer = TensorBoardWriter(run_name=run_name, log_dir=self.test_dir)
29 | writer.add_scalar("my_value", value=0.5, step=2)
30 | writer.write_config(config)
31 | log_path = os.path.join(self.test_dir, run_name, "config.json")
32 | with open(log_path, "r") as f:
33 | file_config = json.load(f)
34 | self.assertEqual(config._asdict(), file_config)
35 | writer.cleanup()
36 |
--------------------------------------------------------------------------------
/test/classification/training/schedulers/test_schedulers.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from snorkel.classification import DictDataLoader, DictDataset
8 | from snorkel.classification.training.schedulers import (
9 | SequentialScheduler,
10 | ShuffledScheduler,
11 | )
12 |
13 | dataset1 = DictDataset(
14 | "d1",
15 | "train",
16 | X_dict={"data": [0, 1, 2, 3, 4]},
17 | Y_dict={"labels": torch.LongTensor([1, 1, 1, 1, 1])},
18 | )
19 | dataset2 = DictDataset(
20 | "d2",
21 | "train",
22 | X_dict={"data": [5, 6, 7, 8, 9]},
23 | Y_dict={"labels": torch.LongTensor([2, 2, 2, 2, 2])},
24 | )
25 |
26 | dataloader1 = DictDataLoader(dataset1, batch_size=2)
27 | dataloader2 = DictDataLoader(dataset2, batch_size=2)
28 | dataloaders = [dataloader1, dataloader2]
29 |
30 |
31 | class SequentialTest(unittest.TestCase):
32 | def test_sequential(self):
33 | scheduler = SequentialScheduler()
34 | data = []
35 | for batch, dl in scheduler.get_batches(dataloaders):
36 | X_dict, Y_dict = batch
37 | data.extend(X_dict["data"])
38 | self.assertEqual(data, sorted(data))
39 |
40 | def test_shuffled(self):
41 | random.seed(123)
42 | np.random.seed(123)
43 | torch.manual_seed(123)
44 | scheduler = ShuffledScheduler()
45 | data = []
46 | for batch, dl in scheduler.get_batches(dataloaders):
47 | X_dict, Y_dict = batch
48 | data.extend(X_dict["data"])
49 | self.assertNotEqual(data, sorted(data))
50 |
51 |
52 | if __name__ == "__main__":
53 | unittest.main()
54 |
--------------------------------------------------------------------------------
/test/labeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/__init__.py
--------------------------------------------------------------------------------
/test/labeling/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/apply/__init__.py
--------------------------------------------------------------------------------
/test/labeling/apply/lf_applier_spark_test_script.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is used to manually test
3 | `snorkel.labeling.apply.lf_applier_spark.SparkLFApplier`
4 |
5 | To test on AWS EMR:
6 | 1. Allocate an EMR cluster (e.g. label 5.24.0) with > 1 worker and SSH permissions
7 | 2. Clone and pip install snorkel on the master node
8 | ```
9 | sudo yum install git
10 | git clone https://github.com/snorkel-team/snorkel
11 | cd snorkel
12 | python3 -m pip install -t snorkel-package .
13 | cd snorkel-package
14 | zip -r ../snorkel-package.zip .
15 | cd ..
16 | ```
17 | 3. Run
18 | ```
19 | sudo sed -i -e \
20 | '$a\\export PYSPARK_PYTHON=/usr/bin/python3' \
21 | /etc/spark/conf/spark-env.sh
22 | ```
23 | 4. Run
24 | ```
25 | spark-submit \
26 | --py-files snorkel-package.zip \
27 | test/labeling/apply/lf_applier_spark_test_script.py
28 | ```
29 | """
30 |
31 | import logging
32 | from typing import List
33 |
34 | import numpy as np
35 | from pyspark import SparkContext
36 |
37 | from snorkel.labeling.apply.spark import SparkLFApplier
38 | from snorkel.labeling.lf import labeling_function
39 | from snorkel.types import DataPoint
40 |
41 | logging.basicConfig(level=logging.INFO)
42 |
43 |
44 | @labeling_function()
45 | def f(x: DataPoint) -> int:
46 | return 1 if x > 42 else 0
47 |
48 |
49 | @labeling_function(resources=dict(db=[3, 6, 9]))
50 | def g(x: DataPoint, db: List[int]) -> int:
51 | return 1 if x in db else 0
52 |
53 |
54 | DATA = [3, 43, 12, 9]
55 | L_EXPECTED = np.array([[0, 1], [1, 0], [0, 0], [0, 1]])
56 |
57 |
58 | def build_lf_matrix() -> None:
59 | logging.info("Getting Spark context")
60 | sc = SparkContext()
61 | sc.addPyFile("snorkel-package.zip")
62 | rdd = sc.parallelize(DATA)
63 |
64 | logging.info("Applying LFs")
65 | lf_applier = SparkLFApplier([f, g])
66 | L = lf_applier.apply(rdd)
67 |
68 | np.testing.assert_equal(L.toarray(), L_EXPECTED)
69 |
70 |
71 | if __name__ == "__main__":
72 | build_lf_matrix()
73 |
--------------------------------------------------------------------------------
/test/labeling/apply/test_spark.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from typing import List
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 | from pyspark import SparkContext
8 | from pyspark.sql import Row, SQLContext
9 |
10 | from snorkel.labeling import labeling_function
11 | from snorkel.labeling.apply.spark import SparkLFApplier
12 | from snorkel.preprocess import preprocessor
13 | from snorkel.types import DataPoint
14 |
15 |
16 | @preprocessor()
17 | def square(x: Row) -> Row:
18 | return Row(num=x.num, num_squared=x.num**2)
19 |
20 |
21 | @labeling_function()
22 | def f(x: DataPoint) -> int:
23 | return 0 if x.num > 42 else -1
24 |
25 |
26 | @labeling_function(pre=[square])
27 | def fp(x: DataPoint) -> int:
28 | return 0 if x.num_squared > 42 else -1
29 |
30 |
31 | @labeling_function(resources=dict(db=[3, 6, 9]))
32 | def g(x: DataPoint, db: List[int]) -> int:
33 | return 0 if x.num in db else -1
34 |
35 |
36 | @labeling_function()
37 | def f_bad(x: DataPoint) -> int:
38 | return 0 if x.mum > 42 else -1
39 |
40 |
41 | DATA = [3, 43, 12, 9, 3]
42 | L_EXPECTED = np.array([[-1, 0], [0, -1], [-1, -1], [-1, 0], [-1, 0]])
43 | L_EXPECTED_BAD = np.array([[-1, -1], [0, -1], [-1, -1], [-1, -1], [-1, -1]])
44 | L_PREPROCESS_EXPECTED = np.array([[-1, -1], [0, 0], [-1, 0], [-1, 0], [-1, -1]])
45 |
46 | TEXT_DATA = ["Jane", "Jane plays soccer.", "Jane plays soccer."]
47 | L_TEXT_EXPECTED = np.array([[0, -1], [0, 0], [0, 0]])
48 |
49 |
50 | class TestSparkApplier(unittest.TestCase):
51 | @pytest.mark.complex
52 | @pytest.mark.spark
53 | def test_lf_applier_spark(self) -> None:
54 | sc = SparkContext.getOrCreate()
55 | sql = SQLContext(sc)
56 | df = pd.DataFrame(dict(num=DATA))
57 | rdd = sql.createDataFrame(df).rdd
58 | applier = SparkLFApplier([f, g])
59 | L = applier.apply(rdd)
60 | np.testing.assert_equal(L, L_EXPECTED)
61 |
62 | @pytest.mark.complex
63 | @pytest.mark.spark
64 | def test_lf_applier_spark_fault(self) -> None:
65 | sc = SparkContext.getOrCreate()
66 | sql = SQLContext(sc)
67 | df = pd.DataFrame(dict(num=DATA))
68 | rdd = sql.createDataFrame(df).rdd
69 | applier = SparkLFApplier([f, f_bad])
70 | with self.assertRaises(Exception):
71 | applier.apply(rdd)
72 | L = applier.apply(rdd, fault_tolerant=True)
73 | np.testing.assert_equal(L, L_EXPECTED_BAD)
74 |
75 | @pytest.mark.complex
76 | @pytest.mark.spark
77 | def test_lf_applier_spark_preprocessor(self) -> None:
78 | sc = SparkContext.getOrCreate()
79 | sql = SQLContext(sc)
80 | df = pd.DataFrame(dict(num=DATA))
81 | rdd = sql.createDataFrame(df).rdd
82 | applier = SparkLFApplier([f, fp])
83 | L = applier.apply(rdd)
84 | np.testing.assert_equal(L, L_PREPROCESS_EXPECTED)
85 |
86 | @pytest.mark.complex
87 | @pytest.mark.spark
88 | def test_lf_applier_spark_preprocessor_memoized(self) -> None:
89 | sc = SparkContext.getOrCreate()
90 | sql = SQLContext(sc)
91 |
92 | @preprocessor(memoize=True)
93 | def square_memoize(x: DataPoint) -> DataPoint:
94 | return Row(num=x.num, num_squared=x.num**2)
95 |
96 | @labeling_function(pre=[square_memoize])
97 | def fp_memoized(x: DataPoint) -> int:
98 | return 0 if x.num_squared > 42 else -1
99 |
100 | df = pd.DataFrame(dict(num=DATA))
101 | rdd = sql.createDataFrame(df).rdd
102 | applier = SparkLFApplier([f, fp_memoized])
103 | L = applier.apply(rdd)
104 | np.testing.assert_equal(L, L_PREPROCESS_EXPECTED)
105 |
--------------------------------------------------------------------------------
/test/labeling/lf/test_core.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import unittest
3 | from types import SimpleNamespace
4 | from typing import List
5 |
6 | from snorkel.labeling import LabelingFunction, labeling_function
7 | from snorkel.preprocess import preprocessor
8 | from snorkel.types import DataPoint
9 |
10 |
11 | @preprocessor()
12 | def square(x: DataPoint) -> DataPoint:
13 | x.num = x.num**2
14 | return x
15 |
16 |
17 | @preprocessor()
18 | def returns_none(x: DataPoint) -> DataPoint:
19 | return None
20 |
21 |
22 | def f(x: DataPoint) -> int:
23 | return 0 if x.num > 42 else -1
24 |
25 |
26 | def g(x: DataPoint, db: List[int]) -> int:
27 | return 0 if x.num in db else -1
28 |
29 |
30 | class TestLabelingFunction(unittest.TestCase):
31 | def _run_lf(self, lf: LabelingFunction) -> None:
32 | x_43 = SimpleNamespace(num=43)
33 | x_19 = SimpleNamespace(num=19)
34 | self.assertEqual(lf(x_43), 0)
35 | self.assertEqual(lf(x_19), -1)
36 |
37 | def test_labeling_function(self) -> None:
38 | lf = LabelingFunction(name="my_lf", f=f)
39 | self._run_lf(lf)
40 |
41 | def test_labeling_function_resources(self) -> None:
42 | db = [3, 6, 43]
43 | lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db))
44 | self._run_lf(lf)
45 |
46 | def test_labeling_function_preprocessor(self) -> None:
47 | lf = LabelingFunction(name="my_lf", f=f, pre=[square, square])
48 | x_43 = SimpleNamespace(num=43)
49 | x_6 = SimpleNamespace(num=6)
50 | x_2 = SimpleNamespace(num=2)
51 | self.assertEqual(lf(x_43), 0)
52 | self.assertEqual(lf(x_6), 0)
53 | self.assertEqual(lf(x_2), -1)
54 |
55 | def test_labeling_function_returns_none(self) -> None:
56 | lf = LabelingFunction(name="my_lf", f=f, pre=[square, returns_none])
57 | x_43 = SimpleNamespace(num=43)
58 | with self.assertRaises(ValueError):
59 | lf(x_43)
60 |
61 | def test_labeling_function_serialize(self) -> None:
62 | db = [3, 6, 43]
63 | lf = LabelingFunction(name="my_lf", f=g, resources=dict(db=db))
64 | lf_load = pickle.loads(pickle.dumps(lf))
65 | self._run_lf(lf_load)
66 |
67 | def test_labeling_function_decorator(self) -> None:
68 | @labeling_function()
69 | def lf(x: DataPoint) -> int:
70 | return 0 if x.num > 42 else -1
71 |
72 | self.assertIsInstance(lf, LabelingFunction)
73 | self.assertEqual(lf.name, "lf")
74 | self._run_lf(lf)
75 |
76 | def test_labeling_function_decorator_args(self) -> None:
77 | db = [3, 6, 43**2]
78 |
79 | @labeling_function(name="my_lf", resources=dict(db=db), pre=[square])
80 | def lf(x: DataPoint, db: List[int]) -> int:
81 | return 0 if x.num in db else -1
82 |
83 | self.assertIsInstance(lf, LabelingFunction)
84 | self.assertEqual(lf.name, "my_lf")
85 | self._run_lf(lf)
86 |
87 | def test_labeling_function_decorator_no_parens(self) -> None:
88 | with self.assertRaisesRegex(ValueError, "missing parentheses"):
89 |
90 | @labeling_function
91 | def lf(x: DataPoint) -> int:
92 | return 0 if x.num > 42 else -1
93 |
--------------------------------------------------------------------------------
/test/labeling/lf/test_nlp.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 |
4 | import dill
5 | import pytest
6 |
7 | from snorkel.labeling.lf.nlp import NLPLabelingFunction, nlp_labeling_function
8 | from snorkel.preprocess import preprocessor
9 | from snorkel.types import DataPoint
10 |
11 |
12 | @preprocessor()
13 | def combine_text(x: DataPoint) -> DataPoint:
14 | x.text = f"{x.title} {x.article}"
15 | return x
16 |
17 |
18 | def has_person_mention(x: DataPoint) -> int:
19 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
20 | return 0 if len(person_ents) > 0 else -1
21 |
22 |
23 | class TestNLPLabelingFunction(unittest.TestCase):
24 | def _run_lf(self, lf: NLPLabelingFunction) -> None:
25 | x = SimpleNamespace(
26 | num=8, title="Great film!", article="The movie is really great!"
27 | )
28 | self.assertEqual(lf(x), -1)
29 | x = SimpleNamespace(num=8, title="Nice movie!", article="Jane Doe acted well.")
30 | self.assertEqual(lf(x), 0)
31 |
32 | def test_nlp_labeling_function(self) -> None:
33 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text])
34 | self._run_lf(lf)
35 |
36 | def test_nlp_labeling_function_memoized(self) -> None:
37 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text])
38 | lf._nlp_config.nlp.reset_cache()
39 | self.assertEqual(len(lf._nlp_config.nlp._cache), 0)
40 | self._run_lf(lf)
41 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2)
42 | self._run_lf(lf)
43 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2)
44 |
45 | @pytest.mark.complex
46 | def test_labeling_function_serialize(self) -> None:
47 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text])
48 | lf_load = dill.loads(dill.dumps(lf))
49 | self._run_lf(lf_load)
50 |
51 | def test_nlp_labeling_function_decorator(self) -> None:
52 | @nlp_labeling_function(pre=[combine_text])
53 | def has_person_mention(x: DataPoint) -> int:
54 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
55 | return 0 if len(person_ents) > 0 else -1
56 |
57 | self.assertIsInstance(has_person_mention, NLPLabelingFunction)
58 | self.assertEqual(has_person_mention.name, "has_person_mention")
59 | self._run_lf(has_person_mention)
60 |
61 | def test_nlp_labeling_function_decorator_no_parens(self) -> None:
62 | with self.assertRaisesRegex(ValueError, "missing parentheses"):
63 |
64 | @nlp_labeling_function
65 | def has_person_mention(x: DataPoint) -> int:
66 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
67 | return 0 if len(person_ents) > 0 else -1
68 |
69 | def test_nlp_labeling_function_shared_cache(self) -> None:
70 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention, pre=[combine_text])
71 |
72 | @nlp_labeling_function(pre=[combine_text])
73 | def lf2(x: DataPoint) -> int:
74 | return 0 if len(x.doc) < 9 else -1
75 |
76 | lf._nlp_config.nlp.reset_cache()
77 | self.assertEqual(len(lf._nlp_config.nlp._cache), 0)
78 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 0)
79 | self._run_lf(lf)
80 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2)
81 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 2)
82 | self._run_lf(lf2)
83 | self.assertEqual(len(lf._nlp_config.nlp._cache), 2)
84 | self.assertEqual(len(lf2._nlp_config.nlp._cache), 2)
85 |
86 | def test_nlp_labeling_function_raises(self) -> None:
87 | with self.assertRaisesRegex(ValueError, "different parameters"):
88 |
89 | @nlp_labeling_function()
90 | def has_person_mention(x: DataPoint) -> int:
91 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
92 | return 0 if len(person_ents) > 0 else -1
93 |
--------------------------------------------------------------------------------
/test/labeling/lf/test_nlp_spark.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 |
4 | import pytest
5 | from pyspark.sql import Row
6 |
7 | from snorkel.labeling.lf.nlp import NLPLabelingFunction
8 | from snorkel.labeling.lf.nlp_spark import (
9 | SparkNLPLabelingFunction,
10 | spark_nlp_labeling_function,
11 | )
12 | from snorkel.types import DataPoint
13 |
14 |
15 | def has_person_mention(x: DataPoint) -> int:
16 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
17 | return 0 if len(person_ents) > 0 else -1
18 |
19 |
20 | @pytest.mark.spark
21 | class TestNLPLabelingFunction(unittest.TestCase):
22 | def _run_lf(self, lf: SparkNLPLabelingFunction) -> None:
23 | x = Row(num=8, text="The movie is really great!")
24 | self.assertEqual(lf(x), -1)
25 | x = Row(num=8, text="Jane Doe acted well.")
26 | self.assertEqual(lf(x), 0)
27 |
28 | def test_nlp_labeling_function(self) -> None:
29 | lf = SparkNLPLabelingFunction(name="my_lf", f=has_person_mention)
30 | self._run_lf(lf)
31 |
32 | def test_nlp_labeling_function_decorator(self) -> None:
33 | @spark_nlp_labeling_function()
34 | def has_person_mention(x: DataPoint) -> int:
35 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
36 | return 0 if len(person_ents) > 0 else -1
37 |
38 | self.assertIsInstance(has_person_mention, SparkNLPLabelingFunction)
39 | self.assertEqual(has_person_mention.name, "has_person_mention")
40 | self._run_lf(has_person_mention)
41 |
42 | def test_spark_nlp_labeling_function_with_nlp_labeling_function(self) -> None:
43 | # Do they have separate _nlp_configs?
44 | lf = NLPLabelingFunction(name="my_lf", f=has_person_mention)
45 | lf_spark = SparkNLPLabelingFunction(name="my_lf_spark", f=has_person_mention)
46 | self.assertEqual(lf(SimpleNamespace(num=8, text="Jane Doe acted well.")), 0)
47 | self._run_lf(lf_spark)
48 |
--------------------------------------------------------------------------------
/test/labeling/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/model/__init__.py
--------------------------------------------------------------------------------
/test/labeling/model/test_baseline.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from snorkel.labeling.model import MajorityClassVoter, MajorityLabelVoter, RandomVoter
6 |
7 |
8 | class BaselineModelTest(unittest.TestCase):
9 | def test_random_vote(self):
10 | L = np.array([[0, 1, 0], [-1, 3, 2], [2, -1, -1], [0, 1, 1]])
11 | rand_voter = RandomVoter()
12 | Y_p = rand_voter.predict_proba(L)
13 | self.assertLessEqual(Y_p.max(), 1.0)
14 | self.assertGreaterEqual(Y_p.min(), 0.0)
15 | np.testing.assert_array_almost_equal(
16 | np.sum(Y_p, axis=1), np.ones(np.shape(L)[0])
17 | )
18 |
19 | def test_majority_class_vote(self):
20 | L = np.array([[0, 1, 0], [1, 1, 0], [1, 1, 0], [-1, -1, 1]])
21 | mc_voter = MajorityClassVoter()
22 | mc_voter.fit(balance=np.array([0.8, 0.2]))
23 | Y_p = mc_voter.predict_proba(L)
24 |
25 | Y_p_true = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]])
26 | np.testing.assert_array_almost_equal(Y_p, Y_p_true)
27 |
28 | def test_majority_label_vote(self):
29 | L = np.array([[0, 1, 0], [0, 1, 0], [1, 0, 0], [-1, -1, 1]])
30 | ml_voter = MajorityLabelVoter()
31 | Y_p = ml_voter.predict_proba(L)
32 |
33 | Y_p_true = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
34 | np.testing.assert_array_almost_equal(Y_p, Y_p_true)
35 |
36 |
37 | if __name__ == "__main__":
38 | unittest.main()
39 |
--------------------------------------------------------------------------------
/test/labeling/model/test_logger.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snorkel.labeling.model.logger import Logger
4 |
5 |
6 | class LoggerTest(unittest.TestCase):
7 | def test_basic(self):
8 | metrics_dict = {"train/loss": 0.01}
9 | logger = Logger(log_freq=1)
10 | logger.log(metrics_dict)
11 |
12 | metrics_dict = {"train/message": "well done!"}
13 | logger = Logger(log_freq=1)
14 | logger.log(metrics_dict)
15 |
16 | def test_bad_metrics_dict(self):
17 | bad_metrics_dict = {"task1/slice1/train/loss": 0.05}
18 | logger = Logger(log_freq=1)
19 | self.assertRaises(Exception, logger.log, bad_metrics_dict)
20 |
21 | def test_valid_metrics_dict(self):
22 | mtl_metrics_dict = {"task1/valid/loss": 0.05}
23 | logger = Logger(log_freq=1)
24 | logger.log(mtl_metrics_dict)
25 |
26 |
27 | if __name__ == "__main__":
28 | unittest.main()
29 |
--------------------------------------------------------------------------------
/test/labeling/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/labeling/preprocess/__init__.py
--------------------------------------------------------------------------------
/test/labeling/preprocess/test_nlp.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 |
4 | from snorkel.preprocess.nlp import SpacyPreprocessor
5 |
6 |
7 | class TestSpacyPreprocessor(unittest.TestCase):
8 | def test_spacy_preprocessor(self) -> None:
9 | x = SimpleNamespace(text="Jane plays soccer.")
10 | preprocessor = SpacyPreprocessor("text", "doc")
11 | x_preprocessed = preprocessor(x)
12 | assert x_preprocessed is not None
13 | self.assertEqual(len(x_preprocessed.doc), 4)
14 | token = x_preprocessed.doc[0]
15 | self.assertEqual(token.text, "Jane")
16 | self.assertEqual(token.pos_, "PROPN")
17 |
--------------------------------------------------------------------------------
/test/labeling/test_convergence.py:
--------------------------------------------------------------------------------
1 | import random
2 | import unittest
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 | import torch
8 |
9 | from snorkel.labeling import LabelingFunction, PandasLFApplier, labeling_function
10 | from snorkel.labeling.model import LabelModel
11 | from snorkel.preprocess import preprocessor
12 | from snorkel.types import DataPoint
13 |
14 |
15 | def create_data(n: int) -> pd.DataFrame:
16 | """Create random pairs x1, x2 in [-1., 1.] with label x1 > x2 + 0.25."""
17 | X = np.random.random((n, 2)) * 2 - 1
18 | Y = (X[:, 0] > X[:, 1] + 0.25).astype(int)
19 |
20 | df = pd.DataFrame(
21 | {"x0": np.random.randint(0, 1000, n), "x1": X[:, 0], "x2": X[:, 1], "y": Y}
22 | )
23 | return df
24 |
25 |
26 | def get_positive_labeling_function(divisor: int) -> LabelingFunction:
27 | """Get LabelingFunction that abstains unless x0 is divisible by divisor."""
28 |
29 | def f(x):
30 | return 1 if x.x0 % divisor == 0 and x.x1 > x.x2 + 0.25 else -1
31 |
32 | return LabelingFunction(f"lf_pos_{divisor}", f)
33 |
34 |
35 | def get_negative_labeling_function(divisor: int) -> LabelingFunction:
36 | """Get LabelingFunction that abstains unless x0 is divisible by divisor."""
37 |
38 | def f(x):
39 | return 0 if x.x0 % divisor == 0 and x.x1 <= x.x2 + 0.25 else -1
40 |
41 | return LabelingFunction(f"lf_neg_{divisor}", f)
42 |
43 |
44 | @preprocessor()
45 | def copy_features(x: DataPoint) -> DataPoint:
46 | """Compute x2 + 0.25 for direct comparison to x1."""
47 | x.x3 = x.x2 + 0.25
48 | return x
49 |
50 |
51 | @labeling_function(pre=[copy_features], resources=dict(divisor=3))
52 | def f(x: DataPoint, divisor: int) -> int:
53 | # Abstain unless x0 is divisible by divisor.
54 | return 0 if x.x0 % divisor == 1 and x.x1 > x.x3 else -1
55 |
56 |
57 | class LabelingConvergenceTest(unittest.TestCase):
58 | @classmethod
59 | def setUpClass(cls):
60 | # Ensure deterministic runs
61 | random.seed(123)
62 | np.random.seed(123)
63 | torch.manual_seed(123)
64 |
65 | # Create raw data
66 | cls.N_TRAIN = 1500
67 |
68 | cls.cardinality = 2
69 | cls.df_train = create_data(cls.N_TRAIN)
70 |
71 | @pytest.mark.complex
72 | def test_labeling_convergence(self) -> None:
73 | """Test convergence of end to end labeling pipeline."""
74 | # Apply LFs
75 | labeling_functions = (
76 | [f]
77 | + [get_positive_labeling_function(divisor) for divisor in range(2, 9)]
78 | + [get_negative_labeling_function(divisor) for divisor in range(2, 9)]
79 | )
80 | applier = PandasLFApplier(labeling_functions)
81 | L_train = applier.apply(self.df_train, progress_bar=False)
82 |
83 | self.assertEqual(L_train.shape, (self.N_TRAIN, len(labeling_functions)))
84 |
85 | # Train LabelModel
86 | label_model = LabelModel(cardinality=self.cardinality, verbose=False)
87 | label_model.fit(L_train, n_epochs=100, lr=0.01, l2=0.0)
88 | Y_lm = label_model.predict_proba(L_train).argmax(axis=1)
89 | Y = self.df_train.y
90 | err = np.where(Y != Y_lm, 1, 0).sum() / self.N_TRAIN
91 | self.assertLess(err, 0.06)
92 |
93 |
94 | if __name__ == "__main__":
95 | unittest.main()
96 |
--------------------------------------------------------------------------------
/test/labeling/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from snorkel.labeling import filter_unlabeled_dataframe
7 |
8 |
9 | class TestAnalysis(unittest.TestCase):
10 | def test_filter_unlabeled_dataframe(self) -> None:
11 | X = pd.DataFrame(dict(A=["x", "y", "z"], B=[1, 2, 3]))
12 | y = np.array(
13 | [[0.25, 0.25, 0.25, 0.25], [1.0, 0.0, 0.0, 0.0], [0.2, 0.3, 0.5, 0.0]]
14 | )
15 | L = np.array([[0, 1, -1], [-1, -1, -1], [1, 1, 0]])
16 | X_filtered, y_filtered = filter_unlabeled_dataframe(X, y, L)
17 | np.array_equal(X_filtered.values, np.array([["x", 1], ["z", 3]]))
18 | np.testing.assert_array_almost_equal(
19 | y_filtered, np.array([[0.25, 0.25, 0.25, 0.25], [0.2, 0.3, 0.5, 0.0]])
20 | )
21 |
--------------------------------------------------------------------------------
/test/map/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/map/__init__.py
--------------------------------------------------------------------------------
/test/slicing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/__init__.py
--------------------------------------------------------------------------------
/test/slicing/apply/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/apply/__init__.py
--------------------------------------------------------------------------------
/test/slicing/apply/test_sf_applier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 | from typing import List
4 |
5 | from snorkel.preprocess import preprocessor
6 | from snorkel.slicing import SFApplier, slicing_function
7 | from snorkel.types import DataPoint
8 |
9 |
10 | @preprocessor()
11 | def square(x: DataPoint) -> DataPoint:
12 | x.num_squared = x.num**2
13 | return x
14 |
15 |
16 | class SquareHitTracker:
17 | def __init__(self):
18 | self.n_hits = 0
19 |
20 | def __call__(self, x: float) -> float:
21 | self.n_hits += 1
22 | return x**2
23 |
24 |
25 | @slicing_function()
26 | def f(x: DataPoint) -> int:
27 | return x.num > 42
28 |
29 |
30 | @slicing_function(pre=[square])
31 | def fp(x: DataPoint) -> int:
32 | return x.num_squared > 42
33 |
34 |
35 | @slicing_function(resources=dict(db=[3, 6, 9]))
36 | def g(x: DataPoint, db: List[int]) -> int:
37 | return x.num in db
38 |
39 |
40 | DATA = [3, 43, 12, 9, 3]
41 | S_EXPECTED = {"f": [0, 1, 0, 0, 0], "g": [1, 0, 0, 1, 1]}
42 | S_PREPROCESS_EXPECTED = {"f": [0, 1, 0, 0, 0], "fp": [0, 1, 1, 1, 0]}
43 |
44 |
45 | class TestSFApplier(unittest.TestCase):
46 | def test_sf_applier(self) -> None:
47 | data_points = [SimpleNamespace(num=num) for num in DATA]
48 | applier = SFApplier([f, g])
49 | S = applier.apply(data_points, progress_bar=False)
50 | self.assertEqual(S["f"].tolist(), S_EXPECTED["f"])
51 | self.assertEqual(S["g"].tolist(), S_EXPECTED["g"])
52 | S = applier.apply(data_points, progress_bar=True)
53 | self.assertEqual(S["f"].tolist(), S_EXPECTED["f"])
54 | self.assertEqual(S["g"].tolist(), S_EXPECTED["g"])
55 |
56 | def test_sf_applier_preprocessor(self) -> None:
57 | data_points = [SimpleNamespace(num=num) for num in DATA]
58 | applier = SFApplier([f, fp])
59 | S = applier.apply(data_points, progress_bar=False)
60 | self.assertEqual(S["f"].tolist(), S_PREPROCESS_EXPECTED["f"])
61 | self.assertEqual(S["fp"].tolist(), S_PREPROCESS_EXPECTED["fp"])
62 |
--------------------------------------------------------------------------------
/test/slicing/sf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/slicing/sf/__init__.py
--------------------------------------------------------------------------------
/test/slicing/sf/test_core.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 |
4 | from snorkel.slicing import SlicingFunction, slicing_function
5 |
6 |
7 | class TestSlicingFunction(unittest.TestCase):
8 | def _run_sf(self, sf: SlicingFunction) -> None:
9 | x_43 = SimpleNamespace(num=43)
10 | x_19 = SimpleNamespace(num=19)
11 | self.assertEqual(sf(x_43), True)
12 | self.assertEqual(sf(x_19), False)
13 |
14 | def _run_sf_raise(self, sf: SlicingFunction) -> None:
15 | x_none = SimpleNamespace(num=None)
16 | with self.assertRaises(TypeError):
17 | sf(x_none)
18 |
19 | def test_slicing_function_decorator(self) -> None:
20 | @slicing_function()
21 | def sf(x) -> int:
22 | return x.num > 42
23 |
24 | self.assertIsInstance(sf, SlicingFunction)
25 | self.assertEqual(sf.name, "sf")
26 | self._run_sf(sf)
27 | self._run_sf_raise(sf)
28 |
29 | def test_slicing_function_decorator_no_parens(self) -> None:
30 | with self.assertRaisesRegex(ValueError, "missing parentheses"):
31 |
32 | @slicing_function
33 | def sf(x) -> int:
34 | return 0 if x.num > 42 else -1
35 |
--------------------------------------------------------------------------------
/test/slicing/sf/test_nlp.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from types import SimpleNamespace
3 |
4 | from snorkel.preprocess import preprocessor
5 | from snorkel.slicing.sf.nlp import NLPSlicingFunction, nlp_slicing_function
6 | from snorkel.types import DataPoint
7 |
8 |
9 | @preprocessor()
10 | def combine_text(x: DataPoint) -> DataPoint:
11 | x.text = f"{x.title} {x.article}"
12 | return x
13 |
14 |
15 | def has_person_mention(x: DataPoint) -> int:
16 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
17 | return 0 if len(person_ents) > 0 else -1
18 |
19 |
20 | class TestNLPSlicingFunction(unittest.TestCase):
21 | def _run_sf(self, sf: NLPSlicingFunction) -> None:
22 | x = SimpleNamespace(
23 | num=8, title="Great film!", article="The movie is really great!"
24 | )
25 | self.assertEqual(sf(x), -1)
26 | x = SimpleNamespace(num=8, title="Nice movie!", article="Jane Doe acted well.")
27 | self.assertEqual(sf(x), 0)
28 |
29 | def test_nlp_slicing_function(self) -> None:
30 | sf = NLPSlicingFunction(name="my_sf", f=has_person_mention, pre=[combine_text])
31 | self._run_sf(sf)
32 |
33 | def test_nlp_slicing_function_decorator(self) -> None:
34 | @nlp_slicing_function(pre=[combine_text])
35 | def has_person_mention(x: DataPoint) -> int:
36 | person_ents = [ent for ent in x.doc.ents if ent.label_ == "PERSON"]
37 | return 0 if len(person_ents) > 0 else -1
38 |
39 | self.assertIsInstance(has_person_mention, NLPSlicingFunction)
40 | self.assertEqual(has_person_mention.name, "has_person_mention")
41 | self._run_sf(has_person_mention)
42 |
--------------------------------------------------------------------------------
/test/slicing/test_monitor.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import pandas as pd
4 |
5 | from snorkel.slicing import slicing_function
6 | from snorkel.slicing.monitor import slice_dataframe
7 |
8 | DATA = [5, 10, 19, 22, 25]
9 |
10 |
11 | @slicing_function()
12 | def sf(x):
13 | return x.num < 20
14 |
15 |
16 | class PandasSlicerTest(unittest.TestCase):
17 | @classmethod
18 | def setUpClass(cls):
19 | cls.df = pd.DataFrame(dict(num=DATA))
20 |
21 | def test_slice(self):
22 | self.assertEqual(len(self.df), 5)
23 |
24 | # Should return a subset
25 | sliced_df = slice_dataframe(self.df, sf)
26 | self.assertEqual(len(sliced_df), 3)
27 |
--------------------------------------------------------------------------------
/test/slicing/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import pandas as pd
4 | import torch
5 | import torch.nn as nn
6 |
7 | from snorkel.classification import DictDataLoader, DictDataset, Operation, Task
8 | from snorkel.slicing import (
9 | PandasSFApplier,
10 | add_slice_labels,
11 | convert_to_slice_tasks,
12 | slicing_function,
13 | )
14 |
15 |
16 | @slicing_function()
17 | def f(x):
18 | return x.val < 0.25
19 |
20 |
21 | class UtilsTest(unittest.TestCase):
22 | def test_add_slice_labels(self):
23 | # Create dummy data
24 | # Given slicing function f(), we expect the first two entries to be active
25 | x = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5])
26 | y = torch.Tensor([0, 1, 1, 0, 1]).long()
27 | dataset = DictDataset(
28 | name="TestData", split="train", X_dict={"data": x}, Y_dict={"TestTask": y}
29 | )
30 |
31 | # Ensure that we start with 1 labelset
32 | self.assertEqual(len(dataset.Y_dict), 1)
33 |
34 | # Apply SFs with PandasSFApplier
35 | df = pd.DataFrame({"val": x, "y": y})
36 | slicing_functions = [f]
37 | applier = PandasSFApplier(slicing_functions)
38 | S = applier.apply(df, progress_bar=False)
39 |
40 | dataloader = DictDataLoader(dataset)
41 |
42 | dummy_task = create_dummy_task(task_name="TestTask")
43 | add_slice_labels(dataloader, dummy_task, S)
44 |
45 | # Ensure that all the fields are present
46 | labelsets = dataloader.dataset.Y_dict
47 | self.assertIn("TestTask", labelsets)
48 | self.assertIn("TestTask_slice:base_ind", labelsets)
49 | self.assertIn("TestTask_slice:base_pred", labelsets)
50 | self.assertIn("TestTask_slice:f_ind", labelsets)
51 | self.assertIn("TestTask_slice:f_pred", labelsets)
52 | self.assertEqual(len(labelsets), 5)
53 |
54 | # Ensure "ind" contains mask
55 | self.assertEqual(
56 | labelsets["TestTask_slice:f_ind"].numpy().tolist(), [1, 1, 0, 0, 0]
57 | )
58 | self.assertEqual(
59 | labelsets["TestTask_slice:base_ind"].numpy().tolist(), [1, 1, 1, 1, 1]
60 | )
61 |
62 | # Ensure "pred" contains masked elements
63 | self.assertEqual(
64 | labelsets["TestTask_slice:f_pred"].numpy().tolist(), [0, 1, -1, -1, -1]
65 | )
66 | self.assertEqual(
67 | labelsets["TestTask_slice:base_pred"].numpy().tolist(), [0, 1, 1, 0, 1]
68 | )
69 | self.assertEqual(labelsets["TestTask"].numpy().tolist(), [0, 1, 1, 0, 1])
70 |
71 | def test_convert_to_slice_tasks(self):
72 | task_name = "TestTask"
73 | task = create_dummy_task(task_name)
74 |
75 | slice_names = ["slice_a", "slice_b", "slice_c"]
76 | slice_tasks = convert_to_slice_tasks(task, slice_names)
77 |
78 | slice_task_names = [t.name for t in slice_tasks]
79 | # Check for original base task
80 | self.assertIn(task_name, slice_task_names)
81 |
82 | # Check for 2 tasks (pred + ind) per slice, accounting for base slice
83 | for slice_name in slice_names + ["base"]:
84 | self.assertIn(f"{task_name}_slice:{slice_name}_pred", slice_task_names)
85 | self.assertIn(f"{task_name}_slice:{slice_name}_ind", slice_task_names)
86 |
87 | self.assertEqual(len(slice_tasks), 2 * (len(slice_names) + 1) + 1)
88 |
89 | # Test that modules share the same body flow operations
90 | # NOTE: Use "is" comparison to check object equality
91 | body_flow = task.op_sequence[:-1]
92 | ind_and_pred_tasks = [
93 | t for t in slice_tasks if "_ind" in t.name or "_pred" in t.name
94 | ]
95 | for op in body_flow:
96 | for slice_task in ind_and_pred_tasks:
97 | self.assertTrue(
98 | slice_task.module_pool[op.module_name]
99 | is task.module_pool[op.module_name]
100 | )
101 |
102 | # Test that pred tasks share the same predictor head
103 | pred_tasks = [t for t in slice_tasks if "_pred" in t.name]
104 | predictor_head_name = pred_tasks[0].op_sequence[-1].module_name
105 | shared_predictor_head = pred_tasks[0].module_pool[predictor_head_name]
106 | for pred_task in pred_tasks[1:]:
107 | self.assertTrue(
108 | pred_task.module_pool[predictor_head_name] is shared_predictor_head
109 | )
110 |
111 |
112 | def create_dummy_task(task_name):
113 | # Create dummy task
114 | module_pool = nn.ModuleDict(
115 | {"linear1": nn.Linear(2, 10), "linear2": nn.Linear(10, 2)}
116 | )
117 |
118 | op_sequence = [
119 | Operation(name="encoder", module_name="linear1", inputs=["_input_"]),
120 | Operation(name="prediction_head", module_name="linear2", inputs=["encoder"]),
121 | ]
122 |
123 | task = Task(name=task_name, module_pool=module_pool, op_sequence=op_sequence)
124 | return task
125 |
--------------------------------------------------------------------------------
/test/synthetic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/synthetic/__init__.py
--------------------------------------------------------------------------------
/test/synthetic/test_synthetic_data.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from typing import Optional
3 |
4 | import numpy as np
5 |
6 | from snorkel.labeling import LFAnalysis
7 | from snorkel.synthetic.synthetic_data import generate_simple_label_matrix
8 |
9 |
10 | class TestGenerateSimpleLabelMatrix(unittest.TestCase):
11 | """Testing the generate_simple_label_matrix function."""
12 |
13 | def setUp(self) -> None:
14 | """Set constants for the tests."""
15 | self.m = 10 # Number of LFs
16 | self.n = 1000 # Number of data points
17 |
18 | def _test_generate_L(self, k: int, decimal: Optional[int] = 2) -> None:
19 | """Test generated label matrix L for consistency with P, Y.
20 |
21 | This tests for consistency between the true conditional LF probabilities, P,
22 | and the empirical ones computed from L and Y, where P, L, and Y are generated
23 | by the generate_simple_label_matrix function.
24 |
25 | Parameters
26 | ----------
27 | k
28 | Cardinality
29 | decimal
30 | Number of decimals to check element-wise error, err < 1.5 * 10**(-decimal)
31 | """
32 | np.random.seed(123)
33 | P, Y, L = generate_simple_label_matrix(self.n, self.m, k)
34 | P_emp = LFAnalysis(L).lf_empirical_probs(Y, k=k)
35 | np.testing.assert_array_almost_equal(P, P_emp, decimal=decimal)
36 |
37 | def test_generate_L(self) -> None:
38 | """Test the generated dataset for consistency."""
39 | self._test_generate_L(2, decimal=1)
40 |
41 | def test_generate_L_multiclass(self) -> None:
42 | """Test the generated dataset for consistency with cardinality=3."""
43 | self._test_generate_L(3, decimal=1)
44 |
45 |
46 | if __name__ == "__main__":
47 | unittest.main()
48 |
--------------------------------------------------------------------------------
/test/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snorkel-team/snorkel/617c92400c50e95ce41fcee84309a86f76cf525c/test/utils/__init__.py
--------------------------------------------------------------------------------
/test/utils/test_config_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snorkel.types import Config
4 | from snorkel.utils.config_utils import merge_config
5 |
6 |
7 | class FooConfig(Config):
8 | a: float = 0.5
9 |
10 |
11 | class BarConfig(Config):
12 | a: int = 1
13 | foo_config: FooConfig = FooConfig() # type: ignore
14 |
15 |
16 | class UtilsTest(unittest.TestCase):
17 | def test_merge_config(self):
18 | config_updates = {"a": 2, "foo_config": {"a": 0.75}}
19 | bar_config = merge_config(BarConfig(), config_updates)
20 | self.assertEqual(bar_config.a, 2)
21 | self.assertEqual(bar_config.foo_config.a, 0.75)
22 |
--------------------------------------------------------------------------------
/test/utils/test_core.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from snorkel.utils import (
6 | filter_labels,
7 | preds_to_probs,
8 | probs_to_preds,
9 | to_int_label_array,
10 | )
11 |
12 | PROBS = np.array([[0.1, 0.9], [0.7, 0.3]])
13 | PREDS = np.array([1, 0])
14 | PREDS_ROUND = np.array([[0, 1], [1, 0]])
15 |
16 |
17 | class UtilsTest(unittest.TestCase):
18 | def test_to_int_label_array(self):
19 | X = np.array([[1], [0], [2.0]])
20 | Y_expected = np.array([1, 0, 2])
21 | Y = to_int_label_array(X, flatten_vector=True)
22 | np.testing.assert_array_equal(Y, Y_expected)
23 |
24 | Y = to_int_label_array(np.array([[1]]), flatten_vector=True)
25 | Y_expected = np.array([1])
26 | np.testing.assert_array_equal(Y, Y_expected)
27 |
28 | Y = to_int_label_array(X, flatten_vector=False)
29 | Y_expected = np.array([[1], [0], [2]])
30 | np.testing.assert_array_equal(Y, Y_expected)
31 |
32 | X = np.array([[1], [0], [2.1]])
33 | with self.assertRaisesRegex(ValueError, "non-integer value"):
34 | to_int_label_array(X)
35 |
36 | X = np.array([[1, 0], [0, 1]])
37 | with self.assertRaisesRegex(ValueError, "1d np.array"):
38 | to_int_label_array(X, flatten_vector=True)
39 |
40 | def test_preds_to_probs(self):
41 | np.testing.assert_array_equal(preds_to_probs(PREDS, 2), PREDS_ROUND)
42 |
43 | def test_probs_to_preds(self):
44 | np.testing.assert_array_equal(probs_to_preds(PROBS), PREDS)
45 |
46 | # abtains with ties
47 | probs = np.array([[0.33, 0.33, 0.33]])
48 | preds = probs_to_preds(probs, tie_break_policy="abstain")
49 | true_preds = np.array([-1])
50 | np.testing.assert_array_equal(preds, true_preds)
51 |
52 | # true random with ties
53 | probs = np.array([[0.33, 0.33, 0.33]])
54 | random_preds = []
55 | for seed in range(10):
56 | preds = probs_to_preds(probs, tie_break_policy="true-random")
57 | random_preds.append(preds[0])
58 |
59 | # check predicted labels within range
60 | self.assertLessEqual(max(random_preds), 2)
61 | self.assertGreaterEqual(min(random_preds), 0)
62 |
63 | # deterministic random with ties
64 | probs = np.array(
65 | [[0.33, 0.33, 0.33], [0.0, 0.5, 0.5], [0.33, 0.33, 0.33], [0.5, 0.5, 0]]
66 | )
67 | random_preds = []
68 | for _ in range(10):
69 | preds = probs_to_preds(probs, tie_break_policy="random")
70 | random_preds.append(preds)
71 |
72 | # check labels are same across seeds
73 | for i in range(len(random_preds) - 1):
74 | np.testing.assert_array_equal(random_preds[i], random_preds[i + 1])
75 |
76 | # check predicted labels within range (only one instance since should all be same)
77 | self.assertLessEqual(max(random_preds[0]), 2)
78 | self.assertGreaterEqual(min(random_preds[0]), 0)
79 |
80 | # check invalid policy
81 | with self.assertRaisesRegex(ValueError, "policy not recognized"):
82 | preds = probs_to_preds(probs, tie_break_policy="negative")
83 |
84 | # check invalid input
85 | with self.assertRaisesRegex(ValueError, "probs must have probabilities"):
86 | preds = probs_to_preds(np.array([[0.33], [0.33]]))
87 |
88 | def test_filter_labels(self):
89 | golds = np.array([-1, 0, 0, 1, 1])
90 | preds = np.array([0, 0, 1, 1, -1])
91 | filtered = filter_labels(
92 | label_dict={"golds": golds, "preds": preds},
93 | filter_dict={"golds": [-1], "preds": [-1]},
94 | )
95 | np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1]))
96 | np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1]))
97 |
98 | def test_filter_labels_probs(self):
99 | golds = np.array([-1, 0, 0, 1, 1])
100 | preds = np.array([0, 0, 1, 1, -1])
101 | probs = np.array([[0.8, 0.2], [0.8, 0.2], [0.2, 0.8], [0.2, 0.8], [0.5, 0.5]])
102 | filtered = filter_labels(
103 | label_dict={"golds": golds, "preds": preds, "probs": probs},
104 | filter_dict={"golds": [-1], "preds": [-1]},
105 | )
106 | np.testing.assert_array_equal(filtered["golds"], np.array([0, 0, 1]))
107 | np.testing.assert_array_equal(filtered["preds"], np.array([0, 1, 1]))
108 |
109 |
110 | if __name__ == "__main__":
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/test/utils/test_data_operators.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from snorkel.utils.data_operators import check_unique_names
4 |
5 |
6 | class DataOperatorsTest(unittest.TestCase):
7 | def test_check_unique_names(self):
8 | check_unique_names(["alice", "bob", "chuck"])
9 | with self.assertRaisesRegex(ValueError, "3 operators with name c"):
10 | check_unique_names(["a", "a", "b", "c", "c", "c"])
11 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | skip_missing_interpreters = true
3 | envlist =
4 | py11,
5 | type,
6 | check,
7 | doctest,
8 | isolated_build = true
9 |
10 | [testenv]
11 | description = run the test driver with {basepython}
12 | # Note: in order to allow dependency library install reuse
13 | # on CI, we allow overriding the default envdir
14 | # (specified as `{toxworkdir}/{envname}`) by setting the
15 | # environment variable `TOX_INSTALL_DIR`. We avoid
16 | # collision with the already-used `TOX_ENV_DIR`.
17 | envdir = {env:TOX_INSTALL_DIR:{toxworkdir}/{envname}}
18 | # Note: we try to keep the deps the same for all tests
19 | # running on CI so that we skip reinstalling dependency
20 | # libraries for all testenvs
21 | deps =
22 | -rrequirements.txt
23 | -rrequirements-pyspark.txt
24 | commands_pre = python -m spacy download en_core_web_sm
25 | commands = python -m pytest {posargs:-m 'not spark and not complex'}
26 |
27 | [testenv:spark]
28 | description = run the test driver for spark tests with {basepython}
29 | passenv = JAVA_HOME
30 | commands = python -m pytest -m spark {posargs}
31 |
32 | [testenv:complex]
33 | description = run the test driver for integration tests with {basepython}
34 | commands = python -m pytest -m 'complex and not spark' {posargs}
35 |
36 | [testenv:doctest]
37 | description = run doctest
38 | skipsdist = true
39 | commands = python -m pytest --doctest-plus snorkel
40 |
41 | [testenv:check]
42 | description = check the code and doc style
43 | basepython = python3
44 | allowlist_externals =
45 | {toxinidir}/scripts/check_requirements.py
46 | {toxinidir}/scripts/sync_api_docs.py
47 | commands_pre =
48 | commands =
49 | isort -rc -c .
50 | black --check .
51 | flake8 .
52 | pydocstyle snorkel
53 | {toxinidir}/scripts/check_requirements.py
54 | {toxinidir}/scripts/sync_api_docs.py --check
55 |
56 | [testenv:type]
57 | description = run static type checking
58 | basepython = python3
59 | commands_pre =
60 | commands = mypy -p snorkel --disallow-untyped-defs --disallow-incomplete-defs --no-implicit-optional
61 |
62 | [testenv:coverage]
63 | description = run coverage checks
64 | basepython = python3
65 | # Note: make sure this matches testenv since this is used
66 | # on CI as the default unit test runner
67 | commands = python -m pytest -m 'not spark and not complex' --cov=snorkel
68 |
69 | [testenv:fix]
70 | description = run code stylers
71 | basepython = python3
72 | usedevelop = True
73 | commands_pre =
74 | commands =
75 | isort -rc .
76 | black .
77 |
78 | [testenv:doc]
79 | description = build docs
80 | basepython = python3
81 | skipsdist = True
82 | commands_pre = python -m pip install -U -r docs/requirements-doc.txt
83 | commands =
84 | rm -rf docs/_build
85 | rm -rf docs/packages/_autosummary
86 | make -C docs/ html
87 | {toxinidir}/scripts/sync_api_docs.py
88 |
--------------------------------------------------------------------------------